Skip to content

Commit

Permalink
Adding repeats and user defined resampling to PerformanceEvaluation
Browse files Browse the repository at this point in the history
  • Loading branch information
pebeto committed Aug 18, 2023
1 parent 1387f64 commit 8b6b047
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 12 deletions.
4 changes: 3 additions & 1 deletion src/composition/models/stacking.jl
Original file line number Diff line number Diff line change
Expand Up @@ -396,7 +396,9 @@ function internal_stack_report(
per_observation = Vector{Union{Missing, Vector{Any}}}(missing, n_measures),
fitted_params_per_fold = [],
report_per_fold = [],
train_test_pairs = tt_pairs
train_test_pairs = tt_pairs,
resampling = stack.resampling,
repeats = 1
)
for model in getfield(stack, :models)
]
Expand Down
34 changes: 23 additions & 11 deletions src/resampling.jl
Original file line number Diff line number Diff line change
Expand Up @@ -510,6 +510,10 @@ struct.
- `train_test_rows`: a vector of tuples, each of the form `(train, test)`,
where `train` and `test` are vectors of row (observation) indices for
training and evaluation respectively.
- `resampling`: the resampling strategy used to generate the train/test pairs.
- `repeats`: the number of times the resampling strategy was repeated.
"""
struct PerformanceEvaluation{M,
Measure,
Expand All @@ -518,7 +522,8 @@ struct PerformanceEvaluation{M,
PerFold,
PerObservation,
FittedParamsPerFold,
ReportPerFold} <: MLJType
ReportPerFold,
R} <: MLJType
model::M
measure::Measure
measurement::Measurement
Expand All @@ -528,6 +533,8 @@ struct PerformanceEvaluation{M,
fitted_params_per_fold::FittedParamsPerFold
report_per_fold::ReportPerFold
train_test_rows::TrainTestPairs
resampling::R
repeats::Int
end

# pretty printing:
Expand Down Expand Up @@ -575,7 +582,7 @@ function Base.show(io::IO, ::MIME"text/plain", e::PerformanceEvaluation)
"with these fields:")
println(io, " model, measure, operation, measurement, per_fold,\n"*
" per_observation, fitted_params_per_fold,\n"*
" report_per_fold, train_test_rows")
" report_per_fold, train_test_rows, resampling, repeats")
println(io, "Extract:")
show_color = MLJBase.SHOW_COLOR[]
color_off()
Expand Down Expand Up @@ -1006,7 +1013,8 @@ function evaluate!(mach::Machine{<:Measurable};
_acceleration= _process_accel_settings(acceleration)

evaluate!(mach, resampling, weights, class_weights, rows, verbosity,
repeats, _measures, _operations, _acceleration, force, logger)
repeats, _measures, _operations, _acceleration, force, logger,
resampling)

end

Expand Down Expand Up @@ -1183,9 +1191,10 @@ function measure_specific_weights(measure, weights, class_weights, test)
end

# Evaluation when `resampling` is a TrainTestPairs (CORE EVALUATOR):
function evaluate!(mach::Machine, resampling, weights,
class_weights, rows, verbosity, repeats,
measures, operations, acceleration, force, logger)
# `user_resampling` keyword argument is the user defined resampling strategy
function evaluate!(mach::Machine, resampling, weights, class_weights, rows,
verbosity, repeats, measures, operations, acceleration,
force, logger, user_resampling)

# Note: `rows` and `repeats` are ignored here

Expand Down Expand Up @@ -1298,7 +1307,9 @@ function evaluate!(mach::Machine, resampling, weights,
per_observation,
fitted_params_per_fold |> collect,
report_per_fold |> collect,
resampling
resampling,
user_resampling,
repeats
)
log_evaluation(logger, evaluation)

Expand All @@ -1321,7 +1332,7 @@ function evaluate!(mach::Machine, resampling::ResamplingStrategy,
[train_test_pairs(resampling, _rows, train_args...) for i in 1:repeats]...
)

return evaluate!(
evaluate!(
mach,
repeated_train_test_pairs,
weights,
Expand All @@ -1331,7 +1342,6 @@ function evaluate!(mach::Machine, resampling::ResamplingStrategy,
repeats,
args...
)

end

# ====================================================================
Expand Down Expand Up @@ -1493,7 +1503,8 @@ function MLJModelInterface.fit(resampler::Resampler, verbosity::Int, args...)
_operations,
_acceleration,
false,
resampler.logger
resampler.logger,
resampler.resampling
)

fitresult = (machine = mach, evaluation = e)
Expand Down Expand Up @@ -1557,7 +1568,8 @@ function MLJModelInterface.update(
operations,
acceleration,
false,
resampler.logger
resampler.logger,
resampler.resampling
)
report = (evaluation = e, )
fitresult = (machine=mach2, evaluation=e)
Expand Down

0 comments on commit 8b6b047

Please sign in to comment.