diff --git a/Project.toml b/Project.toml index e0c7ce8..cdd53f3 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "CatBoost" uuid = "e2e10f9a-a85d-4fa9-b6b2-639a32100a12" authors = ["Beacon Biosignals, Inc."] -version = "0.3.4" +version = "0.3.5" [deps] MLJModelInterface = "e80e1ace-859a-464e-9ed9-23947d8ae3ea" diff --git a/src/mlj_catboostclassifier.jl b/src/mlj_catboostclassifier.jl index 51e9f5f..baad80b 100644 --- a/src/mlj_catboostclassifier.jl +++ b/src/mlj_catboostclassifier.jl @@ -100,10 +100,11 @@ MMI.reports_feature_importances(::Type{<:CatBoostClassifier}) = true function MMI.predict(mlj_model::CatBoostClassifier, fitresult, X_pool) if fitresult[1] === nothing # Always predict the single class - n = nrow(X_pool) + n = pyconvert(Int, X_pool.shape[0]) classes = [fitresult.single_class] probs = ones(n, 1) - return MMI.UnivariateFinite(classes, probs; pool=fitresult.y_first) + pool = MMI.categorical([fitresult.y_first]) + return MMI.UnivariateFinite(classes, probs; pool=pool) end model, y_first = fitresult @@ -116,8 +117,8 @@ end function MMI.predict_mode(mlj_model::CatBoostClassifier, fitresult, X_pool) if fitresult[1] === nothing # Return probability 1 for the single class - n = nrow(X_pool) - return hcat(ones(n), zeros(n)) + n = pyconvert(Int, X_pool.shape[0]) + return fill(fitresult.y_first, n) end model, y_first = fitresult diff --git a/src/mlj_serialization.jl b/src/mlj_serialization.jl index b4c4088..b4f5914 100644 --- a/src/mlj_serialization.jl +++ b/src/mlj_serialization.jl @@ -1,24 +1,41 @@ # Taken from https://github.com/JuliaAI/MLJXGBoostInterface.jl # It is likely also not the optimal method for serializing models, but it works + """ - _persistent(booster) + _persistent(model::CatBoostModels, fitresult) Private method. Return a persistent (ie, Julia-serializable) representation of the -CatBoost.jl model `booster`. +CatBoost.jl model `fitresult`. -Restore the model with [`booster`](@ref) +Restore the model with [`fitresult`](@ref) """ -function _persistent(booster) +function _persistent(::CatBoostRegressor, fitresult) ctb_file, io = mktemp() close(io) - booster.save_model(ctb_file) + fitresult.save_model(ctb_file) persistent_booster = read(ctb_file) rm(ctb_file) return persistent_booster end +function _persistent(::CatBoostClassifier, fitresult) + model, y_first = fitresult + if model === nothing + # Case 1: Single unique class + return (nothing, fitresult.single_class, y_first) + else + # Case 2: Multiple unique classes + ctb_file, io = mktemp() + close(io) + + model.save_model(ctb_file) + persistent_booster = read(ctb_file) + rm(ctb_file) + return (persistent_booster, y_first) + end +end """ _booster(persistent) @@ -28,24 +45,45 @@ Private method. Return the CatBoost.jl model which has `persistent` as its persistent (Julia-serializable) representation. See [`persistent`](@ref) method. """ -function _booster(persistent) +function _booster(::CatBoostRegressor, persistent) + ctb_file, io = mktemp() + write(io, persistent) + close(io) + + booster = catboost.CatBoostRegressor().load_model(ctb_file) + + rm(ctb_file) + + return booster +end +function _booster(::CatBoostClassifier, persistent) ctb_file, io = mktemp() write(io, persistent) close(io) - booster = catboost.CatBoost().load_model(ctb_file) + booster = catboost.CatBoostClassifier().load_model(ctb_file) rm(ctb_file) return booster end -function MMI.save(::CatBoostModels, fr; kw...) - (booster, a_target_element) = fr - return (_persistent(booster), a_target_element) +function MMI.save(model::CatBoostModels, fitresult; kwargs...) + return _persistent(model, fitresult) +end + +function MMI.restore(model::CatBoostRegressor, serializable_fitresult) + return _booster(model, serializable_fitresult) end -function MMI.restore(::CatBoostModels, fr) - (persistent, a_target_element) = fr - return (_booster(persistent), a_target_element) +function MMI.restore(model::CatBoostClassifier, serializable_fitresult) + if serializable_fitresult[1] === nothing + # Case 1: Single unique class + return (model=nothing, single_class=serializable_fitresult[2], + y_first=serializable_fitresult[3]) + else + # Case 2: Multiple unique classes + persistent_booster, y_first = serializable_fitresult + return (_booster(model, persistent_booster), y_first) + end end diff --git a/test/mlj_interface.jl b/test/mlj_interface.jl index bd962ba..0c374a1 100644 --- a/test/mlj_interface.jl +++ b/test/mlj_interface.jl @@ -22,7 +22,24 @@ preds = MLJBase.predict(mach, X) probs = MLJBase.predict_mode(mach, X) - serializable_fitresult = MLJBase.save(mach, mach.fitresult) + serializable_fitresult = MLJBase.save(mach, mach) + restored_fitresult = MLJBase.restore(mach, serializable_fitresult) + end + + @testset "CatBoostClassifier - single class" begin + X = (; a=[1, 4, 5, 6], b=[4, 5, 6, 7]) + y = [0, 0, 0, 0] + + # MLJ Interface + model = CatBoostClassifier(; iterations=5) + mach = machine(model, X, y) + MLJBase.fit!(mach) + preds = MLJBase.predict(mach, X) + println(preds) + probs = MLJBase.predict_mode(mach, X) + println(probs) + + serializable_fitresult = MLJBase.save(mach, mach) restored_fitresult = MLJBase.restore(mach, serializable_fitresult) end @@ -36,7 +53,7 @@ MLJBase.fit!(mach) preds = MLJBase.predict(mach, X) - serializable_fitresult = MLJBase.save(mach, mach.fitresult) + serializable_fitresult = MLJBase.save(mach, mach) restored_fitresult = MLJBase.restore(mach, serializable_fitresult) end