Skip to content

Commit

Permalink
Merge pull request #36 from ExpandingMan/importancefix
Browse files Browse the repository at this point in the history
fix feature_importances for regression
  • Loading branch information
ablaom authored Jan 26, 2023
2 parents a927313 + 22b95d8 commit bc3ecef
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 9 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "MLJXGBoostInterface"
uuid = "54119dfa-1dab-4055-a167-80440f4f7a91"
authors = ["Anthony D. Blaom <[email protected]>"]
version = "0.3.6"
version = "0.3.7"

[deps]
MLJModelInterface = "e80e1ace-859a-464e-9ed9-23947d8ae3ea"
Expand Down
13 changes: 5 additions & 8 deletions src/MLJXGBoostInterface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -136,10 +136,11 @@ end
function MMI.fit(model::XGBoostAbstractRegressor, verbosity::Integer, X, y)
dm = DMatrix(MMI.matrix(X), float(y))
b = xgboost(dm; kwargs(model, verbosity, model.objective)...)
(b, nothing, (features=_feature_names(X, dm),))
# first return value is a tuple for consistancy with classifier case
((b, nothing), nothing, (features=_feature_names(X, dm),))
end

MMI.predict(model::XGBoostAbstractRegressor, fitresult, Xnew) = XGB.predict(fitresult, Xnew)
MMI.predict(model::XGBoostAbstractRegressor, (booster, _), Xnew) = XGB.predict(booster, Xnew)


eval(modelexpr(:XGBoostCount, :XGBoostAbstractRegressor, "count:poisson", :validate_count_objective))
Expand Down Expand Up @@ -189,13 +190,9 @@ _save(fr; kw...) = XGB.save(fr, Vector{UInt8}; kw...)

_restore(fr) = XGB.load(Booster, fr)

MMI.save(::XGBoostAbstractRegressor, fr; kw...) = _save(fr; kw...)
MMI.save(::XGTypes, fr; kw...) = (_save(fr[1]; kw...), fr[2])

MMI.restore(::XGBoostAbstractRegressor, fr; kw...) = _restore(fr)

MMI.save(::XGBoostClassifier, fr; kw...) = (_save(fr[1]; kw...), fr[2])

MMI.restore(::XGBoostClassifier, fr) = (_restore(fr[1]), fr[2])
MMI.restore(::XGTypes, fr) = (_restore(fr[1]), fr[2])

MLJModelInterface.reports_feature_importances(::Type{<:XGBoostAbstractRegressor}) = true
MLJModelInterface.reports_feature_importances(::Type{<:XGBoostAbstractClassifier}) = true
Expand Down

0 comments on commit bc3ecef

Please sign in to comment.