Skip to content

Commit bc71e73

Browse files
Add SGDClassifier (#17)
* Add SGDClassifier * Fix complaining linter
1 parent 4505d10 commit bc71e73

File tree

1 file changed

+3
-0
lines changed

1 file changed

+3
-0
lines changed

pureskillgg_dsdk/ds_models/s3_scikit.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,9 @@ def _use_model(self, dataframe):
3434
if self._model_type == "MiniBatchKMeans":
3535
labels = self._loaded_model.predict(dataframe)
3636
return labels
37+
if self._model_type == "SGDClassifier":
38+
labels = self._loaded_model.predict_proba(dataframe)
39+
return labels
3740
raise Exception(f"Unknown model_type {self._model_type}")
3841

3942
def invoke(self, dataframe):

0 commit comments

Comments
 (0)