From ab061e4b80069cd7884ee23e4d7bbe4834bdd7e9 Mon Sep 17 00:00:00 2001 From: ExpandingMan Date: Mon, 9 Jan 2023 13:27:00 -0500 Subject: [PATCH] fix cases where watchlist can be improperly passed as nothing (#29) * fix cases where watchlist can be improperly passed as nothing * bump version * fix disaster caused by fixing XGBoost predict output --- Project.toml | 2 +- src/MLJXGBoostInterface.jl | 14 +++----------- 2 files changed, 4 insertions(+), 12 deletions(-) diff --git a/Project.toml b/Project.toml index fdc5330..4ad91b8 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "MLJXGBoostInterface" uuid = "54119dfa-1dab-4055-a167-80440f4f7a91" authors = ["Anthony D. Blaom "] -version = "0.3.3" +version = "0.3.4" [deps] MLJModelInterface = "e80e1ace-859a-464e-9ed9-23947d8ae3ea" diff --git a/src/MLJXGBoostInterface.jl b/src/MLJXGBoostInterface.jl index eac3829..bf4c152 100644 --- a/src/MLJXGBoostInterface.jl +++ b/src/MLJXGBoostInterface.jl @@ -107,7 +107,7 @@ function kwargs(model, verbosity, obj) o = merge(o, (silent=(verbosity ≤ 0),)) # watchlist is for log output, so override if it's default and verbosity ≤ 0 wl = (verbosity ≤ 0 && isnothing(model.watchlist)) ? (;) : model.watchlist - o = merge(o, (watchlist=wl,)) + isnothing(wl) || (o = merge(o, (watchlist=wl,))) merge(o, (objective=_fix_objective(obj),)) end @@ -174,16 +174,8 @@ function MMI.predict(model::XGBoostClassifier, fitresult, Xnew) nlevels = length(classes) npatterns = MMI.nrows(Xnew) - if nlevels == 2 - true_class_probabilities = reshape(o, 1, npatterns) - false_class_probabilities = 1 .- true_class_probabilities - o = vcat(false_class_probabilities, true_class_probabilities) - end - - prediction_probabilities = reshape(o, nlevels, npatterns) - - # note we use adjoint of above: - MMI.UnivariateFinite(classes, prediction_probabilities') + # XGB can return a rank-1 array for binary classification + MMI.UnivariateFinite(classes, o, augment=ndims(o)==1) end