Skip to content

Commit

Permalink
fix cases where watchlist can be improperly passed as nothing (#29)
Browse files Browse the repository at this point in the history
* fix cases where watchlist can be improperly passed as nothing

* bump version

* fix disaster caused by fixing XGBoost predict output
  • Loading branch information
ExpandingMan authored Jan 9, 2023
1 parent 608d51d commit ab061e4
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 12 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "MLJXGBoostInterface"
uuid = "54119dfa-1dab-4055-a167-80440f4f7a91"
authors = ["Anthony D. Blaom <[email protected]>"]
version = "0.3.3"
version = "0.3.4"

[deps]
MLJModelInterface = "e80e1ace-859a-464e-9ed9-23947d8ae3ea"
Expand Down
14 changes: 3 additions & 11 deletions src/MLJXGBoostInterface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ function kwargs(model, verbosity, obj)
o = merge(o, (silent=(verbosity 0),))
# watchlist is for log output, so override if it's default and verbosity ≤ 0
wl = (verbosity 0 && isnothing(model.watchlist)) ? (;) : model.watchlist
o = merge(o, (watchlist=wl,))
isnothing(wl) || (o = merge(o, (watchlist=wl,)))
merge(o, (objective=_fix_objective(obj),))
end

Expand Down Expand Up @@ -174,16 +174,8 @@ function MMI.predict(model::XGBoostClassifier, fitresult, Xnew)
nlevels = length(classes)
npatterns = MMI.nrows(Xnew)

if nlevels == 2
true_class_probabilities = reshape(o, 1, npatterns)
false_class_probabilities = 1 .- true_class_probabilities
o = vcat(false_class_probabilities, true_class_probabilities)
end

prediction_probabilities = reshape(o, nlevels, npatterns)

# note we use adjoint of above:
MMI.UnivariateFinite(classes, prediction_probabilities')
# XGB can return a rank-1 array for binary classification
MMI.UnivariateFinite(classes, o, augment=ndims(o)==1)
end


Expand Down

0 comments on commit ab061e4

Please sign in to comment.