-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathExample1.java
More file actions
75 lines (55 loc) · 2.09 KB
/
Example1.java
File metadata and controls
75 lines (55 loc) · 2.09 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
package applications.ml;
import datastructs.maths.DenseMatrixSet;
import datastructs.maths.RowBuilder;
import datastructs.maths.Vector;
import datastructs.utils.RowType;
import maths.functions.distances.EuclideanVectorCalculator;
import ml.classifiers.KNNClassifier;
import utils.ClassificationVoter;
import java.util.ArrayList;
import java.util.List;
/** Category: Machine Learning
* ID: Example1
* Description: Classification with vanilla KNN algorithm
* Taken From:
* Details:
* TODO
*/
public class Example1 {
public static void main(String[] args){
DenseMatrixSet<Double> dataSet = new DenseMatrixSet(RowType.Type.DOUBLE_VECTOR, new RowBuilder());
dataSet.create(12, 2);
dataSet.set(0, 1.0, 3.0);
dataSet.set(1, 1.5, 2.0);
dataSet.set(2, 2.0, 1.0);
dataSet.set(3, 2.5, 4.0 );
dataSet.set(4, 3.0, 1.5);
dataSet.set(5, 3.5, 2.5);
dataSet.set(6, 5.0, 5.0);
dataSet.set(7, 5.5, 4.0);
dataSet.set(8, 6.0, 6.0);
dataSet.set(9, 6.5, 4.5 );
dataSet.set(10, 7.0, 1.5);
dataSet.set(11, 8.0, 2.5);
List<Integer> labels = new ArrayList<>(dataSet.m());
for(int i=0; i<6; ++i) {
labels.add(0);
}
for(int i = 6; i<dataSet.m(); ++i) {
labels.add(1);
}
KNNClassifier<Double, DenseMatrixSet<Double>,
EuclideanVectorCalculator<Double>,
ClassificationVoter> classifier = new KNNClassifier<Double, DenseMatrixSet<Double>,
EuclideanVectorCalculator<Double>, ClassificationVoter>(2, false);
classifier.setDistanceCalculator(new EuclideanVectorCalculator<Double>());
classifier.setMajorityVoter(new ClassificationVoter());
classifier.train(dataSet, labels);
Vector r = new Vector(3.1, 2.2);
Integer classIdx = classifier.predict(r);
System.out.println("Point "+r+" has class index "+ classIdx);
r = new Vector(9.1, 6.2);
classIdx = classifier.predict(r);
System.out.println("Point "+r+" has class index "+ classIdx);
}
}