Skip to content

Commit

Permalink
Merge pull request #33 from JuliaAI/dev
Browse files Browse the repository at this point in the history
For a 0.3.0 release
  • Loading branch information
ablaom authored Nov 11, 2021
2 parents b8e981b + c17346c commit 4f3da13
Show file tree
Hide file tree
Showing 6 changed files with 14 additions and 67 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "EarlyStopping"
uuid = "792122b4-ca99-40de-a6bc-6742525f08b6"
authors = ["Anthony D. Blaom <[email protected]>"]
version = "0.2.0"
version = "0.3.0"

[deps]
Dates = "ade2ca70-3891-5945-98fb-dc099432e06a"
Expand Down
12 changes: 4 additions & 8 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,6 @@ To implement a new stopping criterion, one must:
- Overload methods `update` and `done` for the new type.
```julia

struct NewCriteria <: StoppingCriterion
# Put relevant fields here
end
Expand All @@ -196,7 +195,6 @@ update(c::NewCriteria, loss, state) = ...
# Return true if NewCriteria should stop given `state`.
# Always return false if `state === nothing`
done(c::NewCriteria, state) = state === nothing ? false : ....

```
Optionally, one may define the following:
Expand All @@ -206,20 +204,18 @@ Optionally, one may define the following:
`needs_training_losses`.
```julia

# Final message when NewCriteria triggers a stop
message(c::NewCriteria, state) = ...

# Methods for initializing/updating the state given a training loss
update_training(c::NewCriteria, loss, ::Nothing) = ...
update_training(c::NewCriteria, loss, state) = ...
```
# Define `needs_training_losses` iff NewCriteria needs one or more
# training losses per out-of-sample loss to work. Otherwise, this may be
# omitted
needs_training_losses(::Type{<:NewCriteria}) = true
**Wrappers.** If your criterion wraps another criterion (as `Warmup`
does) then the `criterion` must be a field and must store the
criterion being wrapped.
```
## New Criteria Example
Expand Down
4 changes: 1 addition & 3 deletions src/EarlyStopping.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,7 @@ export StoppingCriterion,
EarlyStopper,
done!,
reset!,
message,
needs_training_losses,
needs_loss
message

include("api.jl")
include("criteria.jl")
Expand Down
7 changes: 0 additions & 7 deletions src/api.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,3 @@ done(::StoppingCriterion, state) = false

message(criterion::StoppingCriterion, state) = "Stop triggered by "*
"$criterion stopping criterion. "

needs_loss(::Type) = false
needs_training_losses(::Type) = false

for trait in [:needs_loss, :needs_training_losses]
eval(:($trait(c) = $trait(typeof(c))))
end
28 changes: 8 additions & 20 deletions src/criteria.jl
Original file line number Diff line number Diff line change
Expand Up @@ -41,17 +41,20 @@ $CUSTOM_ALTERNATIVE_DOC
"""
struct InvalidValue <: StoppingCriterion end

_isinf(x) = isinf(x)
_isinf(::Nothing) = false
_isnan(x) = isnan(x)
_isnan(::Nothing) = false

# state = `true` when `NaN`, `Inf` or `-Inf` has been encountered
update(::InvalidValue, loss, state=false) =
state !== nothing && state || isinf(loss) || isnan(loss)
state !== nothing && state || _isinf(loss) || _isnan(loss)
update_training(c::InvalidValue, loss, state) = update(c, loss, state)
done(::InvalidValue, state) = state !== nothing && state

message(::InvalidValue, state) = "Stopping early as `NaN`, "*
"`Inf` or `-Inf` encountered. "

needs_loss(::Type{<:InvalidValue}) = true


## TIME LIMIT

Expand Down Expand Up @@ -134,7 +137,6 @@ function done(criterion::GL, state)
return gl > criterion.alpha
end
end
needs_loss(::Type{<:GL}) = true


## PQ
Expand Down Expand Up @@ -244,9 +246,6 @@ function done(criterion::PQ, state)
return PQ > criterion.alpha
end

needs_loss(::Type{<:PQ}) = true
needs_training_losses(::Type{<:PQ}) = true


## PATIENCE

Expand Down Expand Up @@ -291,8 +290,6 @@ end
done(criterion::Patience, state) =
state === nothing ? false : state.n_increases == criterion.n

needs_loss(::Type{<:Patience}) = true


## NUMBER SINCE BEST

Expand Down Expand Up @@ -332,8 +329,6 @@ end
done(criterion::NumberSinceBest, state) =
state === nothing ? false : state.number_since_best == criterion.n

needs_loss(::Type{<:NumberSinceBest}) = true


# # NUMBER LIMIT

Expand Down Expand Up @@ -386,8 +381,6 @@ update(::Threshold, loss, state) = loss
done(criterion::Threshold, state) =
state === nothing ? false : state < criterion.value

needs_loss(::Type{<:Threshold}) = true


"""
Warmup(c::StoppingCriterion, n)
Expand Down Expand Up @@ -417,9 +410,8 @@ update_training(c::Warmup, loss) = (1, update_training(c.criterion, loss))

# Handle update vs update_training
update(c::Warmup, loss, state) = _update(update, c, loss, state)
update_training(c::Warmup, loss, state) = _update(update_training, c, loss, state)
needs_loss(::Type{<:Warmup{C}}) where C = needs_loss(C)
needs_training_losses(::Type{<:Warmup{C}}) where C = needs_training_losses(C)
update_training(c::Warmup, loss, state) =
_update(update_training, c, loss, state)

# Dispatch update and update_training here
function _update(f::Function, criterion::Warmup, loss, state)
Expand Down Expand Up @@ -474,7 +466,3 @@ update_training(c::NotANumber, loss, state) = update(c, loss, state)
done(::NotANumber, state) = state !== nothing && state

message(::NotANumber, state) = "Stopping early as NaN encountered. "

needs_loss(::Type{<:NotANumber}) = true


28 changes: 0 additions & 28 deletions test/criteria.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,6 @@ losses = Float64[10, 8, 9, 10, 11, 12, 12, 13, 14, 15, 16, 17, 16]

@testset "Never" begin
@test stopping_time(Never(), losses) == 0
@test !EarlyStopping.needs_loss(Never())
@test !EarlyStopping.needs_training_losses(Never())
end

@testset "InvalidValue" begin
Expand All @@ -28,9 +26,6 @@ end
losses2[n] = NaN
@test stopping_time(InvalidValue(), losses2, is_training) == n_stop
end

@test EarlyStopping.needs_loss(InvalidValue())
@test !EarlyStopping.needs_training_losses(InvalidValue())
end

struct SleepyIterator{T}
Expand All @@ -51,8 +46,6 @@ Base.iterate(iter::SleepyIterator, state) =
@test stopping_time(TimeLimit(t=Millisecond(600)), sleepy_losses) == 7
# codecov:
@test EarlyStopping.update_training(TimeLimit(), 42.0) <= now()
@test !EarlyStopping.needs_loss(TimeLimit())
@test !EarlyStopping.needs_training_losses(TimeLimit())
end

@testset "GL" begin
Expand All @@ -74,9 +67,6 @@ end
@test stopping_time(GL(alpha=90), losses) == 11
@test stopping_time(GL(alpha=110), losses) == 12
@test stopping_time(GL(alpha=1000), losses) == 0

@test EarlyStopping.needs_loss(GL())
@test !EarlyStopping.needs_training_losses(GL())
end

@testset "PQ" begin
Expand Down Expand Up @@ -133,9 +123,6 @@ end
@test stopping_time(PQ(alpha=11.6, k=2), losses2, is_training) == 6
@test stopping_time(PQ(alpha=15.1, k=2), losses2, is_training) == 8
@test stopping_time(PQ(alpha=15.3, k=2), losses2, is_training) == 0

@test EarlyStopping.needs_loss(PQ())
@test EarlyStopping.needs_training_losses(PQ())
end

@testset "Patience" begin
Expand All @@ -146,9 +133,6 @@ end
@test stopping_time(Patience(n=3), losses) == 5
@test stopping_time(Patience(n=2), losses) == 4
@test stopping_time(Patience(n=1), losses) == 3

@test EarlyStopping.needs_loss(Patience())
@test !EarlyStopping.needs_training_losses(Patience())
end

@testset "NumberSinceBest" begin
Expand All @@ -163,9 +147,6 @@ end
losses2 = Float64[10, 9, 8, 9, 10, 7, 10, 10, 10, 10]
@test stopping_time(NumberSinceBest(n=2), losses2) == 5
@test stopping_time(NumberSinceBest(n=3), losses2) == 9

@test EarlyStopping.needs_loss(NumberSinceBest())
@test !EarlyStopping.needs_training_losses(NumberSinceBest())
end

@testset "NumberLimit" begin
Expand All @@ -174,17 +155,11 @@ end
for i in 1:length(losses)
@test stopping_time(NumberLimit(i), losses) == i
end

@test !EarlyStopping.needs_loss(NumberLimit())
@test !EarlyStopping.needs_training_losses(NumberLimit())

end

@testset "Threshold" begin
@test Threshold().value == 0.0
stopping_time(Threshold(2.5), Float64[12, 32, 3, 2, 5, 7]) == 4
@test EarlyStopping.needs_loss(Threshold())
@test !EarlyStopping.needs_training_losses(Threshold())
end

@testset "robustness to first loss being a training loss" begin
Expand Down Expand Up @@ -253,9 +228,6 @@ end
losses2 = copy(losses); losses2[n] = NaN
@test stopping_time(criterion, losses2, is_training) == sum(!, is_training[1:n])
end

@test EarlyStopping.needs_loss(criterion)
@test !EarlyStopping.needs_training_losses(criterion)
end


Expand Down

0 comments on commit 4f3da13

Please sign in to comment.