Skip to content

Commit

Permalink
Merge pull request #27 from JuliaAI/dev
Browse files Browse the repository at this point in the history
For a 0.1.9 release
  • Loading branch information
ablaom authored Oct 18, 2021
2 parents fe258fe + 3acfbbb commit d2b5a0a
Show file tree
Hide file tree
Showing 5 changed files with 87 additions and 4 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "EarlyStopping"
uuid = "792122b4-ca99-40de-a6bc-6742525f08b6"
authors = ["Anthony D. Blaom <[email protected]>"]
version = "0.1.8"
version = "0.1.9"

[deps]
Dates = "ade2ca70-3891-5945-98fb-dc099432e06a"
Expand Down
7 changes: 4 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,13 @@

| Linux | Coverage |
| :-----------: | :------: |
| [![Build status](https://github.com/ablaom/EarlyStopping.jl/workflows/CI/badge.svg)](https://github.com/ablaom/EarlyStopping.jl/actions)| [![codecov.io](http://codecov.io/github/ablaom/EarlyStopping.jl/coverage.svg?branch=master)](http://codecov.io/github/ablaom/EarlyStopping.jl?branch=master) |
| [![Build status](https://github.com/JuliaAI/EarlyStopping.jl/workflows/CI/badge.svg)](https://github.com/JuliaAI/EarlyStopping.jl/actions)| [![codecov.io](http://codecov.io/github/JuliaAI/EarlyStopping.jl/coverage.svg?branch=master)](http://codecov.io/github/JuliaAI/EarlyStopping.jl?branch=master) |

A small package for applying early stopping criteria to
loss-generating iterative algorithms, with a view
to training and optimizing machine learning models.

The basis of [IterationControl.jl](https://github.com/ablaom/IterationControl.jl),
The basis of [IterationControl.jl](https://github.com/JuliaAI/IterationControl.jl),
a package externally controlling iterative algorithms.

Includes the stopping criteria surveyed in [Prechelt, Lutz
Expand Down Expand Up @@ -87,11 +87,12 @@ criterion | description |
`TimeLimit(t=0.5)` | Stop after `t` hours |
`NumberLimit(n=100)` | Stop after `n` loss updates (excl. "training losses") |
`NumberSinceBest(n=6)`| Stop after `n` loss updates (excl. "training losses") |
`Threshold(value=0.0)`| Stop when `loss < value` |
`Threshold(value=0.0)`| Stop when `loss < value` |
`GL(alpha=2.0)` | Stop after "Generalization Loss" exceeds `alpha` | ``GL_α``
`PQ(alpha=0.75, k=5)` | Stop after "Progress-modified GL" exceeds `alpha` | ``PQ_α``
`Patience(n=5)` | Stop after `n` consecutive loss increases | ``UP_s``
`Disjunction(c...)` | Stop when any of the criteria `c` apply |
`Warmup(c; n=1)` | Wait for `n` loss updates before checking criteria `c`|


## Criteria tracking both training and out-of-sample losses
Expand Down
1 change: 1 addition & 0 deletions src/EarlyStopping.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ export StoppingCriterion,
NumberLimit,
Threshold,
Disjunction,
Warmup,
criteria,
stopping_time,
EarlyStopper,
Expand Down
50 changes: 50 additions & 0 deletions src/criteria.jl
Original file line number Diff line number Diff line change
Expand Up @@ -423,10 +423,60 @@ done(criterion::Threshold, state) = state < criterion.value
needs_loss(::Type{<:Threshold}) = true


"""
Warmup(c::StoppingCriterion, n)
Wait for `n` updates before checking stopping criterion `c`
"""
struct Warmup{C} <: StoppingCriterion where {C <: StoppingCriterion}
criterion::C
n::Int
function Warmup(criterion::C, n::N) where {C <: StoppingCriterion, N <: Integer}
n > 0 || throw(ArgumentError("`n` must be positive. "))
new{C}(criterion, Int(n))
end
end

# Constructors for Warmup
Warmup() = Warmup(InvalidValue()) # Default for testing
Warmup(c; n = 1) = Warmup(c, n) # Provide kwargs interface

# Initialize inner state for type-stability, and record first observation
update(c::Warmup, loss, ::Nothing) = update(c, loss)
update(criterion::Warmup, loss) = (1, update(criterion.criterion, loss))

# Catch uninitialized state
update_training(c::Warmup, loss, ::Nothing) = update_training(c, loss)
update_training(c::Warmup, loss) = (1, update_training(c.criterion, loss))

# Handle update vs update_training
update(c::Warmup, loss, state) = _update(update, c, loss, state)
update_training(c::Warmup, loss, state) = _update(update_training, c, loss, state)
needs_loss(::Type{<:Warmup{C}}) where C = needs_loss(C)
needs_training_losses(::Type{<:Warmup{C}}) where C = needs_training_losses(C)

# Dispatch update and update_training here
function _update(f::Function, criterion::Warmup, loss, state)
n, inner = state
n += 1
if n <= criterion.n
# Skip inner criterion
return n, inner
elseif n == criterion.n+1
# First step of inner criterion
return n, f(criterion.criterion, loss)
else
# Step inner criterion
return n, f(criterion.criterion, loss, inner)
end
end

function done(criterion::Warmup, state)
# Only check if inner criterion is done after n updates
return state[1] <= criterion.n ? false : done(criterion.criterion, state[2])
end

message(c::Warmup, state) = message(c.criterion, state[2])


## NOT A NUMBER (deprecated)
Expand Down
31 changes: 31 additions & 0 deletions test/criteria.jl
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,37 @@ end
end
end

@testset "Warmup" begin
@test_throws ArgumentError Warmup(Patience(), 0)
for n in 1:(length(losses)-1)
@test stopping_time(Warmup(NumberLimit(1), n), losses) == n+1
end

# Test message
@testset "message" begin
stopper = Warmup(Patience(2); n = 2)
stopper_ref = Warmup(Patience(2), 2)
state, state_ref = nothing, nothing
for loss = losses
state = EarlyStopping.update(stopper, loss, state)
state_ref = EarlyStopping.update(stopper_ref, loss, state_ref)
@test message(stopper, state) == message(stopper_ref, state_ref)
end
end

@testset "training" begin
stopper = Warmup(PQ(), 3)
is_training = @show map(x -> x%3 > 0, 1:length(losses))

# Feed 2 training losses + 1 non-training to criteria with/without
stop_time = stopping_time(stopper, losses, is_training)
ref_stop_time = stopping_time(stopper.criterion, losses[3:end], is_training)

# PQ only counts training loss updates
@test round(stop_time/3, RoundUp) == ref_stop_time
end
end


# # DEPRECATED

Expand Down

0 comments on commit d2b5a0a

Please sign in to comment.