-
Notifications
You must be signed in to change notification settings - Fork 10
/
convert_classifier.py
52 lines (41 loc) · 1.6 KB
/
convert_classifier.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
import pickle
import joblib
import os
from pure_sklearn.map import convert_estimator
def load_classifier_SIMBA(path_to_sav):
"""Load saved classifier"""
file = open(path_to_sav, "rb")
classifier = pickle.load(file)
file.close()
return classifier
def load_classifier_BSOID(path_to_sav):
"""Load saved classifier"""
file = open(path_to_sav, "rb")
clf = joblib.load(file)
file.close()
return clf
def convert_classifier(path, origin: str):
# convert to pure python estimator
dir_path = os.path.dirname(path)
filename = os.path.basename(path)
filename, _ = filename.split(".")
print("Loading classifier...")
if origin.lower() == 'simba':
clf = load_classifier_SIMBA(path)
clf_pure_predict = convert_estimator(clf)
with open(dir_path + "/" + filename + "_pure.sav", "wb") as f:
pickle.dump(clf_pure_predict, f)
elif origin.lower() == 'bsoid':
clf_pack = load_classifier_BSOID(path)
# bsoid exported classfier has format [a, b, c, clf, d, e]
clf_pure_predict = convert_estimator(clf_pack[3])
clf_pack[3] =clf_pure_predict
with open(dir_path + "/" + filename + "_pure.sav", "wb") as f:
joblib.dump(clf_pack, f)
else:
raise ValueError(f'{origin} is not a valid classifier origin.')
print(f"Converted Classifier {filename}")
if __name__ == "__main__":
"""Converted BSOID Classifiers are not integrated yet, although you can already convert them here"""
path_to_classifier = "PATH_TO_CLASSIFIER"
convert_classifier(path_to_classifier, origin= 'SIMBA')