diff --git a/Project.toml b/Project.toml index 1113998..b00f6a5 100644 --- a/Project.toml +++ b/Project.toml @@ -4,6 +4,7 @@ authors = ["Beacon Biosignals, Inc."] version = "0.3.6" [deps] +CategoricalArrays = "324d7699-5711-5eae-9e2f-1d82baa6b597" MLJModelInterface = "e80e1ace-859a-464e-9ed9-23947d8ae3ea" OrderedCollections = "bac558e1-5e72-5ebc-8fee-abe8a469f55d" PythonCall = "6099a3de-0909-46bc-b1f4-468b9a2dfc0d" @@ -11,6 +12,7 @@ Tables = "bd369af6-aec1-5ad0-b16a-f7cc5008161c" [compat] Aqua = "0.8.4" +CategoricalArrays = "0.10.8" DataFrames = "1.6" MLJBase = "1" MLJModelInterface = "1.7" diff --git a/src/MLJCatBoostInterface.jl b/src/MLJCatBoostInterface.jl index cbf4659..c3d7578 100644 --- a/src/MLJCatBoostInterface.jl +++ b/src/MLJCatBoostInterface.jl @@ -9,6 +9,7 @@ using Tables using MLJModelInterface: MLJModelInterface const MMI = MLJModelInterface using MLJModelInterface: Table, Continuous, Count, Finite, OrderedFactor, Multiclass +using CategoricalArrays: CategoricalArray, CategoricalValue const PKG = "CatBoost" """ diff --git a/src/mlj_catboostclassifier.jl b/src/mlj_catboostclassifier.jl index baad80b..0b21db7 100644 --- a/src/mlj_catboostclassifier.jl +++ b/src/mlj_catboostclassifier.jl @@ -72,8 +72,12 @@ function model_init(mlj_model::CatBoostClassifier; kw...) return catboost.CatBoostClassifier(; mlj_to_kwargs(mlj_model)..., kw...) end +prepare_single_class(y_first) = CategoricalArray([y_first])[1] +prepare_single_class(y_first::CategoricalValue) = y_first + function MMI.fit(mlj_model::CatBoostClassifier, verbosity::Int, data_pool, y_first) - # Check if y_first has only one unique value + # Check if the data pool has only one unique value + y_first = prepare_single_class(y_first) unique_classes = pyconvert(Vector, numpy.unique(data_pool.get_label())) if length(unique_classes) == 1 # Skip training and store the single class @@ -103,8 +107,7 @@ function MMI.predict(mlj_model::CatBoostClassifier, fitresult, X_pool) n = pyconvert(Int, X_pool.shape[0]) classes = [fitresult.single_class] probs = ones(n, 1) - pool = MMI.categorical([fitresult.y_first]) - return MMI.UnivariateFinite(classes, probs; pool=pool) + return MMI.UnivariateFinite(classes, probs; pool=fitresult.y_first) end model, y_first = fitresult