Skip to content

Commit

Permalink
Make sure y_first is a CategoricalValue
Browse files Browse the repository at this point in the history
  • Loading branch information
tylerjthomas9 committed Dec 25, 2024
1 parent a9e89ae commit cece1f0
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 3 deletions.
2 changes: 2 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,15 @@ authors = ["Beacon Biosignals, Inc."]
version = "0.3.6"

[deps]
CategoricalArrays = "324d7699-5711-5eae-9e2f-1d82baa6b597"
MLJModelInterface = "e80e1ace-859a-464e-9ed9-23947d8ae3ea"
OrderedCollections = "bac558e1-5e72-5ebc-8fee-abe8a469f55d"
PythonCall = "6099a3de-0909-46bc-b1f4-468b9a2dfc0d"
Tables = "bd369af6-aec1-5ad0-b16a-f7cc5008161c"

[compat]
Aqua = "0.8.4"
CategoricalArrays = "0.10.8"
DataFrames = "1.6"
MLJBase = "1"
MLJModelInterface = "1.7"
Expand Down
1 change: 1 addition & 0 deletions src/MLJCatBoostInterface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ using Tables
using MLJModelInterface: MLJModelInterface
const MMI = MLJModelInterface
using MLJModelInterface: Table, Continuous, Count, Finite, OrderedFactor, Multiclass
using CategoricalArrays: CategoricalArray, CategoricalValue
const PKG = "CatBoost"

"""
Expand Down
9 changes: 6 additions & 3 deletions src/mlj_catboostclassifier.jl
Original file line number Diff line number Diff line change
Expand Up @@ -72,8 +72,12 @@ function model_init(mlj_model::CatBoostClassifier; kw...)
return catboost.CatBoostClassifier(; mlj_to_kwargs(mlj_model)..., kw...)
end

prepare_single_class(y_first) = CategoricalArray([y_first])[1]
prepare_single_class(y_first::CategoricalValue) = y_first

function MMI.fit(mlj_model::CatBoostClassifier, verbosity::Int, data_pool, y_first)
# Check if y_first has only one unique value
# Check if the data pool has only one unique value
y_first = prepare_single_class(y_first)
unique_classes = pyconvert(Vector, numpy.unique(data_pool.get_label()))
if length(unique_classes) == 1
# Skip training and store the single class
Expand Down Expand Up @@ -103,8 +107,7 @@ function MMI.predict(mlj_model::CatBoostClassifier, fitresult, X_pool)
n = pyconvert(Int, X_pool.shape[0])
classes = [fitresult.single_class]
probs = ones(n, 1)
pool = MMI.categorical([fitresult.y_first])
return MMI.UnivariateFinite(classes, probs; pool=pool)
return MMI.UnivariateFinite(classes, probs; pool=fitresult.y_first)
end

model, y_first = fitresult
Expand Down

0 comments on commit cece1f0

Please sign in to comment.