From 91dcb65b24d36ac8740fe0363989adec6542d016 Mon Sep 17 00:00:00 2001 From: ExpandingMan Date: Wed, 25 Jan 2023 17:20:48 -0500 Subject: [PATCH 1/4] fix feature_importances for regression --- Project.toml | 2 +- src/MLJXGBoostInterface.jl | 3 ++- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/Project.toml b/Project.toml index 5cbda20..dffd492 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "MLJXGBoostInterface" uuid = "54119dfa-1dab-4055-a167-80440f4f7a91" authors = ["Anthony D. Blaom "] -version = "0.3.6" +version = "0.3.7" [deps] MLJModelInterface = "e80e1ace-859a-464e-9ed9-23947d8ae3ea" diff --git a/src/MLJXGBoostInterface.jl b/src/MLJXGBoostInterface.jl index d125ad5..052fc15 100644 --- a/src/MLJXGBoostInterface.jl +++ b/src/MLJXGBoostInterface.jl @@ -114,7 +114,8 @@ function kwargs(model, verbosity, obj) merge(o, (objective=_fix_objective(obj),)) end -function MMI.feature_importances(model::XGTypes, (booster, _), (features,)) +function MMI.feature_importances(model::XGTypes, booster, (features,)) + booster isa Tuple && (booster = booster[1]) dict = XGB.importance(booster, model.importance_type) if length(last(first(dict))) > 1 [features[k] => zero(first(v)) for (k, v) in dict] From dbaed1fdbf1dc9d53a1818910d7e8ab2816bc496 Mon Sep 17 00:00:00 2001 From: ExpandingMan Date: Wed, 25 Jan 2023 17:30:07 -0500 Subject: [PATCH 2/4] make return values of different fit calls more consistent --- src/MLJXGBoostInterface.jl | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/MLJXGBoostInterface.jl b/src/MLJXGBoostInterface.jl index 052fc15..0aa91d1 100644 --- a/src/MLJXGBoostInterface.jl +++ b/src/MLJXGBoostInterface.jl @@ -114,8 +114,7 @@ function kwargs(model, verbosity, obj) merge(o, (objective=_fix_objective(obj),)) end -function MMI.feature_importances(model::XGTypes, booster, (features,)) - booster isa Tuple && (booster = booster[1]) +function MMI.feature_importances(model::XGTypes, (booster, _), (features,)) dict = XGB.importance(booster, model.importance_type) if length(last(first(dict))) > 1 [features[k] => zero(first(v)) for (k, v) in dict] @@ -137,7 +136,8 @@ end function MMI.fit(model::XGBoostAbstractRegressor, verbosity::Integer, X, y) dm = DMatrix(MMI.matrix(X), float(y)) b = xgboost(dm; kwargs(model, verbosity, model.objective)...) - (b, nothing, (features=_feature_names(X, dm),)) + # first return value is a tuple for consistancy with classifier case + ((b, nothing), nothing, (features=_feature_names(X, dm),)) end MMI.predict(model::XGBoostAbstractRegressor, fitresult, Xnew) = XGB.predict(fitresult, Xnew) From 58ebb8e03fd17b3f853994aa77a7cc8cd6e70a31 Mon Sep 17 00:00:00 2001 From: ExpandingMan Date: Wed, 25 Jan 2023 17:43:44 -0500 Subject: [PATCH 3/4] fix predict for regressor --- src/MLJXGBoostInterface.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/MLJXGBoostInterface.jl b/src/MLJXGBoostInterface.jl index 0aa91d1..a59bb60 100644 --- a/src/MLJXGBoostInterface.jl +++ b/src/MLJXGBoostInterface.jl @@ -140,7 +140,7 @@ function MMI.fit(model::XGBoostAbstractRegressor, verbosity::Integer, X, y) ((b, nothing), nothing, (features=_feature_names(X, dm),)) end -MMI.predict(model::XGBoostAbstractRegressor, fitresult, Xnew) = XGB.predict(fitresult, Xnew) +MMI.predict(model::XGBoostAbstractRegressor, (booster, _), Xnew) = XGB.predict(booster, Xnew) eval(modelexpr(:XGBoostCount, :XGBoostAbstractRegressor, "count:poisson", :validate_count_objective)) From 22b95d80ba9ba039ae6786f1d3672743fea33566 Mon Sep 17 00:00:00 2001 From: ExpandingMan Date: Wed, 25 Jan 2023 17:46:35 -0500 Subject: [PATCH 4/4] fix serialization --- src/MLJXGBoostInterface.jl | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/src/MLJXGBoostInterface.jl b/src/MLJXGBoostInterface.jl index a59bb60..113cd6c 100644 --- a/src/MLJXGBoostInterface.jl +++ b/src/MLJXGBoostInterface.jl @@ -190,13 +190,9 @@ _save(fr; kw...) = XGB.save(fr, Vector{UInt8}; kw...) _restore(fr) = XGB.load(Booster, fr) -MMI.save(::XGBoostAbstractRegressor, fr; kw...) = _save(fr; kw...) +MMI.save(::XGTypes, fr; kw...) = (_save(fr[1]; kw...), fr[2]) -MMI.restore(::XGBoostAbstractRegressor, fr; kw...) = _restore(fr) - -MMI.save(::XGBoostClassifier, fr; kw...) = (_save(fr[1]; kw...), fr[2]) - -MMI.restore(::XGBoostClassifier, fr) = (_restore(fr[1]), fr[2]) +MMI.restore(::XGTypes, fr) = (_restore(fr[1]), fr[2]) MLJModelInterface.reports_feature_importances(::Type{<:XGBoostAbstractRegressor}) = true MLJModelInterface.reports_feature_importances(::Type{<:XGBoostAbstractClassifier}) = true