diff --git a/src/composition/models/stacking.jl b/src/composition/models/stacking.jl index 407a3ca0..4a760e24 100644 --- a/src/composition/models/stacking.jl +++ b/src/composition/models/stacking.jl @@ -396,7 +396,9 @@ function internal_stack_report( per_observation = Vector{Union{Missing, Vector{Any}}}(missing, n_measures), fitted_params_per_fold = [], report_per_fold = [], - train_test_pairs = tt_pairs + train_test_pairs = tt_pairs, + resampling = stack.resampling, + repeats = 1 ) for model in getfield(stack, :models) ] diff --git a/src/resampling.jl b/src/resampling.jl index bf774f3c..43483cc3 100644 --- a/src/resampling.jl +++ b/src/resampling.jl @@ -510,6 +510,10 @@ struct. - `train_test_rows`: a vector of tuples, each of the form `(train, test)`, where `train` and `test` are vectors of row (observation) indices for training and evaluation respectively. + +- `resampling`: the resampling strategy used to generate the train/test pairs. + +- `repeats`: the number of times the resampling strategy was repeated. """ struct PerformanceEvaluation{M, Measure, @@ -518,7 +522,8 @@ struct PerformanceEvaluation{M, PerFold, PerObservation, FittedParamsPerFold, - ReportPerFold} <: MLJType + ReportPerFold, + R} <: MLJType model::M measure::Measure measurement::Measurement @@ -528,6 +533,8 @@ struct PerformanceEvaluation{M, fitted_params_per_fold::FittedParamsPerFold report_per_fold::ReportPerFold train_test_rows::TrainTestPairs + resampling::R + repeats::Int end # pretty printing: @@ -575,7 +582,7 @@ function Base.show(io::IO, ::MIME"text/plain", e::PerformanceEvaluation) "with these fields:") println(io, " model, measure, operation, measurement, per_fold,\n"* " per_observation, fitted_params_per_fold,\n"* - " report_per_fold, train_test_rows") + " report_per_fold, train_test_rows, resampling, repeats") println(io, "Extract:") show_color = MLJBase.SHOW_COLOR[] color_off() @@ -1006,7 +1013,8 @@ function evaluate!(mach::Machine{<:Measurable}; _acceleration= _process_accel_settings(acceleration) evaluate!(mach, resampling, weights, class_weights, rows, verbosity, - repeats, _measures, _operations, _acceleration, force, logger) + repeats, _measures, _operations, _acceleration, force, logger, + resampling) end @@ -1183,9 +1191,10 @@ function measure_specific_weights(measure, weights, class_weights, test) end # Evaluation when `resampling` is a TrainTestPairs (CORE EVALUATOR): -function evaluate!(mach::Machine, resampling, weights, - class_weights, rows, verbosity, repeats, - measures, operations, acceleration, force, logger) +# `user_resampling` keyword argument is the user defined resampling strategy +function evaluate!(mach::Machine, resampling, weights, class_weights, rows, + verbosity, repeats, measures, operations, acceleration, + force, logger, user_resampling) # Note: `rows` and `repeats` are ignored here @@ -1298,7 +1307,9 @@ function evaluate!(mach::Machine, resampling, weights, per_observation, fitted_params_per_fold |> collect, report_per_fold |> collect, - resampling + resampling, + user_resampling, + repeats ) log_evaluation(logger, evaluation) @@ -1321,7 +1332,7 @@ function evaluate!(mach::Machine, resampling::ResamplingStrategy, [train_test_pairs(resampling, _rows, train_args...) for i in 1:repeats]... ) - return evaluate!( + evaluate!( mach, repeated_train_test_pairs, weights, @@ -1331,7 +1342,6 @@ function evaluate!(mach::Machine, resampling::ResamplingStrategy, repeats, args... ) - end # ==================================================================== @@ -1493,7 +1503,8 @@ function MLJModelInterface.fit(resampler::Resampler, verbosity::Int, args...) _operations, _acceleration, false, - resampler.logger + resampler.logger, + resampler.resampling ) fitresult = (machine = mach, evaluation = e) @@ -1557,7 +1568,8 @@ function MLJModelInterface.update( operations, acceleration, false, - resampler.logger + resampler.logger, + resampler.resampling ) report = (evaluation = e, ) fitresult = (machine=mach2, evaluation=e)