Skip to content

Commit

Permalink
Merge pull request #56 from Moelf/support_weight
Browse files Browse the repository at this point in the history
support MLJ training per data sample weight
  • Loading branch information
ablaom authored Nov 30, 2024
2 parents 402861a + 0ec95a7 commit b1f5c91
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 6 deletions.
22 changes: 16 additions & 6 deletions src/MLJXGBoostInterface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -145,8 +145,12 @@ function _feature_names(X, dmatrix)
end
end

function MMI.fit(model::XGBoostAbstractRegressor, verbosity::Integer, X, y)
dm = DMatrix(MMI.matrix(X), float(y))
function MMI.fit(model::XGBoostAbstractRegressor, verbosity::Integer, X, y, weight=nothing)
dm = if isnothing(weight)
DMatrix(MMI.matrix(X), float(y))
else
DMatrix(MMI.matrix(X), float(y); weight = weight)
end
b = xgboost(dm; kwargs(model, verbosity, model.objective)...)
# first return value is a tuple for consistancy with classifier case
((b, nothing), nothing, (features=_feature_names(X, dm),))
Expand All @@ -162,7 +166,7 @@ eval(modelexpr(:XGBoostClassifier, :XGBoostAbstractClassifier, "automatic", :val

function MMI.fit(model::XGBoostClassifier,
verbosity, # must be here even if unsupported in pkg
X, y,
X, y, weight=nothing
)
a_target_element = y[1] # a CategoricalValue or CategoricalString
nclass = length(MMI.classes(a_target_element))
Expand All @@ -172,7 +176,11 @@ function MMI.fit(model::XGBoostClassifier,
num_class = nclass == 2 ? (;) : (num_class=nclass,)

# libxgboost wants float labels
dm = DMatrix(MMI.matrix(X), float(MMI.int(y) .- 1))
dm = if isnothing(weight)
DMatrix(MMI.matrix(X), float(MMI.int(y) .- 1))
else
DMatrix(MMI.matrix(X), float(MMI.int(y) .- 1); weight = weight)
end

b = xgboost(dm; kwargs(model, verbosity, objective)..., num_class...)
fr = (b, a_target_element)
Expand Down Expand Up @@ -202,8 +210,8 @@ MMI.save(::XGTypes, fr; kw...) = (_save(fr[1]; kw...), fr[2])

MMI.restore(::XGTypes, fr) = (_restore(fr[1]), fr[2])

MLJModelInterface.reports_feature_importances(::Type{<:XGBoostAbstractRegressor}) = true
MLJModelInterface.reports_feature_importances(::Type{<:XGBoostAbstractClassifier}) = true
MMI.reports_feature_importances(::Type{<:XGBoostAbstractRegressor}) = true
MMI.reports_feature_importances(::Type{<:XGBoostAbstractClassifier}) = true


MMI.package_name(::Type{<:XGTypes}) = "XGBoost"
Expand All @@ -225,6 +233,8 @@ MMI.load_path(::Type{<:XGBoostClassifier}) = "$PKG.XGBoostClassifier"
MMI.input_scitype(::Type{<:XGBoostClassifier}) = Table(Continuous)
MMI.target_scitype(::Type{<:XGBoostClassifier}) = AbstractVector{<:Finite}
MMI.human_name(::Type{<:XGBoostClassifier}) = "eXtreme Gradient Boosting Classifier"
MMI.supports_weights(::Type{<:XGBoostRegressor}) = true
MMI.supports_weights(::Type{<:XGBoostClassifier}) = true


include("docstrings.jl")
Expand Down
7 changes: 7 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,13 @@ end
fit!(mach, verbosity=0)
yhat = predict(mach, Xtable)

weight = rand(length(ycount))
mach_withweight = machine(count_regressor, Xtable, ycount, weight)
fit!(mach_withweight, verbosity=0)
yhat_withweight = predict(mach_withweight, Xtable)

@test !(yhat yhat_withweight)

# serialize:
io = IOBuffer()
MLJBase.save(io, mach)
Expand Down

0 comments on commit b1f5c91

Please sign in to comment.