|
@@ -0,0 +1,133 @@
|
|
1
|
+package kr.co.swh.lecture.opensource.deepleaning4j;
|
|
2
|
+
|
|
3
|
+import org.apache.log4j.BasicConfigurator;
|
|
4
|
+import org.datavec.api.records.reader.RecordReader;
|
|
5
|
+import org.datavec.api.records.reader.impl.csv.CSVRecordReader;
|
|
6
|
+import org.datavec.api.split.FileSplit;
|
|
7
|
+import org.deeplearning4j.datasets.datavec.RecordReaderDataSetIterator;
|
|
8
|
+import org.deeplearning4j.eval.Evaluation;
|
|
9
|
+import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
|
|
10
|
+import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
|
|
11
|
+import org.deeplearning4j.nn.conf.layers.DenseLayer;
|
|
12
|
+import org.deeplearning4j.nn.conf.layers.OutputLayer;
|
|
13
|
+import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
|
|
14
|
+import org.deeplearning4j.nn.weights.WeightInit;
|
|
15
|
+import org.nd4j.common.io.ClassPathResource;
|
|
16
|
+import org.nd4j.linalg.activations.Activation;
|
|
17
|
+import org.nd4j.linalg.api.ndarray.INDArray;
|
|
18
|
+import org.nd4j.linalg.dataset.DataSet;
|
|
19
|
+import org.nd4j.linalg.dataset.SplitTestAndTrain;
|
|
20
|
+import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
|
|
21
|
+import org.nd4j.linalg.dataset.api.preprocessor.DataNormalization;
|
|
22
|
+import org.nd4j.linalg.dataset.api.preprocessor.NormalizerStandardize;
|
|
23
|
+import org.nd4j.linalg.learning.config.Nesterovs;
|
|
24
|
+import org.nd4j.linalg.lossfunctions.LossFunctions;
|
|
25
|
+
|
|
26
|
+/**
|
|
27
|
+ * <pre>
|
|
28
|
+ * kr.co.swh.lecture.opensource.deepleaning4j
|
|
29
|
+ * IrisClassification.java
|
|
30
|
+ *
|
|
31
|
+ * 설명 :
|
|
32
|
+ * Iris 꽃 분류 : https://deeplearning4j.konduit.ai/
|
|
33
|
+ * 개/고양이 인식 분류 Keras 기반 CNN 알고리즘: https://chealin93.tistory.com/69
|
|
34
|
+ *
|
|
35
|
+ * 자바 윤곽선 추출 : https://jin-sung.tistory.com/entry/%EC%9D%B4%EB%AF%B8%EC%A7%80-%ED%8C%8C%EC%9D%BC-%EC%9C%A4%EA%B3%BD%EC%84%A0-%EC%B6%94%EC%B6%9C
|
|
36
|
+ * </pre>
|
|
37
|
+ *
|
|
38
|
+ * @since : 2020. 11. 8.
|
|
39
|
+ * @author : tobby48
|
|
40
|
+ * @version : v1.0
|
|
41
|
+ */
|
|
42
|
+public class IrisClassification {
|
|
43
|
+
|
|
44
|
+ private static final int FEATURES_COUNT = 4;
|
|
45
|
+ private static final int CLASSES_COUNT = 3;
|
|
46
|
+
|
|
47
|
+ public static void main(String[] args) {
|
|
48
|
+
|
|
49
|
+ BasicConfigurator.configure();
|
|
50
|
+ loadData();
|
|
51
|
+
|
|
52
|
+ }
|
|
53
|
+
|
|
54
|
+ private static void loadData() {
|
|
55
|
+ try{
|
|
56
|
+ // csv 파일을 Read
|
|
57
|
+ RecordReader recordReader = new CSVRecordReader(0,',');
|
|
58
|
+ recordReader.initialize(new FileSplit(
|
|
59
|
+ // class패스에 추가 되어있는 경로로부터 Read (Build Path에 추가되어 있는 경로)
|
|
60
|
+ new ClassPathResource("ml/iris.csv").getFile()
|
|
61
|
+ ));
|
|
62
|
+
|
|
63
|
+ // 데이터를 분류하는 기준이 되는 Feature 갯수
|
|
64
|
+ // 데이터를 분류 갯수 (데이터에 분류를 위한 라벨링이 되어 있어야 함)
|
|
65
|
+ // 특정 시드(날짜 정보를 통해 일정한 패턴을 가진 난수) 을 통해 데이터를 셔플
|
|
66
|
+ DataSetIterator iterator = new RecordReaderDataSetIterator(recordReader, 150, FEATURES_COUNT, CLASSES_COUNT);
|
|
67
|
+ DataSet allData = iterator.next();
|
|
68
|
+ allData.shuffle(123);
|
|
69
|
+
|
|
70
|
+ // 정규화(Normalization) : 머신러닝에서 Input되는 데이터는 최초 정규화를 거치게 된다.
|
|
71
|
+ // 정규분포
|
|
72
|
+ // ex. 우리나라 여성 평균 외모 점수(만약 그런 것이 존재한다면)가 50점이라면, 당연히 50점 부근인 사람이 가장 많을 것이다. 80점 이상인 사람은 그만큼 적을 것이다.
|
|
73
|
+ // 그런데도 불구하고 어떻게든 80점 이상의 여자와 결혼하려고 한다면 그만큼 결혼확률이 낮아지는 것이다.
|
|
74
|
+ // 종 모양
|
|
75
|
+ // 사용 이유:
|
|
76
|
+ // 1. 표준화 된 입력을 통해 Gradient Descent 및 Bayesian estimation을 보다 편리하게 수행
|
|
77
|
+ // 2. 0~1로 표준화
|
|
78
|
+ // Gradient Descent : 깊은 골짜기를 찾고 싶을 때에는 가장 가파른 내리막 방향으로 산을 내려가면 될 것. 미분을 통해 계산
|
|
79
|
+ // Bayesian estimation : 자신이 생각한 시나리오나 지식이 실제 정보를 통해 Update되고 설득되어지는 알고리즘.
|
|
80
|
+ DataNormalization normalizer = new NormalizerStandardize();
|
|
81
|
+ normalizer.fit(allData);
|
|
82
|
+ normalizer.transform(allData);
|
|
83
|
+
|
|
84
|
+ // 데이터 분할 : 교육용 65%(0.65)와 테스트 용 나머지 35%(0.35)
|
|
85
|
+ SplitTestAndTrain testAndTrain = allData.splitTestAndTrain(0.65);
|
|
86
|
+ DataSet trainingData = testAndTrain.getTrain();
|
|
87
|
+ DataSet testingData = testAndTrain.getTest();
|
|
88
|
+
|
|
89
|
+ irisNNetwork(trainingData, testingData);
|
|
90
|
+
|
|
91
|
+ } catch (Exception e) {
|
|
92
|
+ Thread.dumpStack();
|
|
93
|
+ new Exception("Stack trace").printStackTrace();
|
|
94
|
+ System.out.println("Error: " + e.getLocalizedMessage());
|
|
95
|
+ }
|
|
96
|
+ }
|
|
97
|
+
|
|
98
|
+ private static void irisNNetwork(DataSet trainingData, DataSet testData) {
|
|
99
|
+
|
|
100
|
+ // 뉴런 신경망
|
|
101
|
+ // https://miro.medium.com/max/625/1*VBRB-_ukJfaZ3HHN1CgJCg.png
|
|
102
|
+ // https://youtu.be/bfmFfD2RIcg
|
|
103
|
+ MultiLayerConfiguration configuration = new NeuralNetConfiguration.Builder()
|
|
104
|
+ .activation(Activation.TANH)
|
|
105
|
+ .weightInit(WeightInit.XAVIER)
|
|
106
|
+ .updater(new Nesterovs(0.1, 0.9))
|
|
107
|
+ .l2(0.0001)
|
|
108
|
+ .list()
|
|
109
|
+ .layer(0, new DenseLayer.Builder().nIn(FEATURES_COUNT).nOut(3).build())
|
|
110
|
+ .layer(1, new DenseLayer.Builder().nIn(3).nOut(3).build())
|
|
111
|
+ .layer(2, new OutputLayer.Builder(
|
|
112
|
+ LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD).activation(Activation.SOFTMAX)
|
|
113
|
+ .nIn(3).nOut(CLASSES_COUNT).build())
|
|
114
|
+// .backprop(true).pretrain(false)
|
|
115
|
+ .build();
|
|
116
|
+
|
|
117
|
+ // 신경망 모델 생성 (훈련 데이터을 통해)
|
|
118
|
+ MultiLayerNetwork model = new MultiLayerNetwork(configuration);
|
|
119
|
+ model.init();
|
|
120
|
+ model.fit(trainingData);
|
|
121
|
+
|
|
122
|
+ // 훈련된 모델을 테스트 데이터를 통해 평가
|
|
123
|
+ INDArray output = model.output(testData.getFeatures());
|
|
124
|
+ Evaluation eval = new Evaluation(3);
|
|
125
|
+ eval.eval(testData.getLabels(), output);
|
|
126
|
+ System.out.println(eval.stats());
|
|
127
|
+
|
|
128
|
+// Accuracy: 정확도
|
|
129
|
+// Precision: 정밀도
|
|
130
|
+// Recall: 재현율
|
|
131
|
+// F1 Score: 테스트의 정밀도와 재현율에서 계산되는 정확도를 나타내는 척도
|
|
132
|
+ }
|
|
133
|
+}
|