From 9f01426f21548a83f53ba597f143119a08fd3d5e Mon Sep 17 00:00:00 2001 From: "Anthony D. Blaom" Date: Mon, 27 May 2024 13:38:35 +1200 Subject: [PATCH] Allow resampling=InSample() in TunedModel; revamp docstring --- src/MLJIteration.jl | 4 +- src/constructors.jl | 249 +++++++++++++++++++++++++------------------- test/core.jl | 24 ++++- 3 files changed, 164 insertions(+), 113 deletions(-) diff --git a/src/MLJIteration.jl b/src/MLJIteration.jl index c2dfc34..a4e6d3a 100644 --- a/src/MLJIteration.jl +++ b/src/MLJIteration.jl @@ -17,7 +17,7 @@ const CONTROLS = vcat(IterationControl.CONTROLS, :WithModelDo, :CycleLearningRate, :Save]) - +const CONTROLS_LIST = join(map(c->"$c()", CONTROLS), ", ", " and ") const TRAINING_CONTROLS = [:Step, ] # export all control types: @@ -42,6 +42,4 @@ include("traits.jl") include("ic_model.jl") include("core.jl") - - end # module diff --git a/src/constructors.jl b/src/constructors.jl index 52448c2..14b0676 100644 --- a/src/constructors.jl +++ b/src/constructors.jl @@ -1,5 +1,5 @@ const IterationResamplingTypes = - Union{Holdout,Nothing,MLJBase.TrainTestPairs} + Union{Holdout,InSample,Nothing,MLJBase.TrainTestPairs} ## TYPES AND CONSTRUCTOR @@ -72,96 +72,119 @@ err_bad_iteration_parameter(p) = ArgumentError("Model to be iterated does not have :($p) as an iteration parameter. ") """ - IteratedModel(model=nothing, - controls=$CONTROLS_DEFAULT, - retrain=false, - resampling=Holdout(), - measure=nothing, - weights=nothing, - class_weights=nothing, - operation=predict, - verbosity=1, - check_measure=true, - iteration_parameter=nothing, - cache=true) - -Wrap the specified `model <: Supervised` in the specified iteration -`controls`. Training a machine bound to the wrapper iterates a -corresonding machine bound to `model`. Here `model` should support -iteration. - -To list all controls, do `MLJIteration.CONTROLS`. Controls are -summarized at -[https://alan-turing-institute.github.io/MLJ.jl/dev/getting_started/](https://alan-turing-institute.github.io/MLJ.jl/dev/controlling_iterative_models/) -but query individual doc-strings for details and advanced options. For -creating your own controls, refer to the documentation just cited. - -To make out-of-sample losses available to the controls, the machine -bound to `model` is only trained on part of the data, as iteration -proceeds. See details on training below. Specify `retrain=true` -to ensure the model is retrained on *all* available data, using the -same number of iterations, once controlled iteration has stopped. - -Specify `resampling=nothing` if all data is to be used for controlled -iteration, with each out-of-sample loss replaced by the most recent -training loss, assuming this is made available by the model -(`supports_training_losses(model) == true`). Otherwise, `resampling` -must have type `Holdout` (eg, `Holdout(fraction_train=0.8, rng=123)`). - -Assuming `retrain=true` or `resampling=nothing`, -`iterated_model` behaves exactly like the original `model` but with -the iteration parameter automatically selected. If -`retrain=false` (default) and `resampling` is not `nothing`, then -`iterated_model` behaves like the original model trained on a subset -of the provided data. - -Controlled iteration can be continued with new `fit!` calls (warm -restart) by mutating a control, or by mutating the iteration parameter -of `model`, which is otherwise ignored. - - -### Training - -Given an instance `iterated_model` of `IteratedModel`, calling -`fit!(mach)` on a machine `mach = machine(iterated_model, data...)` -performs the following actions: - -- Assuming `resampling !== nothing`, the `data` is split into *train* and - *test* sets, according to the specified `resampling` strategy, which - must have type `Holdout`. - -- A clone of the wrapped model, `iterated_model.model`, is bound to - the train data in an internal machine, `train_mach`. If `resampling - === nothing`, all data is used instead. This machine is the object - to which controls are applied. For example, `Callback(fitted_params - |> print)` will print the value of `fitted_params(train_mach)`. + IteratedModel(model; + controls=..., + resampling=Holdout(), + measure=nothing, + retrain=false, + advanced_options..., + ) + +Wrap the specified supervised `model` in the specified iteration `controls`. Here `model` +should support iteration, which is true if (`iteration_parameter(model)` is different from +`nothing`. + +Available controls: $CONTROLS_LIST. + +!!! important + + To make out-of-sample losses available to the controls, the wrapped `model` is only + trained on part of the data, as iteration proceeds. The user may want to force + retraining on all data after controlled iteration has finished by specifying + `retrain=true`. See also "Training", and the `retrain` option, under "Extended help" + below. + +# Extended help + +# Options + +- `controls=$CONTROLS_DEFAULT`: Controls are summarized at + [https://JuliaAI.github.io/MLJ.jl/dev/getting_started/](https://JuliaAI.github.io/MLJ.jl/dev/controlling_iterative_models/) + but query individual doc-strings for details and advanced options. For creating your own + controls, refer to the documentation just cited. + +- `resampling=Holdout(fraction_train=0.7)`: The default resampling holds back 30% of data + for computing an out-of-sample estimate of performance (the "loss") for controls such + as `WithLossDo` and stopping criterion; specify `resampling=nothing` if all data is to + be used for controlled iteration, with each out-of-sample loss replaced by the most + recent training loss, assuming this is made available by the model + (`supports_training_losses(model) == true`). If the model does not provide report a + training loss, you can use `resampling=InSample()` instead, with an additional + performance cost. Otherwise, `resampling` must have type `Holdout` or be a vector with + one element of the form `(train_indices, test_indices)`. + +- `measure=nothing`: StatisticalMeasures.jl compatible measure for estimating model + performance (the "loss", but the orientation is immaterial - i.e., this could be a + score). Inferred by default. Ignored if `resampling=nothing`. + +- `retrain=false`: If `retrain=true` or `resampling=nothing`, `iterated_model` behaves + exactly like the original `model` but with the iteration parameter automatically + selected ("learned"). That is, the model is retrained on *all* available data, using the + same number of iterations, once controlled iteration has stopped. This is typically + desired if wrapping the iterated model further, or when inserting in a pipeline or other + composite model. If `retrain=false` (default) and `resampling isa Holdout`, then + `iterated_model` behaves like the original model trained on a subset of the provided + data. + +- `weights=nothing`: per-observation weights to be passed to `measure` where supported; if + unspecified, these are understood to be uniform. + +- `class_weights=nothing`: class-weights to be passed to `measure` where supported; if + unspecified, these are understood to be uniform. + +- `operation=nothing`: Operation, such as `predict` or `predict_mode`, for computing + target values, or proxy target values, for consumption by `measure`; automatically + inferred by default. + +- `check_measure=true`: Specify `false` to override checks on `measure` for compatibility + with the training data. + +- `iteration_parameter=nothing`: A symbol, such as `:epochs`, naming the iteration + parameter of `model`; inferred by default. Note that the actual value of the iteration + parameter in the supplied `model` is ignored; only the value of an internal clone is + mutated during training the wrapped model. + +- `cache=true`: Whether or not model-specific representations of data are cached in + between iteration parameter increments; specify `cache=false` to prioritize memory over + speed. + + +# Training + +Training an instance `iterated_model` of `IteratedModel` on some `data` (by binding to a +machine and calling `fit!`, for example) performs the following actions: + +- Assuming `resampling !== nothing`, the `data` is split into *train* and *test* sets, + according to the specified `resampling` strategy. + +- A clone of the wrapped model, `model` is bound to the train data in an internal machine, + `train_mach`. If `resampling === nothing`, all data is used instead. This machine is the + object to which controls are applied. For example, `Callback(fitted_params |> print)` + will print the value of `fitted_params(train_mach)`. - The iteration parameter of the clone is set to `0`. -- The specified `controls` are repeatedly applied to `train_mach` in - sequence, until one of the controls triggers a stop. Loss-based - controls (eg, `Patience()`, `GL()`, `Threshold(0.001)`) use an - out-of-sample loss, obtained by applying `measure` to predictions - and the test target values. (Specifically, these predictions are - those returned by `operation(train_mach)`.) If `resampling === - nothing` then the most recent training loss is used instead. Some - controls require *both* out-of-sample and training losses (eg, - `PQ()`). +- The specified `controls` are repeatedly applied to `train_mach` in sequence, until one + of the controls triggers a stop. Loss-based controls (eg, `Patience()`, `GL()`, + `Threshold(0.001)`) use an out-of-sample loss, obtained by applying `measure` to + predictions and the test target values. (Specifically, these predictions are those + returned by `operation(train_mach)`.) If `resampling === nothing` then the most recent + training loss is used instead. Some controls require *both* out-of-sample and training + losses (eg, `PQ()`). -- Once a stop has been triggered, a clone of `model` is bound to all - `data` in a machine called `mach_production` below, unless - `retrain == false` or `resampling === nothing`, in which case - `mach_production` coincides with `train_mach`. +- Once a stop has been triggered, a clone of `model` is bound to all `data` in a machine + called `mach_production` below, unless `retrain == false` (true by default) or + `resampling === nothing`, in which case `mach_production` coincides with `train_mach`. -### Prediction +# Prediction -Calling `predict(mach, Xnew)` returns `predict(mach_production, -Xnew)`. Similar similar statements hold for `predict_mean`, -`predict_mode`, `predict_median`. +Calling `predict(mach, Xnew)` in the example above returns `predict(mach_production, +Xnew)`. Similar similar statements hold for `predict_mean`, `predict_mode`, +`predict_median`. -### Controls +# Controls that mutate parameters A control is permitted to mutate the fields (hyper-parameters) of `train_mach.model` (the clone of `model`). For example, to mutate a @@ -174,11 +197,25 @@ in that parameter, this will trigger retraining of `train_mach` from scratch, with a different training outcome, which is not recommended. -### Warm restarts +# Warm restarts -If `iterated_model` is mutated and `fit!(mach)` is called again, then -a warm restart is attempted if the only parameters to change are -`model` or `controls` or both. +In the following example, the second `fit!` call will not restart training of the internal +`train_mach`, assuming `model` supports warm restarts: + +```julia +iterated_model = IteratedModel( + model, + controls = [Step(1), NumberLimit(100)], +) +mach = machine(iterated_model, X, y) +fit!(mach) # train for 100 iterations +iterated_model.controls = [Step(1), NumberLimit(50)], +fit!(mach) # train for an *extra* 50 iterations +``` + +More generally, if `iterated_model` is mutated and `fit!(mach)` is called again, then a +warm restart is attempted if the only parameters to change are `model` or `controls` or +both. Specifically, `train_mach.model` is mutated to match the current value of `iterated_model.model` and the iteration parameter of the latter is @@ -195,7 +232,7 @@ function IteratedModel(args...; measure=measures, weights=nothing, class_weights=nothing, - operation=predict, + operation=nothing, retrain=false, check_measure=true, iteration_parameter=nothing, @@ -211,30 +248,24 @@ function IteratedModel(args...; atom = model end + options = ( + atom, + controls, + resampling, + measure, + weights, + class_weights, + operation, + retrain, + check_measure, + iteration_parameter, + cache, + ) + if atom isa Deterministic - iterated_model = DeterministicIteratedModel(atom, - controls, - resampling, - measure, - weights, - class_weights, - operation, - retrain, - check_measure, - iteration_parameter, - cache) + iterated_model = DeterministicIteratedModel(options...) elseif atom isa Probabilistic - iterated_model = ProbabilisticIteratedModel(atom, - controls, - resampling, - measure, - weights, - class_weights, - operation, - retrain, - check_measure, - iteration_parameter, - cache) + iterated_model = ProbabilisticIteratedModel(options...) else throw(ERR_NOT_SUPERVISED) end diff --git a/test/core.jl b/test/core.jl index 2119da4..1f54945 100644 --- a/test/core.jl +++ b/test/core.jl @@ -6,6 +6,7 @@ using IterationControl using MLJBase using MLJModelInterface using StatisticalMeasures +using StableRNGs using ..DummyModel X, y = make_dummy(N=20) @@ -26,6 +27,7 @@ model = DummyIterativeModel(n=0) end IterationControl.loss(mach::Machine{<:DummyIterativeModel}) = last(training_losses(mach)) + IterationControl.train!(mach, controls..., verbosity=0) losses1 = report(mach).training_losses yhat1 = predict(mach, X) @@ -104,6 +106,26 @@ model = DummyIterativeModel(n=0) @test report(mach).n_iterations == 5 end +@testset "resampling = InSample()" begin + model = DummyIterativeModel(n=0, rng=StableRNG(123)) + controls=[Step(2), NumberLimit(10)] + + # using `resampling=nothing`: + imodel = IteratedModel(model=model, controls=controls, resampling=nothing) + mach = machine(imodel, X, y) + fit!(mach, verbosity=0) + y1 = predict(mach, rows=1:10) + + # using `resampling=InSample()`: + model = DummyIterativeModel(n=0, rng=StableRNG(123)) + imodel = IteratedModel(model=model, controls=controls, resampling=InSample()) + mach = machine(imodel, X, y) + fit!(mach, verbosity=0) + y2 = predict(mach, rows=1:10) + + @test y1 == y2 +end + @testset "integration: resampling=Holdout()" begin controls=[Step(2), Patience(4), TimeLimit(0.001)] @@ -269,7 +291,7 @@ function MLJBase.restore(::EphemeralRegressor, serialized_fitresult) end @testset "save and restore" begin - #https://github.com/alan-turing-institute/MLJ.jl/issues/1099 + #https://github.com/JuliaAI/MLJ.jl/issues/1099 X, y = (; x = rand(10)), fill(42.0, 3) controls = [Step(1), NumberLimit(2)] imodel = IteratedModel(