diff --git a/Project.toml b/Project.toml index b819504..8ae14ba 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "EarlyStopping" uuid = "792122b4-ca99-40de-a6bc-6742525f08b6" authors = ["Anthony D. Blaom "] -version = "0.1.8" +version = "0.1.9" [deps] Dates = "ade2ca70-3891-5945-98fb-dc099432e06a" diff --git a/README.md b/README.md index 6a766e6..da016a4 100644 --- a/README.md +++ b/README.md @@ -2,13 +2,13 @@ | Linux | Coverage | | :-----------: | :------: | -| [![Build status](https://github.com/ablaom/EarlyStopping.jl/workflows/CI/badge.svg)](https://github.com/ablaom/EarlyStopping.jl/actions)| [![codecov.io](http://codecov.io/github/ablaom/EarlyStopping.jl/coverage.svg?branch=master)](http://codecov.io/github/ablaom/EarlyStopping.jl?branch=master) | +| [![Build status](https://github.com/JuliaAI/EarlyStopping.jl/workflows/CI/badge.svg)](https://github.com/JuliaAI/EarlyStopping.jl/actions)| [![codecov.io](http://codecov.io/github/JuliaAI/EarlyStopping.jl/coverage.svg?branch=master)](http://codecov.io/github/JuliaAI/EarlyStopping.jl?branch=master) | A small package for applying early stopping criteria to loss-generating iterative algorithms, with a view to training and optimizing machine learning models. -The basis of [IterationControl.jl](https://github.com/ablaom/IterationControl.jl), +The basis of [IterationControl.jl](https://github.com/JuliaAI/IterationControl.jl), a package externally controlling iterative algorithms. Includes the stopping criteria surveyed in [Prechelt, Lutz @@ -87,11 +87,12 @@ criterion | description | `TimeLimit(t=0.5)` | Stop after `t` hours | `NumberLimit(n=100)` | Stop after `n` loss updates (excl. "training losses") | `NumberSinceBest(n=6)`| Stop after `n` loss updates (excl. "training losses") | -`Threshold(value=0.0)`| Stop when `loss < value` | +`Threshold(value=0.0)`| Stop when `loss < value` | `GL(alpha=2.0)` | Stop after "Generalization Loss" exceeds `alpha` | ``GL_α`` `PQ(alpha=0.75, k=5)` | Stop after "Progress-modified GL" exceeds `alpha` | ``PQ_α`` `Patience(n=5)` | Stop after `n` consecutive loss increases | ``UP_s`` `Disjunction(c...)` | Stop when any of the criteria `c` apply | +`Warmup(c; n=1)` | Wait for `n` loss updates before checking criteria `c`| ## Criteria tracking both training and out-of-sample losses diff --git a/src/EarlyStopping.jl b/src/EarlyStopping.jl index 2c41007..a7598e3 100644 --- a/src/EarlyStopping.jl +++ b/src/EarlyStopping.jl @@ -17,6 +17,7 @@ export StoppingCriterion, NumberLimit, Threshold, Disjunction, + Warmup, criteria, stopping_time, EarlyStopper, diff --git a/src/criteria.jl b/src/criteria.jl index 16e0c33..753b13b 100644 --- a/src/criteria.jl +++ b/src/criteria.jl @@ -423,10 +423,60 @@ done(criterion::Threshold, state) = state < criterion.value needs_loss(::Type{<:Threshold}) = true +""" + Warmup(c::StoppingCriterion, n) +Wait for `n` updates before checking stopping criterion `c` +""" +struct Warmup{C} <: StoppingCriterion where {C <: StoppingCriterion} + criterion::C + n::Int + function Warmup(criterion::C, n::N) where {C <: StoppingCriterion, N <: Integer} + n > 0 || throw(ArgumentError("`n` must be positive. ")) + new{C}(criterion, Int(n)) + end +end +# Constructors for Warmup +Warmup() = Warmup(InvalidValue()) # Default for testing +Warmup(c; n = 1) = Warmup(c, n) # Provide kwargs interface + +# Initialize inner state for type-stability, and record first observation +update(c::Warmup, loss, ::Nothing) = update(c, loss) +update(criterion::Warmup, loss) = (1, update(criterion.criterion, loss)) + +# Catch uninitialized state +update_training(c::Warmup, loss, ::Nothing) = update_training(c, loss) +update_training(c::Warmup, loss) = (1, update_training(c.criterion, loss)) + +# Handle update vs update_training +update(c::Warmup, loss, state) = _update(update, c, loss, state) +update_training(c::Warmup, loss, state) = _update(update_training, c, loss, state) +needs_loss(::Type{<:Warmup{C}}) where C = needs_loss(C) +needs_training_losses(::Type{<:Warmup{C}}) where C = needs_training_losses(C) + +# Dispatch update and update_training here +function _update(f::Function, criterion::Warmup, loss, state) + n, inner = state + n += 1 + if n <= criterion.n + # Skip inner criterion + return n, inner + elseif n == criterion.n+1 + # First step of inner criterion + return n, f(criterion.criterion, loss) + else + # Step inner criterion + return n, f(criterion.criterion, loss, inner) + end +end +function done(criterion::Warmup, state) + # Only check if inner criterion is done after n updates + return state[1] <= criterion.n ? false : done(criterion.criterion, state[2]) +end +message(c::Warmup, state) = message(c.criterion, state[2]) ## NOT A NUMBER (deprecated) diff --git a/test/criteria.jl b/test/criteria.jl index 5b88735..466aebf 100644 --- a/test/criteria.jl +++ b/test/criteria.jl @@ -209,6 +209,37 @@ end end end +@testset "Warmup" begin + @test_throws ArgumentError Warmup(Patience(), 0) + for n in 1:(length(losses)-1) + @test stopping_time(Warmup(NumberLimit(1), n), losses) == n+1 + end + + # Test message + @testset "message" begin + stopper = Warmup(Patience(2); n = 2) + stopper_ref = Warmup(Patience(2), 2) + state, state_ref = nothing, nothing + for loss = losses + state = EarlyStopping.update(stopper, loss, state) + state_ref = EarlyStopping.update(stopper_ref, loss, state_ref) + @test message(stopper, state) == message(stopper_ref, state_ref) + end + end + + @testset "training" begin + stopper = Warmup(PQ(), 3) + is_training = @show map(x -> x%3 > 0, 1:length(losses)) + + # Feed 2 training losses + 1 non-training to criteria with/without + stop_time = stopping_time(stopper, losses, is_training) + ref_stop_time = stopping_time(stopper.criterion, losses[3:end], is_training) + + # PQ only counts training loss updates + @test round(stop_time/3, RoundUp) == ref_stop_time + end +end + # # DEPRECATED