Skip to content

Commit

Permalink
removed num_class from regressor
Browse files Browse the repository at this point in the history
  • Loading branch information
kainkad committed Nov 29, 2024
1 parent 43c765d commit 46c786e
Show file tree
Hide file tree
Showing 4 changed files with 6 additions and 5 deletions.
3 changes: 1 addition & 2 deletions src/estimators.jl
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,6 @@ mutable struct LGBMRegression <: LGBMEstimator
predict_disable_shape_check::Bool

# Objective parameters
num_class::Int
is_unbalance::Bool
boost_from_average::Bool
reg_sqrt::Bool
Expand Down Expand Up @@ -332,7 +331,7 @@ function LGBMRegression(;
linear_tree, max_bin, max_bin_by_feature, min_data_in_bin, bin_construct_sample_cnt, data_random_seed,
is_enable_sparse, enable_bundle, use_missing, zero_as_missing, feature_pre_filter, pre_partition, categorical_feature,
start_iteration_predict, num_iteration_predict, predict_raw_score, predict_leaf_index, predict_contrib, predict_disable_shape_check,
1, is_unbalance, boost_from_average, reg_sqrt, alpha, fair_c, poisson_max_delta_step, tweedie_variance_power,
is_unbalance, boost_from_average, reg_sqrt, alpha, fair_c, poisson_max_delta_step, tweedie_variance_power,
metric, metric_freq, is_provide_training_metric, eval_at,
num_machines, local_listen_port, time_out, machine_list_filename, machines, gpu_platform_id, gpu_device_id, gpu_use_dp, num_gpu,
)
Expand Down
2 changes: 1 addition & 1 deletion src/predict.jl
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ function predict(

# This works the same one way or another because when n=1, (regression) reshaping is basically no-op
# except for adding the extra dim
prediction = transpose(reshape(prediction, estimator.num_class, :))
prediction = transpose(reshape(prediction, estimator isa LGBMRegression ? 1 : estimator.num_class, :))

return prediction

Expand Down
4 changes: 3 additions & 1 deletion src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,9 @@ the parameters or data of the estimator whose model was saved as `filename`.
"""
function loadmodel!(estimator::LGBMEstimator, filename::String)
estimator.booster = LGBM_BoosterCreateFromModelfile(filename)
estimator.num_class = LGBM_BoosterGetNumClasses(estimator.booster)
if !(estimator isa LGBMRegression)
estimator.num_class = LGBM_BoosterGetNumClasses(estimator.booster)
end
return nothing
end

Expand Down
2 changes: 1 addition & 1 deletion src/wrapper.jl
Original file line number Diff line number Diff line change
Expand Up @@ -616,7 +616,7 @@ end

function LGBM_BoosterGetPredict(bst::Booster, data_idx::Integer)
out_len = Ref{Int64}()
num_class = LGBM_BoosterGetNumClasses(bst)
num_class = bst isa LGBMRegression ? 1 : LGBM_BoosterGetNumClasses(bst)
num_data = LGBM_BoosterGetNumPredict(bst, data_idx)
out_results = Array{Cdouble}(undef, num_class * num_data)
@lightgbm(:LGBM_BoosterGetPredict,
Expand Down

0 comments on commit 46c786e

Please sign in to comment.