Skip to content

Commit

Permalink
Merge pull request #25 from ablaom/dev
Browse files Browse the repository at this point in the history
For a 0.1.8 release
  • Loading branch information
ablaom authored Apr 13, 2021
2 parents a50a729 + 4122113 commit fe258fe
Show file tree
Hide file tree
Showing 7 changed files with 154 additions and 47 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "EarlyStopping"
uuid = "792122b4-ca99-40de-a6bc-6742525f08b6"
authors = ["Anthony D. Blaom <[email protected]>"]
version = "0.1.7"
version = "0.1.8"

[deps]
Dates = "ade2ca70-3891-5945-98fb-dc099432e06a"
Expand Down
14 changes: 7 additions & 7 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ two of these criteria (either triggering the stop):
```julia
using EarlyStopping

stopper = EarlyStopper(Patience(2), NotANumber()) # multiple criteria
stopper = EarlyStopper(Patience(2), InvalidValue()) # multiple criteria
done!(stopper, 0.123) # false
done!(stopper, 0.234) # false
done!(stopper, 0.345) # true
Expand Down Expand Up @@ -83,7 +83,7 @@ REPL. Here is a short summary:
criterion | description | notation in Prechelt
----------------------|-------------------------------------------------------|---------------------
`Never()` | Never stop |
`NotANumber()` | Stop when `NaN` encountered |
`InvalidValue()` | Stop when `NaN`, `Inf` or `-Inf` encountered |
`TimeLimit(t=0.5)` | Stop after `t` hours |
`NumberLimit(n=100)` | Stop after `n` loss updates (excl. "training losses") |
`NumberSinceBest(n=6)`| Stop after `n` loss updates (excl. "training losses") |
Expand All @@ -108,7 +108,7 @@ the out-of-sample update, as in the following example:

```julia
criterion = PQ(alpha=2.0, k=2)
needs_in_and_out_of_sample(criterion) # true
needs_training_losses(criterion) # true

stopper = EarlyStopper(criterion)

Expand Down Expand Up @@ -140,7 +140,7 @@ criteria (see below). If the iterator terminates without a stop, `0`
is returned.

```julia
julia> stopping_time(NotANumber(), [10.0, 3.0, NaN, 4.0])
julia> stopping_time(InvalidValue(), [10.0, 3.0, Inf, 4.0])
3

julia> stopping_time(Patience(3), [10.0, 3.0, 4.0, 5.0], verbosity=1)
Expand Down Expand Up @@ -177,7 +177,7 @@ To implement a new stopping criterion, one must:
- Optionally overload methods `message`.
- Optionally overload `update_training` and the trait
`needs_in_and_out_of_sample`.
`needs_training_losses`.
We demonstrate this with a simplified version of the
[code](/src/criteria.jl) for `Patience`:
Expand Down Expand Up @@ -242,9 +242,9 @@ the same signature as the `update` methods above. Refer to the `PQ`
If a stopping criterion requires one or more `update_training` calls
per `update` call to work, you should overload the trait
`needs_in_and_out_of_sample` for that type, as in this example from
`needs_training_losses` for that type, as in this example from
the source code:
```julia
EarlyStopping.needs_in_and_out_of_sample(::Type{<:PQ}) = true
EarlyStopping.needs_training_losses(::Type{<:PQ}) = true
```
23 changes: 18 additions & 5 deletions src/EarlyStopping.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,25 @@ using Statistics
import Base.+

export StoppingCriterion,
Never, NotANumber, TimeLimit, GL, NumberSinceBest,
Never,
InvalidValue,
NotANumber, # deprecated
TimeLimit,
GL,
NumberSinceBest,
Patience,
UP, PQ, NumberLimit,
Threshold, Disjunction,
criteria, stopping_time, EarlyStopper,
done!, message, needs_training_losses
UP,
PQ,
NumberLimit,
Threshold,
Disjunction,
criteria,
stopping_time,
EarlyStopper,
done!,
message,
needs_training_losses,
needs_loss

include("api.jl")
include("criteria.jl")
Expand Down
85 changes: 71 additions & 14 deletions src/criteria.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,10 @@ const PRECHELT_REF = "[Prechelt, Lutz (1998): \"Early Stopping"*
const STOPPING_DOC = "An early stopping criterion for loss-reporting "*
"iterative algorithms. "

const CUSTOM_ALTERNATIVE_DOC = "For a customizable loss-based stopping "*
"criterion, use [`WithLossDo`](@ref) or [`WithTrainingLossesDo`](@ref) "*
"with the `stop_if_true=true` option. "


## NEVER

Expand All @@ -22,31 +26,34 @@ See also [`NotANumber`](@ref), for stopping on encountering `NaN`.
struct Never <: StoppingCriterion end


## NOT A NUMBER
## OUT OF BOUNDS

"""
NotANumber()
InvalidValue()
$STOPPING_DOC
Stop if a loss of `NaN` is encountered.
Stop if a loss (or training loss) is `NaN`, `Inf` or `-Inf` (or, more
precisely, if `isnan(loss)` or `isinf(loss)` is `true`).
"""
struct NotANumber <: StoppingCriterion end
$CUSTOM_ALTERNATIVE_DOC
# state = `true` when NaN has been encountered
"""
struct InvalidValue <: StoppingCriterion end

update(::NotANumber, loss) = isnan(loss)
update_training(::NotANumber, loss) = isnan(loss)
# state = `true` when `NaN`, `Inf` or `-Inf` has been encountered

update(::NotANumber, loss, state) = state || isnan(loss)
update_training(::NotANumber, loss, state) = state || isnan(loss)
update(::InvalidValue, loss, state=false) =
state || isinf(loss) || isnan(loss)
update_training(::InvalidValue, loss, state=false) =
state || isinf(loss) || isnan(loss)

done(::NotANumber, state) = state
done(::InvalidValue, state) = state

message(::NotANumber, state) = "Stopping early as NaN encountered. "
message(::InvalidValue, state) = "Stopping early as `NaN`, "*
"`Inf` or `-Inf` encountered. "

needs_loss(::Type{<:NotANumber}) = true
needs_loss(::Type{<:InvalidValue}) = true


## TIME LIMIT
Expand All @@ -63,7 +70,7 @@ Any Julia built-in `Real` type can be used for `t`. Subtypes of
`Period` may also be used, as in `TimeLimit(t=Minute(30))`.
Internally, `t` is rounded to nearest millisecond.
``
``
"""
struct TimeLimit <: StoppingCriterion
t::Millisecond
Expand Down Expand Up @@ -278,6 +285,8 @@ A stop is triggered by `n` consecutive increases in the loss.
Denoted "_UP_s" in $PRECHELT_REF.
$CUSTOM_ALTERNATIVE_DOC
"""
struct Patience <: StoppingCriterion
n::Int
Expand Down Expand Up @@ -320,6 +329,8 @@ $STOPPING_DOC
A stop is triggered when the number of calls to the control, since the
lowest value of the loss so far, is `n`.
$CUSTOM_ALTERNATIVE_DOC
"""
struct NumberSinceBest <: StoppingCriterion
n::Int
Expand Down Expand Up @@ -394,6 +405,8 @@ $STOPPING_DOC
A stop is triggered as soon as the loss drops below `value`.
$CUSTOM_ALTERNATIVE_DOC
"""
struct Threshold <: StoppingCriterion
value::Float64
Expand All @@ -408,3 +421,47 @@ update(criterion::Threshold, loss, ::Nothing) = loss
done(criterion::Threshold, state) = state < criterion.value

needs_loss(::Type{<:Threshold}) = true








## NOT A NUMBER (deprecated)

"""
NotANumber()
$STOPPING_DOC
Stop if a loss of `NaN` is encountered.
**Now deprecated** in favour of [`InvalidValue`](@ref).
"""
struct NotANumber <: StoppingCriterion
function NotANumber()
Base.depwarn("`NotANumber()` is deprecated. Use `InvalidValue()` "*
"to trap `NaN`, `Inf` or `-Inf`. ", :NotANumber)
return new()
end
end


# state = `true` when NaN has been encountered

update(::NotANumber, loss) = isnan(loss)
update_training(::NotANumber, loss) = isnan(loss)

update(::NotANumber, loss, state) = state || isnan(loss)
update_training(::NotANumber, loss, state) = state || isnan(loss)

done(::NotANumber, state) = state

message(::NotANumber, state) = "Stopping early as NaN encountered. "

needs_loss(::Type{<:NotANumber}) = true


66 changes: 51 additions & 15 deletions test/criteria.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,36 +9,35 @@ losses = Float64[10, 8, 9, 10, 11, 12, 12, 13, 14, 15, 16, 17, 16]
@test !EarlyStopping.needs_training_losses(Never())
end

@testset "NotANumber" begin
@test stopping_time(NotANumber(), losses) == 0
@testset "InvalidValue" begin
@test stopping_time(InvalidValue(), losses) == 0
N = 5
losses2 = fill(123.4, N)
@test all(reverse(eachindex(losses2))) do j
losses2[j] = NaN
stopping_time(NotANumber(), losses2) == j
stopping_time(InvalidValue(), losses2) == j
end
losses2 = Float64[1, 2, 3, 1, NaN, 3, 1, 2, 3]
losses2 = Float64[1, 2, 3, 1, Inf, 3, 1, 2, 3]
is_training = Bool[1, 1, 0, 1, 1, 0, 1, 1, 0]
@test stopping_time(NotANumber(), losses2, is_training) == 2
losses2 = Float64[1, 2, 3, 1, 2, NaN, 1, 2, 3]
@test stopping_time(NotANumber(), losses2, is_training) == 2
@test stopping_time(InvalidValue(), losses2, is_training) == 2
losses2 = Float64[1, 2, 3, 1, 2, -Inf, 1, 2, 3]
@test stopping_time(InvalidValue(), losses2, is_training) == 2
losses2 = Float64[1, 2, 3, 1, 2, 3, NaN, 2, 3]
@test stopping_time(NotANumber(), losses2, is_training) == 3
@test stopping_time(InvalidValue(), losses2, is_training) == 3
losses2 = Float64[1, 2, 3, 1, 2, 3, 1, 2, 3]
@test stopping_time(NotANumber(), losses2, is_training) == 0
@test stopping_time(InvalidValue(), losses2, is_training) == 0
@test_logs((:info, r"loss updates: 0"),
(:info, r"state: true"),
(:info, r"loss updates: 1"),
(:info, r"state: true"),
stopping_time(NotANumber(),
[NaN, 1],
stopping_time(InvalidValue(),
[-Inf, 1],
[true, false],
verbosity=1))
@test EarlyStopping.needs_loss(NotANumber())
@test !EarlyStopping.needs_training_losses(NotANumber())
@test EarlyStopping.needs_loss(InvalidValue())
@test !EarlyStopping.needs_training_losses(InvalidValue())
end


struct SleepyIterator{T}
iter::T
t::Float64
Expand Down Expand Up @@ -200,11 +199,48 @@ end
end

@testset "robustness to first loss being a training loss" begin
for C in subtypes(StoppingCriterion)
criteria = filter(subtypes(StoppingCriterion)) do C
C != NotANumber # as deprecated
end
for C in criteria
losses = float.(4:-1:1)
is_training = [true, true, false, false]
stopping_time(C(), losses, is_training)
end
end


# # DEPRECATED

@testset "NotANumber" begin
@test_deprecated criterion = NotANumber()
@test stopping_time(criterion, losses) == 0
N = 5
losses2 = fill(123.4, N)
@test all(reverse(eachindex(losses2))) do j
losses2[j] = NaN
stopping_time(criterion, losses2) == j
end
losses2 = Float64[1, 2, 3, 1, NaN, 3, 1, 2, 3]
is_training = Bool[1, 1, 0, 1, 1, 0, 1, 1, 0]
@test stopping_time(criterion, losses2, is_training) == 2
losses2 = Float64[1, 2, 3, 1, 2, NaN, 1, 2, 3]
@test stopping_time(criterion, losses2, is_training) == 2
losses2 = Float64[1, 2, 3, 1, 2, 3, NaN, 2, 3]
@test stopping_time(criterion, losses2, is_training) == 3
losses2 = Float64[1, 2, 3, 1, 2, 3, 1, 2, 3]
@test stopping_time(criterion, losses2, is_training) == 0
@test_logs((:info, r"loss updates: 0"),
(:info, r"state: true"),
(:info, r"loss updates: 1"),
(:info, r"state: true"),
stopping_time(criterion,
[NaN, 1],
[true, false],
verbosity=1))
@test EarlyStopping.needs_loss(criterion)
@test !EarlyStopping.needs_training_losses(criterion)
end


true
7 changes: 4 additions & 3 deletions test/disjunction.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
c1 = Patience(1)
c2 = NotANumber()
c2 = InvalidValue()
c3 = TimeLimit(t=100)

@test Disjunction(c1) == c1
Expand All @@ -20,15 +20,16 @@ show(d)
end

@testset "stoppping times" begin
d2 = Patience(3) + NotANumber()
d2 = Patience(3) + InvalidValue()
@test stopping_time(d2, [12.0, 10.0, 11.0, 12.0, 13.0, NaN]) == 5
@test stopping_time(d2, [NaN, 12.0, 10.0, 11.0, 12.0, 13.0]) == 1
end

@testset "message" begin
state = EarlyStopping.update(d, NaN)
@test EarlyStopping.message(d, state) ==
"Stopping early as NaN encountered. "
"Stopping early as `NaN`, "*
"`Inf` or `-Inf` encountered. "
end

state = EarlyStopping.update(d, 1.0)
Expand Down
4 changes: 2 additions & 2 deletions test/object_oriented_api.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ losses2 = [9.5, 9.3, 10, # 10.8 0 0 1
9.9, 9.5, 10, # 21.2 25.0 1.18 4
10.6, 10.4, 11] # 9.61 37.5 3.90 5

stopper = EarlyStopper(PQ(alpha=3.8, k=2), NotANumber())
stopper = EarlyStopper(PQ(alpha=3.8, k=2), InvalidValue())

@test !done!(stopper, losses2[1], training=true)
@test !done!(stopper, losses2[2], training=true)
Expand Down Expand Up @@ -33,7 +33,7 @@ message(stopper) == "Early stop triggered by "*
n. "

# verbose case:
stopper = EarlyStopper(NotANumber(), verbosity=1)
stopper = EarlyStopper(InvalidValue(), verbosity=1)

@test_logs (:info, r"training loss: 1.0") done!(stopper, 1.0, training=true)
@test_logs (:info, r"loss: 2.0") done!(stopper, 2.0)
Expand Down

0 comments on commit fe258fe

Please sign in to comment.