From c570ab9bb5b6c96ffd85533206ed365bd7a3c727 Mon Sep 17 00:00:00 2001 From: "Anthony D. Blaom" Date: Thu, 11 Nov 2021 16:55:33 +1300 Subject: [PATCH 1/4] remove completely the needs_loss/needs_training_losses traits --- src/EarlyStopping.jl | 4 +--- src/api.jl | 7 ------- src/criteria.jl | 21 ++------------------- test/criteria.jl | 28 ---------------------------- 4 files changed, 3 insertions(+), 57 deletions(-) diff --git a/src/EarlyStopping.jl b/src/EarlyStopping.jl index dcbf62a..1e976aa 100644 --- a/src/EarlyStopping.jl +++ b/src/EarlyStopping.jl @@ -23,9 +23,7 @@ export StoppingCriterion, EarlyStopper, done!, reset!, - message, - needs_training_losses, - needs_loss + message include("api.jl") include("criteria.jl") diff --git a/src/api.jl b/src/api.jl index d5db511..36e0875 100644 --- a/src/api.jl +++ b/src/api.jl @@ -12,10 +12,3 @@ done(::StoppingCriterion, state) = false message(criterion::StoppingCriterion, state) = "Stop triggered by "* "$criterion stopping criterion. " - -needs_loss(::Type) = false -needs_training_losses(::Type) = false - -for trait in [:needs_loss, :needs_training_losses] - eval(:($trait(c) = $trait(typeof(c)))) -end diff --git a/src/criteria.jl b/src/criteria.jl index 15bfb68..db83236 100644 --- a/src/criteria.jl +++ b/src/criteria.jl @@ -50,8 +50,6 @@ done(::InvalidValue, state) = state !== nothing && state message(::InvalidValue, state) = "Stopping early as `NaN`, "* "`Inf` or `-Inf` encountered. " -needs_loss(::Type{<:InvalidValue}) = true - ## TIME LIMIT @@ -134,7 +132,6 @@ function done(criterion::GL, state) return gl > criterion.alpha end end -needs_loss(::Type{<:GL}) = true ## PQ @@ -244,9 +241,6 @@ function done(criterion::PQ, state) return PQ > criterion.alpha end -needs_loss(::Type{<:PQ}) = true -needs_training_losses(::Type{<:PQ}) = true - ## PATIENCE @@ -291,8 +285,6 @@ end done(criterion::Patience, state) = state === nothing ? false : state.n_increases == criterion.n -needs_loss(::Type{<:Patience}) = true - ## NUMBER SINCE BEST @@ -332,8 +324,6 @@ end done(criterion::NumberSinceBest, state) = state === nothing ? false : state.number_since_best == criterion.n -needs_loss(::Type{<:NumberSinceBest}) = true - # # NUMBER LIMIT @@ -386,8 +376,6 @@ update(::Threshold, loss, state) = loss done(criterion::Threshold, state) = state === nothing ? false : state < criterion.value -needs_loss(::Type{<:Threshold}) = true - """ Warmup(c::StoppingCriterion, n) @@ -417,9 +405,8 @@ update_training(c::Warmup, loss) = (1, update_training(c.criterion, loss)) # Handle update vs update_training update(c::Warmup, loss, state) = _update(update, c, loss, state) -update_training(c::Warmup, loss, state) = _update(update_training, c, loss, state) -needs_loss(::Type{<:Warmup{C}}) where C = needs_loss(C) -needs_training_losses(::Type{<:Warmup{C}}) where C = needs_training_losses(C) +update_training(c::Warmup, loss, state) = + _update(update_training, c, loss, state) # Dispatch update and update_training here function _update(f::Function, criterion::Warmup, loss, state) @@ -474,7 +461,3 @@ update_training(c::NotANumber, loss, state) = update(c, loss, state) done(::NotANumber, state) = state !== nothing && state message(::NotANumber, state) = "Stopping early as NaN encountered. " - -needs_loss(::Type{<:NotANumber}) = true - - diff --git a/test/criteria.jl b/test/criteria.jl index 2e2fcee..6900b34 100644 --- a/test/criteria.jl +++ b/test/criteria.jl @@ -5,8 +5,6 @@ losses = Float64[10, 8, 9, 10, 11, 12, 12, 13, 14, 15, 16, 17, 16] @testset "Never" begin @test stopping_time(Never(), losses) == 0 - @test !EarlyStopping.needs_loss(Never()) - @test !EarlyStopping.needs_training_losses(Never()) end @testset "InvalidValue" begin @@ -28,9 +26,6 @@ end losses2[n] = NaN @test stopping_time(InvalidValue(), losses2, is_training) == n_stop end - - @test EarlyStopping.needs_loss(InvalidValue()) - @test !EarlyStopping.needs_training_losses(InvalidValue()) end struct SleepyIterator{T} @@ -51,8 +46,6 @@ Base.iterate(iter::SleepyIterator, state) = @test stopping_time(TimeLimit(t=Millisecond(600)), sleepy_losses) == 7 # codecov: @test EarlyStopping.update_training(TimeLimit(), 42.0) <= now() - @test !EarlyStopping.needs_loss(TimeLimit()) - @test !EarlyStopping.needs_training_losses(TimeLimit()) end @testset "GL" begin @@ -74,9 +67,6 @@ end @test stopping_time(GL(alpha=90), losses) == 11 @test stopping_time(GL(alpha=110), losses) == 12 @test stopping_time(GL(alpha=1000), losses) == 0 - - @test EarlyStopping.needs_loss(GL()) - @test !EarlyStopping.needs_training_losses(GL()) end @testset "PQ" begin @@ -133,9 +123,6 @@ end @test stopping_time(PQ(alpha=11.6, k=2), losses2, is_training) == 6 @test stopping_time(PQ(alpha=15.1, k=2), losses2, is_training) == 8 @test stopping_time(PQ(alpha=15.3, k=2), losses2, is_training) == 0 - - @test EarlyStopping.needs_loss(PQ()) - @test EarlyStopping.needs_training_losses(PQ()) end @testset "Patience" begin @@ -146,9 +133,6 @@ end @test stopping_time(Patience(n=3), losses) == 5 @test stopping_time(Patience(n=2), losses) == 4 @test stopping_time(Patience(n=1), losses) == 3 - - @test EarlyStopping.needs_loss(Patience()) - @test !EarlyStopping.needs_training_losses(Patience()) end @testset "NumberSinceBest" begin @@ -163,9 +147,6 @@ end losses2 = Float64[10, 9, 8, 9, 10, 7, 10, 10, 10, 10] @test stopping_time(NumberSinceBest(n=2), losses2) == 5 @test stopping_time(NumberSinceBest(n=3), losses2) == 9 - - @test EarlyStopping.needs_loss(NumberSinceBest()) - @test !EarlyStopping.needs_training_losses(NumberSinceBest()) end @testset "NumberLimit" begin @@ -174,17 +155,11 @@ end for i in 1:length(losses) @test stopping_time(NumberLimit(i), losses) == i end - - @test !EarlyStopping.needs_loss(NumberLimit()) - @test !EarlyStopping.needs_training_losses(NumberLimit()) - end @testset "Threshold" begin @test Threshold().value == 0.0 stopping_time(Threshold(2.5), Float64[12, 32, 3, 2, 5, 7]) == 4 - @test EarlyStopping.needs_loss(Threshold()) - @test !EarlyStopping.needs_training_losses(Threshold()) end @testset "robustness to first loss being a training loss" begin @@ -253,9 +228,6 @@ end losses2 = copy(losses); losses2[n] = NaN @test stopping_time(criterion, losses2, is_training) == sum(!, is_training[1:n]) end - - @test EarlyStopping.needs_loss(criterion) - @test !EarlyStopping.needs_training_losses(criterion) end From f92095fba0547589eaa547da4dec3c01380fefda Mon Sep 17 00:00:00 2001 From: "Anthony D. Blaom" Date: Thu, 11 Nov 2021 17:19:22 +1300 Subject: [PATCH 2/4] update readme --- README.md | 12 ++++-------- 1 file changed, 4 insertions(+), 8 deletions(-) diff --git a/README.md b/README.md index 45950ab..0915363 100644 --- a/README.md +++ b/README.md @@ -177,7 +177,6 @@ To implement a new stopping criterion, one must: - Overload methods `update` and `done` for the new type. ```julia - struct NewCriteria <: StoppingCriterion # Put relevant fields here end @@ -196,7 +195,6 @@ update(c::NewCriteria, loss, state) = ... # Return true if NewCriteria should stop given `state`. # Always return false if `state === nothing` done(c::NewCriteria, state) = state === nothing ? false : .... - ``` Optionally, one may define the following: @@ -206,20 +204,18 @@ Optionally, one may define the following: `needs_training_losses`. ```julia - # Final message when NewCriteria triggers a stop message(c::NewCriteria, state) = ... # Methods for initializing/updating the state given a training loss update_training(c::NewCriteria, loss, ::Nothing) = ... update_training(c::NewCriteria, loss, state) = ... +``` -# Define `needs_training_losses` iff NewCriteria needs one or more -# training losses per out-of-sample loss to work. Otherwise, this may be -# omitted -needs_training_losses(::Type{<:NewCriteria}) = true +**Wrappers.** If your criterion wraps another criterion (as `Warmup` +does) then the `criterion` must be a field and must store the +criterion being wrapped. -``` ## New Criteria Example From 443d49cbaadd35301bd8d97cc0886c994c98872d Mon Sep 17 00:00:00 2001 From: "Anthony D. Blaom" Date: Fri, 12 Nov 2021 11:01:53 +1300 Subject: [PATCH 3/4] extend defintions of isnan and isinf to Nothing for InvalidValue --- src/criteria.jl | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/src/criteria.jl b/src/criteria.jl index db83236..28eab33 100644 --- a/src/criteria.jl +++ b/src/criteria.jl @@ -41,9 +41,14 @@ $CUSTOM_ALTERNATIVE_DOC """ struct InvalidValue <: StoppingCriterion end +_isinf(x) = isinf(x) +_isinf(::Nothing) = false +_isnan(x) = isnan(x) +_isnan(::Nothing) = false + # state = `true` when `NaN`, `Inf` or `-Inf` has been encountered update(::InvalidValue, loss, state=false) = - state !== nothing && state || isinf(loss) || isnan(loss) + state !== nothing && state || _isinf(loss) || _isnan(loss) update_training(c::InvalidValue, loss, state) = update(c, loss, state) done(::InvalidValue, state) = state !== nothing && state From 419db79a5f013aff00ebec6b321f6542e9aca4ca Mon Sep 17 00:00:00 2001 From: "Anthony D. Blaom" Date: Fri, 12 Nov 2021 11:02:30 +1300 Subject: [PATCH 4/4] bump 0.3.0 --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 7c04222..3e5a830 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "EarlyStopping" uuid = "792122b4-ca99-40de-a6bc-6742525f08b6" authors = ["Anthony D. Blaom "] -version = "0.2.0" +version = "0.3.0" [deps] Dates = "ade2ca70-3891-5945-98fb-dc099432e06a"