tobby48 4 years ago
parent
commit
f833d0f6c7

+ 30
- 0
pom.xml View File

@@ -334,6 +334,36 @@
334 334
 		    <version>4.5.3</version>
335 335
 		</dependency>
336 336
 		
337
+		
338
+		<!-- https://mvnrepository.com/artifact/org.deeplearning4j/deeplearning4j-core -->
339
+		<dependency>
340
+		    <groupId>org.deeplearning4j</groupId>
341
+		    <artifactId>deeplearning4j-core</artifactId>
342
+		    <version>1.0.0-beta7</version>
343
+		</dependency>
344
+
345
+		<dependency>
346
+			<groupId>org.nd4j</groupId>
347
+			<artifactId>nd4j-native-platform</artifactId>
348
+			<version>1.0.0-beta7</version>
349
+		</dependency>
350
+		<dependency>
351
+		    <groupId>org.apache.cassandra</groupId>
352
+		    <artifactId>cassandra-all</artifactId>
353
+		    <version>1.1.4</version>
354
+		    <exclusions>
355
+		        <exclusion>
356
+		            <groupId>org.slf4j</groupId>
357
+		            <artifactId>slf4j-log4j12</artifactId>
358
+		        </exclusion>
359
+		        <exclusion>
360
+		            <groupId>log4j</groupId>
361
+		            <artifactId>log4j</artifactId>
362
+		        </exclusion>
363
+		    </exclusions>
364
+		
365
+		</dependency>
366
+		
337 367
 	</dependencies>
338 368
 	
339 369
 	<build>

+ 133
- 0
src/main/java/kr/co/swh/lecture/opensource/deepleaning4j/IrisClassification.java View File

@@ -0,0 +1,133 @@
1
+package kr.co.swh.lecture.opensource.deepleaning4j; 
2
+
3
+import org.apache.log4j.BasicConfigurator;
4
+import org.datavec.api.records.reader.RecordReader;
5
+import org.datavec.api.records.reader.impl.csv.CSVRecordReader;
6
+import org.datavec.api.split.FileSplit;
7
+import org.deeplearning4j.datasets.datavec.RecordReaderDataSetIterator;
8
+import org.deeplearning4j.eval.Evaluation;
9
+import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
10
+import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
11
+import org.deeplearning4j.nn.conf.layers.DenseLayer;
12
+import org.deeplearning4j.nn.conf.layers.OutputLayer;
13
+import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
14
+import org.deeplearning4j.nn.weights.WeightInit;
15
+import org.nd4j.common.io.ClassPathResource;
16
+import org.nd4j.linalg.activations.Activation;
17
+import org.nd4j.linalg.api.ndarray.INDArray;
18
+import org.nd4j.linalg.dataset.DataSet;
19
+import org.nd4j.linalg.dataset.SplitTestAndTrain;
20
+import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
21
+import org.nd4j.linalg.dataset.api.preprocessor.DataNormalization;
22
+import org.nd4j.linalg.dataset.api.preprocessor.NormalizerStandardize;
23
+import org.nd4j.linalg.learning.config.Nesterovs;
24
+import org.nd4j.linalg.lossfunctions.LossFunctions;
25
+
26
+/**
27
+ * <pre>
28
+ * kr.co.swh.lecture.opensource.deepleaning4j 
29
+ * IrisClassification.java
30
+ *
31
+ * 설명 :
32
+ * 	Iris 꽃 분류 : https://deeplearning4j.konduit.ai/
33
+ * 	개/고양이 인식 분류 Keras 기반 CNN 알고리즘: https://chealin93.tistory.com/69
34
+ * 
35
+ * 	자바 윤곽선 추출 : https://jin-sung.tistory.com/entry/%EC%9D%B4%EB%AF%B8%EC%A7%80-%ED%8C%8C%EC%9D%BC-%EC%9C%A4%EA%B3%BD%EC%84%A0-%EC%B6%94%EC%B6%9C
36
+ * </pre>
37
+ * 
38
+ * @since : 2020. 11. 8.
39
+ * @author : tobby48
40
+ * @version : v1.0
41
+ */
42
+public class IrisClassification {
43
+
44
+    private static final int FEATURES_COUNT = 4;
45
+    private static final int CLASSES_COUNT = 3;
46
+
47
+    public static void main(String[] args) {
48
+
49
+        BasicConfigurator.configure();
50
+        loadData();
51
+
52
+    }
53
+
54
+    private static void loadData() {
55
+        try{
56
+        	//	csv 파일을 Read
57
+        	RecordReader recordReader = new CSVRecordReader(0,',');
58
+            recordReader.initialize(new FileSplit(
59
+            		//	class패스에 추가 되어있는 경로로부터 Read (Build Path에 추가되어 있는 경로)
60
+                    new ClassPathResource("ml/iris.csv").getFile()
61
+            ));
62
+
63
+            //	데이터를 분류하는 기준이 되는 Feature 갯수
64
+            //	데이터를 분류 갯수 (데이터에 분류를 위한 라벨링이 되어 있어야 함)
65
+            //	특정 시드(날짜 정보를 통해 일정한 패턴을 가진 난수) 을 통해 데이터를 셔플
66
+            DataSetIterator iterator = new RecordReaderDataSetIterator(recordReader, 150, FEATURES_COUNT, CLASSES_COUNT);
67
+            DataSet allData = iterator.next();
68
+            allData.shuffle(123);
69
+
70
+            //	정규화(Normalization) : 머신러닝에서 Input되는 데이터는 최초 정규화를 거치게 된다.
71
+            //	정규분포
72
+            //	ex. 우리나라 여성 평균 외모 점수(만약 그런 것이 존재한다면)가 50점이라면, 당연히 50점 부근인 사람이 가장 많을 것이다. 80점 이상인 사람은 그만큼 적을 것이다. 
73
+            //	그런데도 불구하고 어떻게든 80점 이상의 여자와 결혼하려고 한다면 그만큼 결혼확률이 낮아지는 것이다.
74
+            //	종 모양
75
+            //	사용 이유: 
76
+            //		1. 표준화 된 입력을 통해 Gradient Descent 및 Bayesian estimation을 보다 편리하게 수행
77
+            //		2. 0~1로 표준화
78
+            //	Gradient Descent : 깊은 골짜기를 찾고 싶을 때에는 가장 가파른 내리막 방향으로 산을 내려가면 될 것. 미분을 통해 계산
79
+            //	Bayesian estimation : 자신이 생각한 시나리오나 지식이 실제 정보를 통해 Update되고 설득되어지는 알고리즘.
80
+            DataNormalization normalizer = new NormalizerStandardize();
81
+            normalizer.fit(allData);
82
+            normalizer.transform(allData);
83
+
84
+            //	데이터 분할 : 교육용 65%(0.65)와 테스트 용 나머지 35%(0.35)
85
+            SplitTestAndTrain testAndTrain = allData.splitTestAndTrain(0.65);
86
+            DataSet trainingData = testAndTrain.getTrain();
87
+            DataSet testingData = testAndTrain.getTest();
88
+
89
+            irisNNetwork(trainingData, testingData);
90
+
91
+        } catch (Exception e) {
92
+            Thread.dumpStack();
93
+            new Exception("Stack trace").printStackTrace();
94
+            System.out.println("Error: " + e.getLocalizedMessage());
95
+        }
96
+    }
97
+
98
+    private static void irisNNetwork(DataSet trainingData, DataSet testData) {
99
+
100
+    	//	뉴런 신경망
101
+    	//	https://miro.medium.com/max/625/1*VBRB-_ukJfaZ3HHN1CgJCg.png
102
+    	//	https://youtu.be/bfmFfD2RIcg
103
+        MultiLayerConfiguration configuration = new NeuralNetConfiguration.Builder()
104
+                .activation(Activation.TANH)
105
+                .weightInit(WeightInit.XAVIER)
106
+                .updater(new Nesterovs(0.1, 0.9))
107
+                .l2(0.0001)
108
+                .list()
109
+                .layer(0, new DenseLayer.Builder().nIn(FEATURES_COUNT).nOut(3).build())
110
+                .layer(1, new DenseLayer.Builder().nIn(3).nOut(3).build())
111
+                .layer(2, new OutputLayer.Builder(
112
+                        LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD).activation(Activation.SOFTMAX)
113
+                        .nIn(3).nOut(CLASSES_COUNT).build())
114
+//                .backprop(true).pretrain(false)
115
+                .build();
116
+
117
+        //	신경망 모델 생성 (훈련 데이터을 통해)
118
+        MultiLayerNetwork model = new MultiLayerNetwork(configuration);
119
+        model.init();
120
+        model.fit(trainingData);
121
+
122
+        //	훈련된 모델을 테스트 데이터를 통해 평가
123
+        INDArray output = model.output(testData.getFeatures());
124
+        Evaluation eval = new Evaluation(3);
125
+        eval.eval(testData.getLabels(), output);
126
+        System.out.println(eval.stats());
127
+
128
+//        Accuracy:        정확도
129
+//        Precision:       정밀도
130
+//        Recall:          재현율
131
+//        F1 Score:        테스트의 정밀도와 재현율에서 계산되는 정확도를 나타내는 척도
132
+    }
133
+}

+ 150
- 0
src/main/resources/ml/iris.csv View File

@@ -0,0 +1,150 @@
1
+5.1,3.5,1.4,0.2,0
2
+4.9,3,1.4,0.2,0
3
+4.7,3.2,1.3,0.2,0
4
+4.6,3.1,1.5,0.2,0
5
+5,3.6,1.4,0.2,0
6
+5.4,3.9,1.7,0.4,0
7
+4.6,3.4,1.4,0.3,0
8
+5,3.4,1.5,0.2,0
9
+4.4,2.9,1.4,0.2,0
10
+4.9,3.1,1.5,0.1,0
11
+5.4,3.7,1.5,0.2,0
12
+4.8,3.4,1.6,0.2,0
13
+4.8,3,1.4,0.1,0
14
+4.3,3,1.1,0.1,0
15
+5.8,4,1.2,0.2,0
16
+5.7,4.4,1.5,0.4,0
17
+5.4,3.9,1.3,0.4,0
18
+5.1,3.5,1.4,0.3,0
19
+5.7,3.8,1.7,0.3,0
20
+5.1,3.8,1.5,0.3,0
21
+5.4,3.4,1.7,0.2,0
22
+5.1,3.7,1.5,0.4,0
23
+4.6,3.6,1,0.2,0
24
+5.1,3.3,1.7,0.5,0
25
+4.8,3.4,1.9,0.2,0
26
+5,3,1.6,0.2,0
27
+5,3.4,1.6,0.4,0
28
+5.2,3.5,1.5,0.2,0
29
+5.2,3.4,1.4,0.2,0
30
+4.7,3.2,1.6,0.2,0
31
+4.8,3.1,1.6,0.2,0
32
+5.4,3.4,1.5,0.4,0
33
+5.2,4.1,1.5,0.1,0
34
+5.5,4.2,1.4,0.2,0
35
+4.9,3.1,1.5,0.1,0
36
+5,3.2,1.2,0.2,0
37
+5.5,3.5,1.3,0.2,0
38
+4.9,3.1,1.5,0.1,0
39
+4.4,3,1.3,0.2,0
40
+5.1,3.4,1.5,0.2,0
41
+5,3.5,1.3,0.3,0
42
+4.5,2.3,1.3,0.3,0
43
+4.4,3.2,1.3,0.2,0
44
+5,3.5,1.6,0.6,0
45
+5.1,3.8,1.9,0.4,0
46
+4.8,3,1.4,0.3,0
47
+5.1,3.8,1.6,0.2,0
48
+4.6,3.2,1.4,0.2,0
49
+5.3,3.7,1.5,0.2,0
50
+5,3.3,1.4,0.2,0
51
+7,3.2,4.7,1.4,1
52
+6.4,3.2,4.5,1.5,1
53
+6.9,3.1,4.9,1.5,1
54
+5.5,2.3,4,1.3,1
55
+6.5,2.8,4.6,1.5,1
56
+5.7,2.8,4.5,1.3,1
57
+6.3,3.3,4.7,1.6,1
58
+4.9,2.4,3.3,1,1
59
+6.6,2.9,4.6,1.3,1
60
+5.2,2.7,3.9,1.4,1
61
+5,2,3.5,1,1
62
+5.9,3,4.2,1.5,1
63
+6,2.2,4,1,1
64
+6.1,2.9,4.7,1.4,1
65
+5.6,2.9,3.6,1.3,1
66
+6.7,3.1,4.4,1.4,1
67
+5.6,3,4.5,1.5,1
68
+5.8,2.7,4.1,1,1
69
+6.2,2.2,4.5,1.5,1
70
+5.6,2.5,3.9,1.1,1
71
+5.9,3.2,4.8,1.8,1
72
+6.1,2.8,4,1.3,1
73
+6.3,2.5,4.9,1.5,1
74
+6.1,2.8,4.7,1.2,1
75
+6.4,2.9,4.3,1.3,1
76
+6.6,3,4.4,1.4,1
77
+6.8,2.8,4.8,1.4,1
78
+6.7,3,5,1.7,1
79
+6,2.9,4.5,1.5,1
80
+5.7,2.6,3.5,1,1
81
+5.5,2.4,3.8,1.1,1
82
+5.5,2.4,3.7,1,1
83
+5.8,2.7,3.9,1.2,1
84
+6,2.7,5.1,1.6,1
85
+5.4,3,4.5,1.5,1
86
+6,3.4,4.5,1.6,1
87
+6.7,3.1,4.7,1.5,1
88
+6.3,2.3,4.4,1.3,1
89
+5.6,3,4.1,1.3,1
90
+5.5,2.5,4,1.3,1
91
+5.5,2.6,4.4,1.2,1
92
+6.1,3,4.6,1.4,1
93
+5.8,2.6,4,1.2,1
94
+5,2.3,3.3,1,1
95
+5.6,2.7,4.2,1.3,1
96
+5.7,3,4.2,1.2,1
97
+5.7,2.9,4.2,1.3,1
98
+6.2,2.9,4.3,1.3,1
99
+5.1,2.5,3,1.1,1
100
+5.7,2.8,4.1,1.3,1
101
+6.3,3.3,6,2.5,2
102
+5.8,2.7,5.1,1.9,2
103
+7.1,3,5.9,2.1,2
104
+6.3,2.9,5.6,1.8,2
105
+6.5,3,5.8,2.2,2
106
+7.6,3,6.6,2.1,2
107
+4.9,2.5,4.5,1.7,2
108
+7.3,2.9,6.3,1.8,2
109
+6.7,2.5,5.8,1.8,2
110
+7.2,3.6,6.1,2.5,2
111
+6.5,3.2,5.1,2,2
112
+6.4,2.7,5.3,1.9,2
113
+6.8,3,5.5,2.1,2
114
+5.7,2.5,5,2,2
115
+5.8,2.8,5.1,2.4,2
116
+6.4,3.2,5.3,2.3,2
117
+6.5,3,5.5,1.8,2
118
+7.7,3.8,6.7,2.2,2
119
+7.7,2.6,6.9,2.3,2
120
+6,2.2,5,1.5,2
121
+6.9,3.2,5.7,2.3,2
122
+5.6,2.8,4.9,2,2
123
+7.7,2.8,6.7,2,2
124
+6.3,2.7,4.9,1.8,2
125
+6.7,3.3,5.7,2.1,2
126
+7.2,3.2,6,1.8,2
127
+6.2,2.8,4.8,1.8,2
128
+6.1,3,4.9,1.8,2
129
+6.4,2.8,5.6,2.1,2
130
+7.2,3,5.8,1.6,2
131
+7.4,2.8,6.1,1.9,2
132
+7.9,3.8,6.4,2,2
133
+6.4,2.8,5.6,2.2,2
134
+6.3,2.8,5.1,1.5,2
135
+6.1,2.6,5.6,1.4,2
136
+7.7,3,6.1,2.3,2
137
+6.3,3.4,5.6,2.4,2
138
+6.4,3.1,5.5,1.8,2
139
+6,3,4.8,1.8,2
140
+6.9,3.1,5.4,2.1,2
141
+6.7,3.1,5.6,2.4,2
142
+6.9,3.1,5.1,2.3,2
143
+5.8,2.7,5.1,1.9,2
144
+6.8,3.2,5.9,2.3,2
145
+6.7,3.3,5.7,2.5,2
146
+6.7,3,5.2,2.3,2
147
+6.3,2.5,5,1.9,2
148
+6.5,3,5.2,2,2
149
+6.2,3.4,5.4,2.3,2
150
+5.9,3,5.1,1.8,2