Skip to content

Commit 236bab4

Browse files
committed
test experimental DJL support
1 parent 3bcb219 commit 236bab4

File tree

1 file changed

+59
-0
lines changed

1 file changed

+59
-0
lines changed
Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
/*
2+
* SPDX-License-Identifier: Apache-2.0
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* https://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
import org.knowm.xchart.SwingWrapper
17+
import org.knowm.xchart.XYChartBuilder
18+
import weka.classifiers.AbstractClassifier
19+
import weka.core.Utils
20+
import weka.core.WekaPackageManager
21+
import weka.core.converters.CSVLoader
22+
23+
import static org.knowm.xchart.XYSeries.XYSeriesRenderStyle.Scatter
24+
25+
WekaPackageManager.loadPackages(true)
26+
27+
def file = getClass().classLoader.getResource('iris_data.csv').file as File
28+
println file
29+
def species = ['Iris-setosa', 'Iris-versicolor', 'Iris-virginica']
30+
def loader = new CSVLoader(file: file)
31+
def data = loader.dataSet
32+
data.classIndex = 4
33+
34+
def options = Utils.splitOptions("-S 1")
35+
AbstractClassifier classifier = Utils.forName(AbstractClassifier, "weka.classifiers.djl.DJLRegressor", options)
36+
37+
double[] actual = data.collect{ it.value(4) }
38+
double[] predicted = data.collect{ classifier.classifyInstance(it) }
39+
double[] petalW = data.collect{ it.value(2) }
40+
double[] petalL = data.collect{ it.value(3) }
41+
def indices = actual.indices
42+
43+
def chart = new XYChartBuilder().width(900).height(450).
44+
title("Species").xAxisTitle("Petal length").yAxisTitle("Petal width").build()
45+
species.eachWithIndex{ String name, int i ->
46+
def groups = indices.findAll{ predicted[it] == i }.groupBy{ actual[it] == i }
47+
Collection found = groups[true] ?: []
48+
Collection errors = groups[false] ?: []
49+
println "$name: ${found.size()} correct, ${errors.size()} incorrect"
50+
chart.addSeries("$name correct", petalW[found], petalL[found]).with {
51+
XYSeriesRenderStyle = Scatter
52+
}
53+
if (errors) {
54+
chart.addSeries("$name incorrect", petalW[errors], petalL[errors]).with {
55+
XYSeriesRenderStyle = Scatter
56+
}
57+
}
58+
}
59+
new SwingWrapper(chart).displayChart()

0 commit comments

Comments
 (0)