Skip to content

Commit

Permalink
Fix MMI.predict, ensure fitresult is a named tuple
Browse files Browse the repository at this point in the history
  • Loading branch information
tylerjthomas9 committed Dec 31, 2024
1 parent 7a52755 commit e9a7907
Showing 1 changed file with 6 additions and 7 deletions.
13 changes: 6 additions & 7 deletions src/mlj_catboostclassifier.jl
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ function MMI.fit(mlj_model::CatBoostClassifier, verbosity::Int, data_pool, y_fir

cache = (; mlj_model=deepcopy(mlj_model))
report = (feature_importances=feature_importance(model),)
fitresult = (model, y_first)
fitresult = (model=model, y_first=y_first)

return (fitresult, cache, report)
end
Expand All @@ -102,18 +102,17 @@ MMI.fitted_params(::CatBoostClassifier, model) = (model=model,)
MMI.reports_feature_importances(::Type{<:CatBoostClassifier}) = true

function MMI.predict(mlj_model::CatBoostClassifier, fitresult, X_pool)
model, y_first = fitresult
if model === nothing
if fitresult.model === nothing
# Always predict the single class
n = pyconvert(Int, X_pool.shape[0])
classes = [fitresult.single_class]
probs = ones(n, 1)
return MMI.UnivariateFinite(classes, probs; pool=y_first.pool)
return MMI.UnivariateFinite(classes, probs; pool=fitresult.y_first.pool)
end

classes = pyconvert(Array, model.classes_.tolist())
py_preds = predict_proba(model, X_pool)
preds = MMI.UnivariateFinite(classes, pyconvert(Array, py_preds); pool=y_first.pool)
classes = pyconvert(Array, fitresult.model.classes_.tolist())
py_preds = predict_proba(fitresult.model, X_pool)
preds = MMI.UnivariateFinite(classes, pyconvert(Array, py_preds); pool=fitresult.y_first.pool)
return preds
end

Expand Down

0 comments on commit e9a7907

Please sign in to comment.