diff --git a/Project.toml b/Project.toml index 8ae14ba..7c04222 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "EarlyStopping" uuid = "792122b4-ca99-40de-a6bc-6742525f08b6" authors = ["Anthony D. Blaom "] -version = "0.1.9" +version = "0.2.0" [deps] Dates = "ade2ca70-3891-5945-98fb-dc099432e06a" diff --git a/README.md b/README.md index da016a4..45950ab 100644 --- a/README.md +++ b/README.md @@ -101,7 +101,9 @@ For criteria tracking both an "out-of-sample" loss and a "training" loss (eg, stopping criterion of type `PQ`), specify `training=true` if the update is for training, as in - done!(stopper, 0.123, training=true) +``` julia +done!(stopper, 0.123, training=true) +``` In these cases, the out-of-sample update must always come after the corresponding training update. Multiple training updates may precede @@ -171,17 +173,58 @@ stopping_time(PQ(), To implement a new stopping criterion, one must: - Define a new `struct` for the criterion, which must subtype -`StoppingCriterion`. - + `StoppingCriterion`. - Overload methods `update` and `done` for the new type. -- Optionally overload methods `message`. +```julia + +struct NewCriteria <: StoppingCriterion + # Put relevant fields here +end + +# Provide a default constructor with all key-word arguments +NewCriteria(; kwargs...) = ... + +# Return the initial state of the NewCriteria after +# receiving an out-of-sample loss +update(c::NewCriteria, loss, ::Nothing) = ... -- Optionally overload `update_training` and the trait +# Return an updated state for NewCriteria given a `loss` +# and the current `state` +update(c::NewCriteria, loss, state) = ... + +# Return true if NewCriteria should stop given `state`. +# Always return false if `state === nothing` +done(c::NewCriteria, state) = state === nothing ? false : .... + +``` + +Optionally, one may define the following: + +- Overload the final message with `message`. +- Handle training losses by overloading `update_training` and the trait `needs_training_losses`. +```julia + +# Final message when NewCriteria triggers a stop +message(c::NewCriteria, state) = ... + +# Methods for initializing/updating the state given a training loss +update_training(c::NewCriteria, loss, ::Nothing) = ... +update_training(c::NewCriteria, loss, state) = ... + +# Define `needs_training_losses` iff NewCriteria needs one or more +# training losses per out-of-sample loss to work. Otherwise, this may be +# omitted +needs_training_losses(::Type{<:NewCriteria}) = true + +``` + +## New Criteria Example + We demonstrate this with a simplified version of the -[code](/src/criteria.jl) for `Patience`: +[code](./src/criteria.jl) for `Patience`: ### Defining the new type @@ -189,7 +232,7 @@ We demonstrate this with a simplified version of the ```julia using EarlyStopping -mutable struct Patience <: StoppingCriterion +struct Patience <: StoppingCriterion n::Int end Patience(; n=5) = Patience(n) @@ -199,15 +242,20 @@ Patience(; n=5) = Patience(n) All information to be "remembered" must passed around in an object called `state` below, which is the return value of `update` (and -`update_training`). The `update` function has two methods - one for -initialization, without a `state` argument, and one for all subsequent -loss updates, which requires the `state` returned by the preceding -`update` (or `update_training`) call: +`update_training`). The `update` function has two methods: + +- Initialization: `update(c::NewCriteria, loss, ::Nothing)` +- Subsequent Loss Updates: `update(c::NewCriteria, loss, state)` + +Where `state` is the return of the previous call to `update` or `update_training`. +Notice, that `state === nothing` indicates an uninitialized criteria. ```julia import EarlyStopping: update, done -update(criterion::Patience, loss) = (loss=loss, n_increases=0) # state +function update(criterion::Patience, loss, ::Nothing) + return (loss=loss, n_increases=0) # state +end function update(criterion::Patience, loss, state) old_loss, n = state @@ -220,15 +268,18 @@ function update(criterion::Patience, loss, state) end ``` -The `done` method returns `true` or `false` depending on the `state`: +The `done` method returns `true` or `false` depending on the `state`, but +always returns `false` if `state === nothing`. ```julia -done(criterion::Patience, state) = state.n_increases == criterion.n +done(criterion::Patience, state) = + state === nothing ? false : state.n_increases == criterion.n + ``` ### Optional methods -The final message of an `EarlyStopper` is generated by a `message` +The final message of an `EarlyStopper` is generated using a `message` method for `StoppingCriterion`. Here is the fallback (which does not use `state`): @@ -239,7 +290,7 @@ EarlyStopping.message(criteria::StoppingCriterion, state) The optional `update_training` methods (two for each criterion) have the same signature as the `update` methods above. Refer to the `PQ` -[code](/src/criteria.jl) for an example. +[code](./src/criteria.jl) for an example. If a stopping criterion requires one or more `update_training` calls per `update` call to work, you should overload the trait @@ -249,3 +300,13 @@ the source code: ```julia EarlyStopping.needs_training_losses(::Type{<:PQ}) = true ``` + +### Unit Testing + +The following are provided to facilitate testing of new criteria: + +- `stopping_time`: returns the stopping time for an iterator + `losses` using `criterion`. +- [`@test_criteria NewCriteria()`](./test/smoke.jl): Runs a suite of unit tests + against the provided `StoppingCriteria`. This macro is only part of the + test suite and is not part of the API. diff --git a/src/EarlyStopping.jl b/src/EarlyStopping.jl index a7598e3..dcbf62a 100644 --- a/src/EarlyStopping.jl +++ b/src/EarlyStopping.jl @@ -22,6 +22,7 @@ export StoppingCriterion, stopping_time, EarlyStopper, done!, + reset!, message, needs_training_losses, needs_loss @@ -29,7 +30,7 @@ export StoppingCriterion, include("api.jl") include("criteria.jl") include("disjunction.jl") -include("stopping_time.jl") include("object_oriented_api.jl") +include("stopping_time.jl") end # module diff --git a/src/api.jl b/src/api.jl index 4294557..d5db511 100644 --- a/src/api.jl +++ b/src/api.jl @@ -4,16 +4,8 @@ abstract type StoppingCriterion end ## FALL BACK METHODS - -# initialization call is either: -update(::StoppingCriterion, loss) = nothing # state - -# ... or: -update_training(::StoppingCriterion, loss) = nothing # state - -# subsequent updating: -update(::StoppingCriterion, loss, state) = state -update_training(::StoppingCriterion, loss, state) = state +update(::StoppingCriterion, loss, state=nothing) = state +update_training(::StoppingCriterion, loss, state=nothing) = state # returns whether it's time to stop: done(::StoppingCriterion, state) = false diff --git a/src/criteria.jl b/src/criteria.jl index 753b13b..15bfb68 100644 --- a/src/criteria.jl +++ b/src/criteria.jl @@ -42,13 +42,10 @@ $CUSTOM_ALTERNATIVE_DOC struct InvalidValue <: StoppingCriterion end # state = `true` when `NaN`, `Inf` or `-Inf` has been encountered - update(::InvalidValue, loss, state=false) = - state || isinf(loss) || isnan(loss) -update_training(::InvalidValue, loss, state=false) = - state || isinf(loss) || isnan(loss) - -done(::InvalidValue, state) = state + state !== nothing && state || isinf(loss) || isnan(loss) +update_training(c::InvalidValue, loss, state) = update(c, loss, state) +done(::InvalidValue, state) = state !== nothing && state message(::InvalidValue, state) = "Stopping early as `NaN`, "* "`Inf` or `-Inf` encountered. " @@ -86,13 +83,10 @@ TimeLimit(t) = TimeLimit(round(Int, 3_600_000*t) |> Millisecond) TimeLimit(; t =Minute(30)) = TimeLimit(t) # state = time at initialization - -update(::TimeLimit, loss) = now() -update_training(::TimeLimit, loss) = now() -update(::TimeLimit, loss, state) = state -done(criterion::TimeLimit, state) = begin - criterion.t < now() - state -end +update(::TimeLimit, loss, ::Nothing) = now() +update_training(::TimeLimit, loss, ::Nothing) = now() +done(criterion::TimeLimit, state) = + state === nothing ? false : criterion.t < now() - state ## GL @@ -130,14 +124,16 @@ struct GL <: StoppingCriterion end end GL(; alpha=2.0) = GL(alpha) - -update(::GL, loss) = (loss=loss, min_loss=loss) +update(::GL, loss, ::Nothing) = (loss=loss, min_loss=loss) update(::GL, loss, state) = (loss=loss, min_loss=min(loss, state.min_loss)) -# in case first loss consumed was a training loss: -update(criterion::GL, loss, ::Nothing) = update(criterion, loss) -done(criterion::GL, state) = - generalization_loss(state.loss, state.min_loss) > criterion.alpha - +function done(criterion::GL, state) + if state === nothing + return false + else + gl = generalization_loss(state.loss, state.min_loss) + return gl > criterion.alpha + end +end needs_loss(::Type{<:GL}) = true @@ -166,38 +162,35 @@ _min(x, y) = min(x, y) A stopping criterion for training iterative supervised learners. A stop is triggered when Prechelt's progress-modified generalization -loss exceeds the threshold `alpha`, or if the training progress drops -below `tol`. Here `k` is the maximum number of training (in-sample) -losses to be used to estimate the training progress. - -**Context and explanation of terminology.** The progress-modified loss -is defined in the following scenario: Estimates, ``E_1, E_2, ..., -E_t``, of the out-of-sample loss of an iterative supervised learner -are being computed, but not necessarily at every iteration. However, -training losses for every iteration *are* being made available -(usually as a by-product of training) which can be used to quantify -recent training progress, as follows. - -Fix a time ``j``, corresponding to some out-of-sample loss ``E_j``, -and let ``F_1`` be the corresponding training loss, ``F_2`` the -training loss in the previous interation of the model, ``F_3``, the -training loss two iterations previously, and so on. Let ``K`` denote -the number of model iterations since the last out-of-sample loss -``E_{j-1}`` was computed, or `k`, whichever is the smaller. Then -the *training progress* at time ``j`` is defined by +loss exceeds the threshold ``PQ_T > alpha``, or if the training progress drops +below ``P_j ≤ tol``. Here `k` is the number of training (in-sample) losses used to +estimate the training progress. + +## Context and explanation of terminology + +The *training progress* at time ``j`` is defined by `` P_j = 1000 |M - m|/|m| `` -where ``M`` is the mean of the training losses ``F_1, F_2, \\ldots , -F_K`` and ``m`` the minimum value of those losses. +where ``M`` is the mean of the last `k` training losses ``F_1, F_2, …, F_k`` +and ``m`` is the minimum value of those losses. -The *progress-modified generalization loss* at time ``t`` is given by +The *progress-modified generalization loss* at time ``t`` is then given by `` PQ_t = GL_t / P_t`` where ``GL_t`` is the generalization loss at time ``t``; see [`GL`](@ref). +PQ will stop when the following are true: + +1) At least `k` training samples have been collected via + `done!(c::PQ, loss; training = true)` or `update_training(c::PQ, loss, state)` +2) The last update was an out-of-sample update. + (`done!(::PQ, loss; training=true)` is always false) +3) The progress-modified generalization loss exceeds the threshold + ``PQ_t > alpha`` **OR** the training progress stalls ``P_j ≤ tol``. + Reference: $PRECHELT_REF. """ @@ -226,41 +219,24 @@ struct PQState{T} min_loss::Union{Nothing,T} end -function update_training(criterion::PQ, loss) - training_losses = [loss, ] - return PQState(training_losses, - true, - nothing, - nothing) -end - -update(::PQ, loss) = error("First loss reported to the PQ early stopping "* - "algorithm must be a training loss. ") +update_training(::PQ, loss, ::Nothing) = PQState([loss, ], true, nothing, nothing) +update(::PQ, loss::T, ::Nothing) where T = PQState(T[], false, loss, loss) function update_training(criterion::PQ, loss, state) - training_losses = if state.waiting_for_out_of_sample - prepend(state.training_losses, loss, criterion.k) - else - [loss, ] - end - return PQState(training_losses, - true, - state.loss, - state.min_loss) + training_losses = prepend(state.training_losses, loss, criterion.k) + return PQState(training_losses, true, state.loss, state.min_loss) end function update(::PQ, loss, state) - length(state.training_losses) > 1 || - error("The PQ stopping criterion requires at least two training "* - "losses between out-of-sample loss updates. ") - return PQState(state.training_losses, - false, - loss, - _min(loss, state.min_loss)) + min_loss = _min(loss, state.min_loss) + return PQState(state.training_losses, false, loss, min_loss) end function done(criterion::PQ, state) + state === nothing && return false + state.loss === nothing && return false state.waiting_for_out_of_sample && return false + length(state.training_losses) < criterion.k && return false GL = generalization_loss(state.loss, state.min_loss) P = progress(state.training_losses) P > criterion.tol || return true @@ -301,8 +277,8 @@ Patience(; n=5) = Patience(n) # Prechelt alias: const UP = Patience -update(criterion::Patience, loss) = (loss=loss, n_increases=0) -@inline function update(criterion::Patience, loss, state) +update(::Patience, loss, ::Nothing) = (loss=loss, n_increases=0) +@inline function update(::Patience, loss, state) old_loss, n = state if loss > old_loss n += 1 @@ -311,10 +287,9 @@ update(criterion::Patience, loss) = (loss=loss, n_increases=0) end return (loss=loss, n_increases=n) end -# in case first loss consumed was a training loss: -update(criterion::Patience, loss, ::Nothing) = update(criterion, loss) -done(criterion::Patience, state) = state.n_increases == criterion.n +done(criterion::Patience, state) = + state === nothing ? false : state.n_increases == criterion.n needs_loss(::Type{<:Patience}) = true @@ -342,8 +317,8 @@ struct NumberSinceBest <: StoppingCriterion end NumberSinceBest(; n=6) = NumberSinceBest(n) -update(criterion::NumberSinceBest, loss) = (best=loss, number_since_best=0) -@inline function update(criterion::NumberSinceBest, loss, state) +update(::NumberSinceBest, loss, ::Nothing) = (best=loss, number_since_best=0) +@inline function update(::NumberSinceBest, loss, state) best, number_since_best = state if loss < best best = loss @@ -354,10 +329,8 @@ update(criterion::NumberSinceBest, loss) = (best=loss, number_since_best=0) return (best=best, number_since_best=number_since_best) end -# in case first loss consumed was a training loss: -update(criterion::NumberSinceBest, loss, ::Nothing) = update(criterion, loss) - -done(criterion::NumberSinceBest, state) = state.number_since_best == criterion.n +done(criterion::NumberSinceBest, state) = + state === nothing ? false : state.number_since_best == criterion.n needs_loss(::Type{<:NumberSinceBest}) = true @@ -386,14 +359,10 @@ struct NumberLimit <: StoppingCriterion end NumberLimit(; n=100) = NumberLimit(n) -update(criterion::NumberLimit, loss) = 1 -@inline function update(criterion::NumberLimit, loss, state) - return state+1 -end -# in case first loss consumed was a training loss: -update(criterion::NumberLimit, loss, ::Nothing) = update(criterion, loss) - -done(criterion::NumberLimit, state) = state == criterion.n +update(criterion::NumberLimit, loss, ::Nothing) = 1 +update(::NumberLimit, loss, state) = state+1 +done(criterion::NumberLimit, state) = + state === nothing ? false : state >= criterion.n # ## THRESHOLD @@ -413,12 +382,9 @@ struct Threshold <: StoppingCriterion end Threshold(; value=0.0) = Threshold(value) -update(criterion::Threshold, loss) = loss -update(criterion::Threshold, loss, state) = loss -# in case first loss consumed was a training loss: -update(criterion::Threshold, loss, ::Nothing) = loss - -done(criterion::Threshold, state) = state < criterion.value +update(::Threshold, loss, state) = loss +done(criterion::Threshold, state) = + state === nothing ? false : state < criterion.value needs_loss(::Type{<:Threshold}) = true @@ -473,6 +439,7 @@ end function done(criterion::Warmup, state) # Only check if inner criterion is done after n updates + state === nothing && return false return state[1] <= criterion.n ? false : done(criterion.criterion, state[2]) end @@ -501,14 +468,10 @@ end # state = `true` when NaN has been encountered +update(::NotANumber, loss, state) = state !== nothing && state || isnan(loss) +update_training(c::NotANumber, loss, state) = update(c, loss, state) -update(::NotANumber, loss) = isnan(loss) -update_training(::NotANumber, loss) = isnan(loss) - -update(::NotANumber, loss, state) = state || isnan(loss) -update_training(::NotANumber, loss, state) = state || isnan(loss) - -done(::NotANumber, state) = state +done(::NotANumber, state) = state !== nothing && state message(::NotANumber, state) = "Stopping early as NaN encountered. " diff --git a/src/disjunction.jl b/src/disjunction.jl index 43a5289..4e740bf 100644 --- a/src/disjunction.jl +++ b/src/disjunction.jl @@ -26,7 +26,7 @@ Disjunction(a, b, c...) = Disjunction(Disjunction(a,b), c...) for f in [:update, :update_training] eval(quote - $f(d::Disjunction, loss) = + $f(d::Disjunction, loss, ::Nothing) = (a = $f(d.a, loss), b = $f(d.b, loss)) $f(d::Disjunction, loss, state) = (a = $f(d.a, loss, state.a), @@ -71,7 +71,7 @@ _done(criterion, state, old_done) = old_done || done(criterion, state) _done(d::Disjunction, state, old_done) = _done(d.a, state.a, _done(d.b, state.b, old_done)) -done(d::Disjunction, state) = _done(d, state, false) +done(d::Disjunction, state) = state === nothing ? false : _done(d, state, false) ## RECURSION TO BUILD MESSAGE diff --git a/src/object_oriented_api.jl b/src/object_oriented_api.jl index 8c5743b..f0b4cf7 100644 --- a/src/object_oriented_api.jl +++ b/src/object_oriented_api.jl @@ -23,11 +23,11 @@ For criteria tracking both an "out-of-sample" loss and a "training" loss (eg, stopping criterion of type `PQ`), specify `training=true` if the update is for training, as in - done!(stopper, 0.123, training=true) + done!(stopper, 0.123; training=true) -In these cases, the out-of-sample update must always come after the -corresponding training update. Multiple training updates may preceed -the out-of-sample update. +Zero or more training updates may precede each out-of-sample update. + +The state of the stopper can be reset or restored to a prior state using `reset!` """ mutable struct EarlyStopper{S} @@ -35,45 +35,41 @@ mutable struct EarlyStopper{S} verbosity::Int state EarlyStopper(criterion::S; verbosity=0) where S = - new{S}(criterion, verbosity) + new{S}(criterion, verbosity, nothing) end EarlyStopper(criteria...; kwargs...) = EarlyStopper(sum(criteria); kwargs...) +# Dispatch message, done, update and update_training to wrapped criterion for f in [:message, :done] - eval(quote - $f(stopper::EarlyStopper) = $f(stopper.criterion, stopper.state) - end) + @eval $f(stopper::EarlyStopper) = $f(stopper.criterion, stopper.state) end - -# defines 2 private functons `done_after_update!(stopper, loss)` and -# `done_after_training_update!(stopper, loss)`: for f in [:update, :update_training] - newf = Symbol(string("done_after_", f, "!")) - eval(quote - function $newf(stopper::EarlyStopper, loss) - if isdefined(stopper, :state) - stopper.state = $f(stopper.criterion, loss, stopper.state) - else - stopper.state = $f(stopper.criterion, loss) - end - return done(stopper) - end - end) + @eval $f(stopper::EarlyStopper, loss) = + $f(stopper.criterion, loss, stopper.state) end """ - done!(stopper::EarlyStopper) + done!(stopper::EarlyStopper, loss; training = false) """ function done!(stopper::EarlyStopper, loss; training=false) - ret = if training - done_after_update_training!(stopper, loss) + if training + stopper.state = update_training(stopper, loss) else - done_after_update!(stopper, loss) + stopper.state = update(stopper, loss) end if stopper.verbosity > 0 suffix = training ? "training " : "" loss_str = suffix*"loss" @info "$loss_str: $loss\t state: $(stopper.state)" end - return ret + return done(stopper) end + +""" + reset!(stopper::EarlyStopper) + reset!(stopper::EarlyStopper, state) + +Reset a stopper to it's uninitialized state or to a particular state +""" +reset!(stopper::EarlyStopper) = stopper.state = nothing +reset!(stopper::EarlyStopper, state) = stopper.state = state diff --git a/src/stopping_time.jl b/src/stopping_time.jl index 3b6de39..7160445 100644 --- a/src/stopping_time.jl +++ b/src/stopping_time.jl @@ -1,8 +1,3 @@ -_update(::Val{false}, args...) = update(args...) -_update(::Val{true}, args...) = update_training(args...) -_getindex(is_training::Nothing, s) = false -_getindex(is_training, s) = is_training[s] - """ stopping_time(criterion, losses; verbosity=0) stopping_time(criterion, losses, is_training; verbosity=0) @@ -22,38 +17,31 @@ julia> stopping_time(NotANumber(), [10.0, 3.0, 5.0, 4.0]) 0 ``` """ -function stopping_time(criterion, losses, training; verbosity=0) +function stopping_time(criterion::EarlyStopper, losses, training; verbosity=0) - t_stop = 0 # meaning no stop t = 0 # counts regular `update` calls but ignores `update_training` calls - s = 0 # counter for iteration over `losses` - - is_training = training === nothing ? - nothing : collect(training) - global state - - for loss in losses - s += 1 - _is_training = _getindex(is_training, s) - state = if s == 1 - _update(Val(_is_training), criterion, loss) - else - _update(Val(_is_training), criterion, loss, state) + for (loss, training) in zip(losses, training) + if !training + t += 1 # Increment count of out-of-sample updates end - _is_training || (t += 1) - verbosity < 1 || begin + + # Update criterion state + is_done = done!(criterion, loss; training = training) + + if verbosity > 0 @info "loss updates: $t" - @info "state: $state" - end - if !_is_training && done(criterion, state) - t_stop = t - break + @info "state: $(criterion.state)" end + is_done && return t end - return t_stop - + # No stopping + return 0 end +# If training is not provided -> Assume always out-of-sample stopping_time(criterion, losses; kwargs...) = - stopping_time(criterion, losses, nothing; kwargs...) + stopping_time(criterion, losses, Iterators.repeated(false); kwargs...) + +stopping_time(c::StoppingCriterion, args...; kwargs...) = + stopping_time(EarlyStopper(c), args...; kwargs...) diff --git a/test/api.jl b/test/api.jl index bf334a6..041432a 100644 --- a/test/api.jl +++ b/test/api.jl @@ -6,13 +6,15 @@ struct Dummy <: StoppingCriterion end @test update_training(Dummy(), 1.0) === nothing @test update_training(Dummy(), 1.0, 42.0) == 42.0 -update_training(c::Dummy, loss) = (training=loss, loss=nothing) -update_training(c::Dummy, loss, state) = (training=loss, loss=state.loss) -update(c::Dummy, loss, state) = (training=state.training, loss=loss) -done(c::Dummy, state) = state.training == state.loss +update_training(::Dummy, loss, ::Nothing) = (training=loss, loss=nothing) +update_training(::Dummy, loss, state) = (training=loss, loss=state.loss) +update(::Dummy, loss, state) = (training=state.training, loss=loss) +update(::Dummy, loss, ::Nothing) = (training=nothing, loss=loss) +done(::Dummy, state) = state === nothing ? false : state.training == state.loss stopper = EarlyStopper(Dummy()) @test !done!(stopper, 1.0, training=true) @test done!(stopper, 1.0) +@test_criteria Dummy() true diff --git a/test/criteria.jl b/test/criteria.jl index 466aebf..2e2fcee 100644 --- a/test/criteria.jl +++ b/test/criteria.jl @@ -17,23 +17,18 @@ end losses2[j] = NaN stopping_time(InvalidValue(), losses2) == j end - losses2 = Float64[1, 2, 3, 1, Inf, 3, 1, 2, 3] - is_training = Bool[1, 1, 0, 1, 1, 0, 1, 1, 0] - @test stopping_time(InvalidValue(), losses2, is_training) == 2 - losses2 = Float64[1, 2, 3, 1, 2, -Inf, 1, 2, 3] - @test stopping_time(InvalidValue(), losses2, is_training) == 2 - losses2 = Float64[1, 2, 3, 1, 2, 3, NaN, 2, 3] - @test stopping_time(InvalidValue(), losses2, is_training) == 3 - losses2 = Float64[1, 2, 3, 1, 2, 3, 1, 2, 3] - @test stopping_time(InvalidValue(), losses2, is_training) == 0 - @test_logs((:info, r"loss updates: 0"), - (:info, r"state: true"), - (:info, r"loss updates: 1"), - (:info, r"state: true"), - stopping_time(InvalidValue(), - [-Inf, 1], - [true, false], - verbosity=1)) + + is_training = map(x -> x%3 > 0, 1:length(losses)) + @test stopping_time(InvalidValue(), losses, is_training) == 0 + for n = 1:2:length(losses) + n_stop = sum(!, is_training[1:n]) + losses2 = copy(losses) + losses2[n] = Inf + @test stopping_time(InvalidValue(), losses2, is_training) == n_stop + losses2[n] = NaN + @test stopping_time(InvalidValue(), losses2, is_training) == n_stop + end + @test EarlyStopping.needs_loss(InvalidValue()) @test !EarlyStopping.needs_training_losses(InvalidValue()) end @@ -96,13 +91,7 @@ end c = PQ(alpha=10, k=2) - # first update must be training: - @test_throws Exception EarlyStopping.update(c, 1.0) - state = EarlyStopping.update_training(c, 10.0) - # at least two training updates before out-of-sample update: - @test_throws Exception EarlyStopping.update(c, state, 10.0) - state = EarlyStopping.update_training(c, 10.0, state) state = EarlyStopping.update(c, 10.0, state) @test EarlyStopping.done(c, state) # progress = 0 @@ -229,7 +218,7 @@ end @testset "training" begin stopper = Warmup(PQ(), 3) - is_training = @show map(x -> x%3 > 0, 1:length(losses)) + is_training = map(x -> x%3 > 0, 1:length(losses)) # Feed 2 training losses + 1 non-training to criteria with/without stop_time = stopping_time(stopper, losses, is_training) @@ -238,6 +227,12 @@ end # PQ only counts training loss updates @test round(stop_time/3, RoundUp) == ref_stop_time end + + @testset "integration" begin + @test_criteria Warmup(Patience()) + @test_criteria Warmup(NumberSinceBest()) + @test_criteria Warmup(Patience(3) + InvalidValue()) + end end @@ -252,23 +247,13 @@ end losses2[j] = NaN stopping_time(criterion, losses2) == j end - losses2 = Float64[1, 2, 3, 1, NaN, 3, 1, 2, 3] - is_training = Bool[1, 1, 0, 1, 1, 0, 1, 1, 0] - @test stopping_time(criterion, losses2, is_training) == 2 - losses2 = Float64[1, 2, 3, 1, 2, NaN, 1, 2, 3] - @test stopping_time(criterion, losses2, is_training) == 2 - losses2 = Float64[1, 2, 3, 1, 2, 3, NaN, 2, 3] - @test stopping_time(criterion, losses2, is_training) == 3 - losses2 = Float64[1, 2, 3, 1, 2, 3, 1, 2, 3] - @test stopping_time(criterion, losses2, is_training) == 0 - @test_logs((:info, r"loss updates: 0"), - (:info, r"state: true"), - (:info, r"loss updates: 1"), - (:info, r"state: true"), - stopping_time(criterion, - [NaN, 1], - [true, false], - verbosity=1)) + is_training = map(x -> x%3 > 0, 1:length(losses)) + @test stopping_time(criterion, losses, is_training) == 0 + for n = 1:2:length(losses) + losses2 = copy(losses); losses2[n] = NaN + @test stopping_time(criterion, losses2, is_training) == sum(!, is_training[1:n]) + end + @test EarlyStopping.needs_loss(criterion) @test !EarlyStopping.needs_training_losses(criterion) end diff --git a/test/disjunction.jl b/test/disjunction.jl index b99e25b..7b6c11b 100644 --- a/test/disjunction.jl +++ b/test/disjunction.jl @@ -23,6 +23,8 @@ end d2 = Patience(3) + InvalidValue() @test stopping_time(d2, [12.0, 10.0, 11.0, 12.0, 13.0, NaN]) == 5 @test stopping_time(d2, [NaN, 12.0, 10.0, 11.0, 12.0, 13.0]) == 1 + @test_criteria NumberSinceBest() + PQ() + @test_criteria Patience(3) + InvalidValue() end @testset "message" begin diff --git a/test/object_oriented_api.jl b/test/object_oriented_api.jl index a08d65c..8e116db 100644 --- a/test/object_oriented_api.jl +++ b/test/object_oriented_api.jl @@ -28,6 +28,13 @@ stopper = EarlyStopper(PQ(alpha=3.8, k=2), InvalidValue()) @test !done!(stopper, losses2[15], training=true) @test done!(stopper, losses2[16]) +# Test reset +state = stopper.state +reset!(stopper) +@test !EarlyStopping.done(stopper) +reset!(stopper, state) +@test EarlyStopping.done(stopper) + message(stopper) == "Early stop triggered by "* "PQ(3.8, 2, 2.220446049250313e-16) stopping criterio n. " diff --git a/test/runtests.jl b/test/runtests.jl index 99416fe..2ca8936 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1,18 +1,28 @@ using EarlyStopping, Dates, Test, InteractiveUtils -@testset "criteria.jl" begin - include("criteria.jl") -end +include("smoke.jl") -@testset "disjunction.jl" begin - include("disjunction.jl") -end +@testset "EarlyStopping" begin + @testset "smoke" begin + @testset "$C" for C in subtypes(StoppingCriterion) + @eval @test_criteria $(C()) + end + end -@testset "object_oriented_api.jl" begin - include("object_oriented_api.jl") -end + @testset "criteria.jl" begin + include("criteria.jl") + end + + @testset "disjunction.jl" begin + include("disjunction.jl") + end + + @testset "object_oriented_api.jl" begin + include("object_oriented_api.jl") + end -# to complete code coverage: -@testset "api.jl" begin - include("api.jl") + # to complete code coverage: + @testset "api.jl" begin + include("api.jl") + end end diff --git a/test/smoke.jl b/test/smoke.jl new file mode 100644 index 0000000..ce773d5 --- /dev/null +++ b/test/smoke.jl @@ -0,0 +1,74 @@ +""" + test_criteria StoppingCriterion() + +Runs a series of tests to check functionality of a StoppingCriterion + +- `update(::StoppingCriterion, loss, state=nothing)` is defined +- Checks that `done(::StoppingCriterion, nothing)` is `false` +- `done(::StoppingCriterion, state)` can handle state after `update` or `update_training` +- `message(::StoppingCriterion, state)` can handle state after `update` or `update_training` + +""" +macro test_criteria(criteria) + quote + @test $criteria isa StoppingCriterion + @testset "state is Nothing" begin + loss = rand() + using EarlyStopping: update, update_training + + # Check update for nothing state + c1 = update($criteria, loss) + c2 = update($criteria, loss, nothing) + @test compare_state(c1, c2) + + # Check update_training for nothing state + c1 = update_training($criteria, loss) + c2 = update_training($criteria, loss, nothing) + @test compare_state(c1, c2) + end + + # Check that done can be called after `update` or `update_training` + @testset "done" begin + loss = rand() + using EarlyStopping: done, update, update_training + @test done($criteria, nothing) == false + @test done($criteria, update($criteria, loss)) isa Bool + @test done($criteria, update_training($criteria, loss)) isa Bool + + # Training then out-of-sample + loss2 = rand() + s = update_training($criteria, loss2, update($criteria, loss)) + @test done($criteria, s) isa Bool + + # Out of sample then + @test done($criteria, update($criteria, loss2, update_training($criteria, loss))) isa Bool + end + + # Check that `message` can handle state after `update` or `update_training` + @testset "message" begin + loss = rand() + using EarlyStopping: update, update_training + @test message($criteria, update($criteria, loss)) isa String + @test message($criteria, update_training($criteria, loss)) isa String + end + end +end + +""" + compare_state(s1, s2) + +Helper method to check if StoppingCriterion states `s1` and `s2` are effectively +equivalent. +""" +compare_state(s1, s2) = s1 == s2 +compare_state(s1::DateTime, s2::DateTime) = abs(s1 - s2) < Millisecond(10) + +# Extend == for PQState +import Base: == +function ==(s1::EarlyStopping.PQState, s2::EarlyStopping.PQState) + length(s1.training_losses) == length(s2.training_losses) && + all(s1.training_losses .== s2.training_losses) && + s1.waiting_for_out_of_sample == s2.waiting_for_out_of_sample && + s1.loss == s2.loss && + s1.min_loss == s2.min_loss +end