diff --git a/Project.toml b/Project.toml index fdc5330..4ad91b8 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "MLJXGBoostInterface" uuid = "54119dfa-1dab-4055-a167-80440f4f7a91" authors = ["Anthony D. Blaom "] -version = "0.3.3" +version = "0.3.4" [deps] MLJModelInterface = "e80e1ace-859a-464e-9ed9-23947d8ae3ea" diff --git a/src/MLJXGBoostInterface.jl b/src/MLJXGBoostInterface.jl index eac3829..bf4c152 100644 --- a/src/MLJXGBoostInterface.jl +++ b/src/MLJXGBoostInterface.jl @@ -107,7 +107,7 @@ function kwargs(model, verbosity, obj) o = merge(o, (silent=(verbosity ≤ 0),)) # watchlist is for log output, so override if it's default and verbosity ≤ 0 wl = (verbosity ≤ 0 && isnothing(model.watchlist)) ? (;) : model.watchlist - o = merge(o, (watchlist=wl,)) + isnothing(wl) || (o = merge(o, (watchlist=wl,))) merge(o, (objective=_fix_objective(obj),)) end @@ -174,16 +174,8 @@ function MMI.predict(model::XGBoostClassifier, fitresult, Xnew) nlevels = length(classes) npatterns = MMI.nrows(Xnew) - if nlevels == 2 - true_class_probabilities = reshape(o, 1, npatterns) - false_class_probabilities = 1 .- true_class_probabilities - o = vcat(false_class_probabilities, true_class_probabilities) - end - - prediction_probabilities = reshape(o, nlevels, npatterns) - - # note we use adjoint of above: - MMI.UnivariateFinite(classes, prediction_probabilities') + # XGB can return a rank-1 array for binary classification + MMI.UnivariateFinite(classes, o, augment=ndims(o)==1) end