From e9a7907a5305bc33a0855c2a511c1f488e5d5542 Mon Sep 17 00:00:00 2001 From: Tyler Thomas Date: Tue, 31 Dec 2024 11:22:58 -0700 Subject: [PATCH] Fix `MMI.predict`, ensure `fitresult` is a named tuple --- src/mlj_catboostclassifier.jl | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/src/mlj_catboostclassifier.jl b/src/mlj_catboostclassifier.jl index 7ad4d67..ed31def 100644 --- a/src/mlj_catboostclassifier.jl +++ b/src/mlj_catboostclassifier.jl @@ -93,7 +93,7 @@ function MMI.fit(mlj_model::CatBoostClassifier, verbosity::Int, data_pool, y_fir cache = (; mlj_model=deepcopy(mlj_model)) report = (feature_importances=feature_importance(model),) - fitresult = (model, y_first) + fitresult = (model=model, y_first=y_first) return (fitresult, cache, report) end @@ -102,18 +102,17 @@ MMI.fitted_params(::CatBoostClassifier, model) = (model=model,) MMI.reports_feature_importances(::Type{<:CatBoostClassifier}) = true function MMI.predict(mlj_model::CatBoostClassifier, fitresult, X_pool) - model, y_first = fitresult - if model === nothing + if fitresult.model === nothing # Always predict the single class n = pyconvert(Int, X_pool.shape[0]) classes = [fitresult.single_class] probs = ones(n, 1) - return MMI.UnivariateFinite(classes, probs; pool=y_first.pool) + return MMI.UnivariateFinite(classes, probs; pool=fitresult.y_first.pool) end - classes = pyconvert(Array, model.classes_.tolist()) - py_preds = predict_proba(model, X_pool) - preds = MMI.UnivariateFinite(classes, pyconvert(Array, py_preds); pool=y_first.pool) + classes = pyconvert(Array, fitresult.model.classes_.tolist()) + py_preds = predict_proba(fitresult.model, X_pool) + preds = MMI.UnivariateFinite(classes, pyconvert(Array, py_preds); pool=fitresult.y_first.pool) return preds end