Skip to content

Commit

Permalink
Pass the pool object to UnivariateFinite
Browse files Browse the repository at this point in the history
  • Loading branch information
tylerjthomas9 committed Dec 31, 2024
1 parent cece1f0 commit 7a52755
Showing 1 changed file with 4 additions and 4 deletions.
8 changes: 4 additions & 4 deletions src/mlj_catboostclassifier.jl
Original file line number Diff line number Diff line change
Expand Up @@ -102,18 +102,18 @@ MMI.fitted_params(::CatBoostClassifier, model) = (model=model,)
MMI.reports_feature_importances(::Type{<:CatBoostClassifier}) = true

function MMI.predict(mlj_model::CatBoostClassifier, fitresult, X_pool)
if fitresult[1] === nothing
model, y_first = fitresult
if model === nothing
# Always predict the single class
n = pyconvert(Int, X_pool.shape[0])
classes = [fitresult.single_class]
probs = ones(n, 1)
return MMI.UnivariateFinite(classes, probs; pool=fitresult.y_first)
return MMI.UnivariateFinite(classes, probs; pool=y_first.pool)
end

model, y_first = fitresult
classes = pyconvert(Array, model.classes_.tolist())
py_preds = predict_proba(model, X_pool)
preds = MMI.UnivariateFinite(classes, pyconvert(Array, py_preds); pool=y_first)
preds = MMI.UnivariateFinite(classes, pyconvert(Array, py_preds); pool=y_first.pool)
return preds
end

Expand Down

0 comments on commit 7a52755

Please sign in to comment.