From 30706a563e2cb1d51958705097f13edbcc57eb42 Mon Sep 17 00:00:00 2001 From: Tyler Thomas Date: Thu, 21 Mar 2024 23:27:28 -0700 Subject: [PATCH 1/3] Fix MLJ Serialization --- src/mlj_serialization.jl | 63 +++++++++++++++++++++++++++++++--------- test/mlj_interface.jl | 4 +-- 2 files changed, 52 insertions(+), 15 deletions(-) diff --git a/src/mlj_serialization.jl b/src/mlj_serialization.jl index b4c4088..0a3d4c9 100644 --- a/src/mlj_serialization.jl +++ b/src/mlj_serialization.jl @@ -1,24 +1,41 @@ # Taken from https://github.com/JuliaAI/MLJXGBoostInterface.jl # It is likely also not the optimal method for serializing models, but it works + """ - _persistent(booster) + _persistent(model::CatBoostModels, fitresult) Private method. Return a persistent (ie, Julia-serializable) representation of the -CatBoost.jl model `booster`. +CatBoost.jl model `fitresult`. -Restore the model with [`booster`](@ref) +Restore the model with [`fitresult`](@ref) """ -function _persistent(booster) +function _persistent(::CatBoostRegressor, fitresult) ctb_file, io = mktemp() close(io) - booster.save_model(ctb_file) + fitresult.save_model(ctb_file) persistent_booster = read(ctb_file) rm(ctb_file) return persistent_booster end +function _persistent(::CatBoostClassifier, fitresult) + model, y_first = fitresult + if model === nothing + # Case 1: Single unique class + return (nothing, fitresult.single_class, y_first) + else + # Case 2: Multiple unique classes + ctb_file, io = mktemp() + close(io) + + model.save_model(ctb_file) + persistent_booster = read(ctb_file) + rm(ctb_file) + return (persistent_booster, y_first) + end +end """ _booster(persistent) @@ -28,24 +45,44 @@ Private method. Return the CatBoost.jl model which has `persistent` as its persistent (Julia-serializable) representation. See [`persistent`](@ref) method. """ -function _booster(persistent) +function _booster(::CatBoostRegressor, persistent) ctb_file, io = mktemp() write(io, persistent) close(io) - booster = catboost.CatBoost().load_model(ctb_file) + booster = catboost.CatBoostRegressor().load_model(ctb_file) rm(ctb_file) return booster end +function _booster(::CatBoostClassifier, persistent) + ctb_file, io = mktemp() + write(io, persistent) + close(io) + + booster = catboost.CatBoostClassifier().load_model(ctb_file) -function MMI.save(::CatBoostModels, fr; kw...) - (booster, a_target_element) = fr - return (_persistent(booster), a_target_element) + rm(ctb_file) + + return booster end -function MMI.restore(::CatBoostModels, fr) - (persistent, a_target_element) = fr - return (_booster(persistent), a_target_element) +function MMI.save(model::CatBoostModels, fitresult; kwargs...) + return _persistent(model, fitresult) end + +function MMI.restore(model::CatBoostRegressor, serializable_fitresult) + return _booster(model, serializable_fitresult) +end + +function MMI.restore(model::CatBoostClassifier, serializable_fitresult) + if serializable_fitresult[1] === nothing + # Case 1: Single unique class + return (model=nothing, single_class=serializable_fitresult[2], y_first=serializable_fitresult[3]) + else + # Case 2: Multiple unique classes + persistent_booster, y_first = serializable_fitresult + return (_booster(model, persistent_booster), y_first) + end +end \ No newline at end of file diff --git a/test/mlj_interface.jl b/test/mlj_interface.jl index bd962ba..fae3c08 100644 --- a/test/mlj_interface.jl +++ b/test/mlj_interface.jl @@ -22,7 +22,7 @@ preds = MLJBase.predict(mach, X) probs = MLJBase.predict_mode(mach, X) - serializable_fitresult = MLJBase.save(mach, mach.fitresult) + serializable_fitresult = MLJBase.save(mach, mach) restored_fitresult = MLJBase.restore(mach, serializable_fitresult) end @@ -36,7 +36,7 @@ MLJBase.fit!(mach) preds = MLJBase.predict(mach, X) - serializable_fitresult = MLJBase.save(mach, mach.fitresult) + serializable_fitresult = MLJBase.save(mach, mach) restored_fitresult = MLJBase.restore(mach, serializable_fitresult) end From 4936b57d12ade59dcfbf23128e137fab99307fa5 Mon Sep 17 00:00:00 2001 From: Tyler Thomas Date: Thu, 21 Mar 2024 23:28:00 -0700 Subject: [PATCH 2/3] reformat --- src/mlj_serialization.jl | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/mlj_serialization.jl b/src/mlj_serialization.jl index 0a3d4c9..b4f5914 100644 --- a/src/mlj_serialization.jl +++ b/src/mlj_serialization.jl @@ -79,10 +79,11 @@ end function MMI.restore(model::CatBoostClassifier, serializable_fitresult) if serializable_fitresult[1] === nothing # Case 1: Single unique class - return (model=nothing, single_class=serializable_fitresult[2], y_first=serializable_fitresult[3]) + return (model=nothing, single_class=serializable_fitresult[2], + y_first=serializable_fitresult[3]) else # Case 2: Multiple unique classes persistent_booster, y_first = serializable_fitresult return (_booster(model, persistent_booster), y_first) end -end \ No newline at end of file +end From e026b626d7e547e375119e7c791ac1acc885f2b1 Mon Sep 17 00:00:00 2001 From: Tyler Thomas Date: Thu, 21 Mar 2024 23:55:05 -0700 Subject: [PATCH 3/3] add test for single class classifiers, fix single class classifiers predict --- Project.toml | 2 +- src/mlj_catboostclassifier.jl | 9 +++++---- test/mlj_interface.jl | 17 +++++++++++++++++ 3 files changed, 23 insertions(+), 5 deletions(-) diff --git a/Project.toml b/Project.toml index e0c7ce8..cdd53f3 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "CatBoost" uuid = "e2e10f9a-a85d-4fa9-b6b2-639a32100a12" authors = ["Beacon Biosignals, Inc."] -version = "0.3.4" +version = "0.3.5" [deps] MLJModelInterface = "e80e1ace-859a-464e-9ed9-23947d8ae3ea" diff --git a/src/mlj_catboostclassifier.jl b/src/mlj_catboostclassifier.jl index 51e9f5f..baad80b 100644 --- a/src/mlj_catboostclassifier.jl +++ b/src/mlj_catboostclassifier.jl @@ -100,10 +100,11 @@ MMI.reports_feature_importances(::Type{<:CatBoostClassifier}) = true function MMI.predict(mlj_model::CatBoostClassifier, fitresult, X_pool) if fitresult[1] === nothing # Always predict the single class - n = nrow(X_pool) + n = pyconvert(Int, X_pool.shape[0]) classes = [fitresult.single_class] probs = ones(n, 1) - return MMI.UnivariateFinite(classes, probs; pool=fitresult.y_first) + pool = MMI.categorical([fitresult.y_first]) + return MMI.UnivariateFinite(classes, probs; pool=pool) end model, y_first = fitresult @@ -116,8 +117,8 @@ end function MMI.predict_mode(mlj_model::CatBoostClassifier, fitresult, X_pool) if fitresult[1] === nothing # Return probability 1 for the single class - n = nrow(X_pool) - return hcat(ones(n), zeros(n)) + n = pyconvert(Int, X_pool.shape[0]) + return fill(fitresult.y_first, n) end model, y_first = fitresult diff --git a/test/mlj_interface.jl b/test/mlj_interface.jl index fae3c08..0c374a1 100644 --- a/test/mlj_interface.jl +++ b/test/mlj_interface.jl @@ -26,6 +26,23 @@ restored_fitresult = MLJBase.restore(mach, serializable_fitresult) end + @testset "CatBoostClassifier - single class" begin + X = (; a=[1, 4, 5, 6], b=[4, 5, 6, 7]) + y = [0, 0, 0, 0] + + # MLJ Interface + model = CatBoostClassifier(; iterations=5) + mach = machine(model, X, y) + MLJBase.fit!(mach) + preds = MLJBase.predict(mach, X) + println(preds) + probs = MLJBase.predict_mode(mach, X) + println(probs) + + serializable_fitresult = MLJBase.save(mach, mach) + restored_fitresult = MLJBase.restore(mach, serializable_fitresult) + end + @testset "CatBoostRegressor" begin X = (; a=[1, 4, 5, 6], b=[4, 5, 6, 7]) y = [2.0, 4.0, 6.0, 7.0]