diff --git a/Project.toml b/Project.toml index 5cbda20..dffd492 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "MLJXGBoostInterface" uuid = "54119dfa-1dab-4055-a167-80440f4f7a91" authors = ["Anthony D. Blaom "] -version = "0.3.6" +version = "0.3.7" [deps] MLJModelInterface = "e80e1ace-859a-464e-9ed9-23947d8ae3ea" diff --git a/src/MLJXGBoostInterface.jl b/src/MLJXGBoostInterface.jl index d125ad5..113cd6c 100644 --- a/src/MLJXGBoostInterface.jl +++ b/src/MLJXGBoostInterface.jl @@ -136,10 +136,11 @@ end function MMI.fit(model::XGBoostAbstractRegressor, verbosity::Integer, X, y) dm = DMatrix(MMI.matrix(X), float(y)) b = xgboost(dm; kwargs(model, verbosity, model.objective)...) - (b, nothing, (features=_feature_names(X, dm),)) + # first return value is a tuple for consistancy with classifier case + ((b, nothing), nothing, (features=_feature_names(X, dm),)) end -MMI.predict(model::XGBoostAbstractRegressor, fitresult, Xnew) = XGB.predict(fitresult, Xnew) +MMI.predict(model::XGBoostAbstractRegressor, (booster, _), Xnew) = XGB.predict(booster, Xnew) eval(modelexpr(:XGBoostCount, :XGBoostAbstractRegressor, "count:poisson", :validate_count_objective)) @@ -189,13 +190,9 @@ _save(fr; kw...) = XGB.save(fr, Vector{UInt8}; kw...) _restore(fr) = XGB.load(Booster, fr) -MMI.save(::XGBoostAbstractRegressor, fr; kw...) = _save(fr; kw...) +MMI.save(::XGTypes, fr; kw...) = (_save(fr[1]; kw...), fr[2]) -MMI.restore(::XGBoostAbstractRegressor, fr; kw...) = _restore(fr) - -MMI.save(::XGBoostClassifier, fr; kw...) = (_save(fr[1]; kw...), fr[2]) - -MMI.restore(::XGBoostClassifier, fr) = (_restore(fr[1]), fr[2]) +MMI.restore(::XGTypes, fr) = (_restore(fr[1]), fr[2]) MLJModelInterface.reports_feature_importances(::Type{<:XGBoostAbstractRegressor}) = true MLJModelInterface.reports_feature_importances(::Type{<:XGBoostAbstractClassifier}) = true