Skip to content

Commit

Permalink
Allow resampling=InSample() in TunedModel; revamp docstring
Browse files Browse the repository at this point in the history
  • Loading branch information
ablaom committed May 27, 2024
1 parent f8b294b commit 9f01426
Show file tree
Hide file tree
Showing 3 changed files with 164 additions and 113 deletions.
4 changes: 1 addition & 3 deletions src/MLJIteration.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ const CONTROLS = vcat(IterationControl.CONTROLS,
:WithModelDo,
:CycleLearningRate,
:Save])

const CONTROLS_LIST = join(map(c->"$c()", CONTROLS), ", ", " and ")
const TRAINING_CONTROLS = [:Step, ]

# export all control types:
Expand All @@ -42,6 +42,4 @@ include("traits.jl")
include("ic_model.jl")
include("core.jl")



end # module
249 changes: 140 additions & 109 deletions src/constructors.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
const IterationResamplingTypes =
Union{Holdout,Nothing,MLJBase.TrainTestPairs}
Union{Holdout,InSample,Nothing,MLJBase.TrainTestPairs}


## TYPES AND CONSTRUCTOR
Expand Down Expand Up @@ -72,96 +72,119 @@ err_bad_iteration_parameter(p) =
ArgumentError("Model to be iterated does not have :($p) as an iteration parameter. ")

"""
IteratedModel(model=nothing,
controls=$CONTROLS_DEFAULT,
retrain=false,
resampling=Holdout(),
measure=nothing,
weights=nothing,
class_weights=nothing,
operation=predict,
verbosity=1,
check_measure=true,
iteration_parameter=nothing,
cache=true)
Wrap the specified `model <: Supervised` in the specified iteration
`controls`. Training a machine bound to the wrapper iterates a
corresonding machine bound to `model`. Here `model` should support
iteration.
To list all controls, do `MLJIteration.CONTROLS`. Controls are
summarized at
[https://alan-turing-institute.github.io/MLJ.jl/dev/getting_started/](https://alan-turing-institute.github.io/MLJ.jl/dev/controlling_iterative_models/)
but query individual doc-strings for details and advanced options. For
creating your own controls, refer to the documentation just cited.
To make out-of-sample losses available to the controls, the machine
bound to `model` is only trained on part of the data, as iteration
proceeds. See details on training below. Specify `retrain=true`
to ensure the model is retrained on *all* available data, using the
same number of iterations, once controlled iteration has stopped.
Specify `resampling=nothing` if all data is to be used for controlled
iteration, with each out-of-sample loss replaced by the most recent
training loss, assuming this is made available by the model
(`supports_training_losses(model) == true`). Otherwise, `resampling`
must have type `Holdout` (eg, `Holdout(fraction_train=0.8, rng=123)`).
Assuming `retrain=true` or `resampling=nothing`,
`iterated_model` behaves exactly like the original `model` but with
the iteration parameter automatically selected. If
`retrain=false` (default) and `resampling` is not `nothing`, then
`iterated_model` behaves like the original model trained on a subset
of the provided data.
Controlled iteration can be continued with new `fit!` calls (warm
restart) by mutating a control, or by mutating the iteration parameter
of `model`, which is otherwise ignored.
### Training
Given an instance `iterated_model` of `IteratedModel`, calling
`fit!(mach)` on a machine `mach = machine(iterated_model, data...)`
performs the following actions:
- Assuming `resampling !== nothing`, the `data` is split into *train* and
*test* sets, according to the specified `resampling` strategy, which
must have type `Holdout`.
- A clone of the wrapped model, `iterated_model.model`, is bound to
the train data in an internal machine, `train_mach`. If `resampling
=== nothing`, all data is used instead. This machine is the object
to which controls are applied. For example, `Callback(fitted_params
|> print)` will print the value of `fitted_params(train_mach)`.
IteratedModel(model;
controls=...,
resampling=Holdout(),
measure=nothing,
retrain=false,
advanced_options...,
)
Wrap the specified supervised `model` in the specified iteration `controls`. Here `model`
should support iteration, which is true if (`iteration_parameter(model)` is different from
`nothing`.
Available controls: $CONTROLS_LIST.
!!! important
To make out-of-sample losses available to the controls, the wrapped `model` is only
trained on part of the data, as iteration proceeds. The user may want to force
retraining on all data after controlled iteration has finished by specifying
`retrain=true`. See also "Training", and the `retrain` option, under "Extended help"
below.
# Extended help
# Options
- `controls=$CONTROLS_DEFAULT`: Controls are summarized at
[https://JuliaAI.github.io/MLJ.jl/dev/getting_started/](https://JuliaAI.github.io/MLJ.jl/dev/controlling_iterative_models/)
but query individual doc-strings for details and advanced options. For creating your own
controls, refer to the documentation just cited.
- `resampling=Holdout(fraction_train=0.7)`: The default resampling holds back 30% of data
for computing an out-of-sample estimate of performance (the "loss") for controls such
as `WithLossDo` and stopping criterion; specify `resampling=nothing` if all data is to
be used for controlled iteration, with each out-of-sample loss replaced by the most
recent training loss, assuming this is made available by the model
(`supports_training_losses(model) == true`). If the model does not provide report a
training loss, you can use `resampling=InSample()` instead, with an additional
performance cost. Otherwise, `resampling` must have type `Holdout` or be a vector with
one element of the form `(train_indices, test_indices)`.
- `measure=nothing`: StatisticalMeasures.jl compatible measure for estimating model
performance (the "loss", but the orientation is immaterial - i.e., this could be a
score). Inferred by default. Ignored if `resampling=nothing`.
- `retrain=false`: If `retrain=true` or `resampling=nothing`, `iterated_model` behaves
exactly like the original `model` but with the iteration parameter automatically
selected ("learned"). That is, the model is retrained on *all* available data, using the
same number of iterations, once controlled iteration has stopped. This is typically
desired if wrapping the iterated model further, or when inserting in a pipeline or other
composite model. If `retrain=false` (default) and `resampling isa Holdout`, then
`iterated_model` behaves like the original model trained on a subset of the provided
data.
- `weights=nothing`: per-observation weights to be passed to `measure` where supported; if
unspecified, these are understood to be uniform.
- `class_weights=nothing`: class-weights to be passed to `measure` where supported; if
unspecified, these are understood to be uniform.
- `operation=nothing`: Operation, such as `predict` or `predict_mode`, for computing
target values, or proxy target values, for consumption by `measure`; automatically
inferred by default.
- `check_measure=true`: Specify `false` to override checks on `measure` for compatibility
with the training data.
- `iteration_parameter=nothing`: A symbol, such as `:epochs`, naming the iteration
parameter of `model`; inferred by default. Note that the actual value of the iteration
parameter in the supplied `model` is ignored; only the value of an internal clone is
mutated during training the wrapped model.
- `cache=true`: Whether or not model-specific representations of data are cached in
between iteration parameter increments; specify `cache=false` to prioritize memory over
speed.
# Training
Training an instance `iterated_model` of `IteratedModel` on some `data` (by binding to a
machine and calling `fit!`, for example) performs the following actions:
- Assuming `resampling !== nothing`, the `data` is split into *train* and *test* sets,
according to the specified `resampling` strategy.
- A clone of the wrapped model, `model` is bound to the train data in an internal machine,
`train_mach`. If `resampling === nothing`, all data is used instead. This machine is the
object to which controls are applied. For example, `Callback(fitted_params |> print)`
will print the value of `fitted_params(train_mach)`.
- The iteration parameter of the clone is set to `0`.
- The specified `controls` are repeatedly applied to `train_mach` in
sequence, until one of the controls triggers a stop. Loss-based
controls (eg, `Patience()`, `GL()`, `Threshold(0.001)`) use an
out-of-sample loss, obtained by applying `measure` to predictions
and the test target values. (Specifically, these predictions are
those returned by `operation(train_mach)`.) If `resampling ===
nothing` then the most recent training loss is used instead. Some
controls require *both* out-of-sample and training losses (eg,
`PQ()`).
- The specified `controls` are repeatedly applied to `train_mach` in sequence, until one
of the controls triggers a stop. Loss-based controls (eg, `Patience()`, `GL()`,
`Threshold(0.001)`) use an out-of-sample loss, obtained by applying `measure` to
predictions and the test target values. (Specifically, these predictions are those
returned by `operation(train_mach)`.) If `resampling === nothing` then the most recent
training loss is used instead. Some controls require *both* out-of-sample and training
losses (eg, `PQ()`).
- Once a stop has been triggered, a clone of `model` is bound to all
`data` in a machine called `mach_production` below, unless
`retrain == false` or `resampling === nothing`, in which case
`mach_production` coincides with `train_mach`.
- Once a stop has been triggered, a clone of `model` is bound to all `data` in a machine
called `mach_production` below, unless `retrain == false` (true by default) or
`resampling === nothing`, in which case `mach_production` coincides with `train_mach`.
### Prediction
# Prediction
Calling `predict(mach, Xnew)` returns `predict(mach_production,
Xnew)`. Similar similar statements hold for `predict_mean`,
`predict_mode`, `predict_median`.
Calling `predict(mach, Xnew)` in the example above returns `predict(mach_production,
Xnew)`. Similar similar statements hold for `predict_mean`, `predict_mode`,
`predict_median`.
### Controls
# Controls that mutate parameters
A control is permitted to mutate the fields (hyper-parameters) of
`train_mach.model` (the clone of `model`). For example, to mutate a
Expand All @@ -174,11 +197,25 @@ in that parameter, this will trigger retraining of `train_mach` from
scratch, with a different training outcome, which is not recommended.
### Warm restarts
# Warm restarts
If `iterated_model` is mutated and `fit!(mach)` is called again, then
a warm restart is attempted if the only parameters to change are
`model` or `controls` or both.
In the following example, the second `fit!` call will not restart training of the internal
`train_mach`, assuming `model` supports warm restarts:
```julia
iterated_model = IteratedModel(
model,
controls = [Step(1), NumberLimit(100)],
)
mach = machine(iterated_model, X, y)
fit!(mach) # train for 100 iterations
iterated_model.controls = [Step(1), NumberLimit(50)],
fit!(mach) # train for an *extra* 50 iterations
```
More generally, if `iterated_model` is mutated and `fit!(mach)` is called again, then a
warm restart is attempted if the only parameters to change are `model` or `controls` or
both.
Specifically, `train_mach.model` is mutated to match the current value
of `iterated_model.model` and the iteration parameter of the latter is
Expand All @@ -195,7 +232,7 @@ function IteratedModel(args...;
measure=measures,
weights=nothing,
class_weights=nothing,
operation=predict,
operation=nothing,
retrain=false,
check_measure=true,
iteration_parameter=nothing,
Expand All @@ -211,30 +248,24 @@ function IteratedModel(args...;
atom = model
end

options = (
atom,
controls,
resampling,
measure,
weights,
class_weights,
operation,
retrain,
check_measure,
iteration_parameter,
cache,
)

if atom isa Deterministic
iterated_model = DeterministicIteratedModel(atom,
controls,
resampling,
measure,
weights,
class_weights,
operation,
retrain,
check_measure,
iteration_parameter,
cache)
iterated_model = DeterministicIteratedModel(options...)
elseif atom isa Probabilistic
iterated_model = ProbabilisticIteratedModel(atom,
controls,
resampling,
measure,
weights,
class_weights,
operation,
retrain,
check_measure,
iteration_parameter,
cache)
iterated_model = ProbabilisticIteratedModel(options...)
else
throw(ERR_NOT_SUPERVISED)
end
Expand Down
24 changes: 23 additions & 1 deletion test/core.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ using IterationControl
using MLJBase
using MLJModelInterface
using StatisticalMeasures
using StableRNGs
using ..DummyModel

X, y = make_dummy(N=20)
Expand All @@ -26,6 +27,7 @@ model = DummyIterativeModel(n=0)
end
IterationControl.loss(mach::Machine{<:DummyIterativeModel}) =
last(training_losses(mach))

IterationControl.train!(mach, controls..., verbosity=0)
losses1 = report(mach).training_losses
yhat1 = predict(mach, X)
Expand Down Expand Up @@ -104,6 +106,26 @@ model = DummyIterativeModel(n=0)
@test report(mach).n_iterations == 5
end

@testset "resampling = InSample()" begin
model = DummyIterativeModel(n=0, rng=StableRNG(123))
controls=[Step(2), NumberLimit(10)]

# using `resampling=nothing`:
imodel = IteratedModel(model=model, controls=controls, resampling=nothing)
mach = machine(imodel, X, y)
fit!(mach, verbosity=0)
y1 = predict(mach, rows=1:10)

# using `resampling=InSample()`:
model = DummyIterativeModel(n=0, rng=StableRNG(123))
imodel = IteratedModel(model=model, controls=controls, resampling=InSample())
mach = machine(imodel, X, y)
fit!(mach, verbosity=0)
y2 = predict(mach, rows=1:10)

@test y1 == y2
end

@testset "integration: resampling=Holdout()" begin

controls=[Step(2), Patience(4), TimeLimit(0.001)]
Expand Down Expand Up @@ -269,7 +291,7 @@ function MLJBase.restore(::EphemeralRegressor, serialized_fitresult)
end

@testset "save and restore" begin
#https://github.com/alan-turing-institute/MLJ.jl/issues/1099
#https://github.com/JuliaAI/MLJ.jl/issues/1099
X, y = (; x = rand(10)), fill(42.0, 3)
controls = [Step(1), NumberLimit(2)]
imodel = IteratedModel(
Expand Down

0 comments on commit 9f01426

Please sign in to comment.