Skip to content

Commit 84a81dc

Browse files
committed
Converting the Scikit-learn model trained to the Core ML Format.
1 parent 4b7b849 commit 84a81dc

File tree

6 files changed

+47
-3
lines changed

6 files changed

+47
-3
lines changed
Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
# Author/ Learner: Nguyen Truong Thinh
2+
# Contact me: [email protected] || +84393280504
3+
#
4+
# Use case: Create a Decision Tree classification model that can be used to convert into the Core ML Format
5+
# via CoreML Tool .
6+
# The model will be trained & converted on the popular UCI ML Iris flowers dataset.
7+
8+
import coremltools
9+
from iris_flowers_decision_tree_model import model_trained
10+
11+
# Export the model to Core ML format
12+
coreml_model = coremltools.converters.sklearn.convert(model_trained, ["sepal length (cm)", "sepal width (cm)", "petal length (cm)", "petal width (cm)"], "target")
13+
coreml_model.author = 'Nguyen Truong Thinh'
14+
coreml_model.short_description = 'A Decision Tree model trained on the Iris flowers dataset.'
15+
# Features descriptions
16+
coreml_model.input_description['sepal length (cm)'] = 'Sepal length in cm.'
17+
coreml_model.input_description['sepal width (cm)'] = 'Sepal width in cm.'
18+
coreml_model.input_description['petal length (cm)'] = 'Petal length in cm.'
19+
coreml_model.input_description['petal width (cm)'] = 'Petal width in cm.'
20+
# Description of target variable
21+
coreml_model.output_description['target'] = 'A categorical value value, 0 = Iris-Setosa, 1 = Iris-Versicolour, 3 = Iris-Virginica'
22+
coreml_model.save("iris_flowers_dtm.mlpackage")
23+
24+
25+

usecases/tabular_classifier/decision_tree_classifer/converting_skitlearn_model_coreml.py

Whitespace-only changes.

usecases/tabular_classifier/decision_tree_classifer/iris_flowers_decision_tree_model.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,15 @@
1+
# Author/ Learner: Nguyen Truong Thinh
2+
# Contact me: [email protected] || +84393280504
3+
#
4+
# Use case: Create a Decision Tree classification model that can be used to convert into the Core ML Format
5+
# via CoreML Tool .
6+
# The model will be trained & converted on the popular UCI ML Iris flowers dataset.
7+
18
import numpy as np
29
import matplotlib.pyplot as plt
310
import pandas as pd
411
import pydotplus
5-
12+
# Scikit-learn: 1.1.2
613
from sklearn import datasets
714
from sklearn.metrics import accuracy_score
815
from sklearn.model_selection import train_test_split
@@ -44,7 +51,7 @@
4451
# https:/scikit-learn.org/stable/modules;generated/sklearn.tree.DecisionTreeClassifier.html
4552
# Train a DTM
4653
model = DecisionTreeClassifier(random_state=17)
47-
model.fit(df_iris_features_train, df_iris_target_train.values.ravel())
54+
model_trained = model.fit(df_iris_features_train, df_iris_target_train.values.ravel())
4855
print(model.feature_importances_)
4956
# Get predictions from model, and compute accuracy
5057
predictions = model.predict(df_iris_features_test)
@@ -56,7 +63,7 @@
5663
export_graphviz(model, out_file=dot_data, filled=True, rounded=True,
5764
special_characters=True, feature_names=df_iris_features.columns)
5865
graph = pydotplus.graph_from_dot_data(dot_data.getvalue())
59-
graph.write_png("iris_flow_dtm.png")
66+
graph.write_png("iris_flowers_dtm.png")
6067
Image(graph.create_png())
6168

6269

963 Bytes
Binary file not shown.
Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
{
2+
"fileFormatVersion": "1.0.0",
3+
"itemInfoEntries": {
4+
"7360340A-F6CF-46AC-99E4-3E2868896A7B": {
5+
"author": "com.apple.CoreML",
6+
"description": "CoreML Model Specification",
7+
"name": "model.mlmodel",
8+
"path": "com.apple.CoreML/model.mlmodel"
9+
}
10+
},
11+
"rootModelIdentifier": "7360340A-F6CF-46AC-99E4-3E2868896A7B"
12+
}
File renamed without changes.

0 commit comments

Comments
 (0)