Skip to content

Commit

Permalink
Merge pull request #14 from ablaom/dev
Browse files Browse the repository at this point in the history
For a 0.1.4 release
  • Loading branch information
ablaom authored Mar 3, 2021
2 parents d5d6c21 + 3ebc904 commit 952d984
Show file tree
Hide file tree
Showing 5 changed files with 41 additions and 12 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "EarlyStopping"
uuid = "792122b4-ca99-40de-a6bc-6742525f08b6"
authors = ["Anthony D. Blaom <[email protected]>"]
version = "0.1.3"
version = "0.1.4"

[deps]
Dates = "ade2ca70-3891-5945-98fb-dc099432e06a"
Expand Down
2 changes: 1 addition & 1 deletion src/EarlyStopping.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ export StoppingCriterion,
Never, NotANumber, TimeLimit, GL, Patience, UP, PQ, NumberLimit,
Threshold, Disjunction,
criteria, stopping_time, EarlyStopper,
done!, message, needs_in_and_out_of_sample
done!, message, needs_training_losses

include("api.jl")
include("criteria.jl")
Expand Down
11 changes: 5 additions & 6 deletions src/api.jl
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,9 @@ done(::StoppingCriterion, state) = false
message(criterion::StoppingCriterion, state) = "Early stop triggered by "*
"$criterion stopping criterion. "

# to be overloaded if the criterion requires one or more
# `update_training` calls per `update` call, ie, it requires both
# "training" and "out-of-sample" losses to work:
needs_in_and_out_of_sample(::Type) = false
needs_loss(::Type) = false
needs_training_losses(::Type) = false

needs_in_and_out_of_sample(c::StoppingCriterion) =
needs_in_and_out_of_sample(typeof(c))
for trait in [:needs_loss, :needs_training_losses]
eval(:($trait(c::StoppingCriterion) = $trait(typeof(c))))
end
14 changes: 13 additions & 1 deletion src/criteria.jl
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,8 @@ done(::NotANumber, state) = state

message(::NotANumber, state) = "Stopping early as NaN encountered. "

needs_loss(::Type{<:NotANumber}) = true


## TIME LIMIT

Expand Down Expand Up @@ -129,6 +131,8 @@ update(criterion::GL, loss, ::Nothing) = update(criterion, loss)
done(criterion::GL, state) =
generalization_loss(state.loss, state.min_loss) > criterion.alpha

needs_loss(::Type{<:GL}) = true


## PQ

Expand Down Expand Up @@ -257,7 +261,8 @@ function done(criterion::PQ, state)
return PQ > criterion.alpha
end

needs_in_and_out_of_sample(::Type{<:PQ}) = true
needs_loss(::Type{<:PQ}) = true
needs_training_losses(::Type{<:PQ}) = true


## PATIENCE
Expand Down Expand Up @@ -302,6 +307,11 @@ update(criterion::Patience, loss, ::Nothing) = update(criterion, loss)

done(criterion::Patience, state) = state.n_increases == criterion.n

needs_loss(::Type{<:Patience}) = true


# # NUMBER LIMIT

"""
NumberLimit(; n=100)
Expand Down Expand Up @@ -355,3 +365,5 @@ update(criterion::Threshold, loss, state) = loss
update(criterion::Threshold, loss, ::Nothing) = loss

done(criterion::Threshold, state) = state < criterion.value

needs_loss(::Type{<:Threshold}) = true
24 changes: 21 additions & 3 deletions test/criteria.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@ losses = Float64[10, 8, 9, 10, 11, 12, 12, 13, 14, 15, 16, 17, 16]

@testset "Never" begin
@test stopping_time(Never(), losses) == 0
@test !EarlyStopping.needs_loss(Never())
@test !EarlyStopping.needs_training_losses(Never())
end

@testset "NotANumber" begin
Expand Down Expand Up @@ -32,6 +34,8 @@ end
[NaN, 1],
[true, false],
verbosity=1))
@test EarlyStopping.needs_loss(NotANumber())
@test !EarlyStopping.needs_training_losses(NotANumber())
end


Expand All @@ -53,11 +57,11 @@ Base.iterate(iter::SleepyIterator, state) =
@test stopping_time(TimeLimit(t=Millisecond(600)), sleepy_losses) == 7
# codecov:
@test EarlyStopping.update_training(TimeLimit(), 42.0) <= now()
@test !EarlyStopping.needs_loss(TimeLimit())
@test !EarlyStopping.needs_training_losses(TimeLimit())
end

@testset "GL" begin
@test !EarlyStopping.needs_in_and_out_of_sample(GL())

# constructor:
@test_throws ArgumentError GL(alpha=0)
@test GL(alpha=1).alpha === 1.0
Expand All @@ -76,6 +80,9 @@ end
@test stopping_time(GL(alpha=90), losses) == 11
@test stopping_time(GL(alpha=110), losses) == 12
@test stopping_time(GL(alpha=1000), losses) == 0

@test EarlyStopping.needs_loss(GL())
@test !EarlyStopping.needs_training_losses(GL())
end

@testset "PQ" begin
Expand All @@ -89,7 +96,6 @@ end
@test_throws ArgumentError PQ(k=1)

c = PQ(alpha=10, k=2)
@test EarlyStopping.needs_in_and_out_of_sample(c)

# first update must be training:
@test_throws Exception EarlyStopping.update(c, 1.0)
Expand Down Expand Up @@ -139,6 +145,9 @@ end
@test stopping_time(PQ(alpha=11.6, k=2), losses2, is_training) == 6
@test stopping_time(PQ(alpha=15.1, k=2), losses2, is_training) == 8
@test stopping_time(PQ(alpha=15.3, k=2), losses2, is_training) == 0

@test EarlyStopping.needs_loss(PQ())
@test EarlyStopping.needs_training_losses(PQ())
end

@testset "Patience" begin
Expand All @@ -149,6 +158,9 @@ end
@test stopping_time(Patience(n=3), losses) == 5
@test stopping_time(Patience(n=2), losses) == 4
@test stopping_time(Patience(n=1), losses) == 3

@test EarlyStopping.needs_loss(Patience())
@test !EarlyStopping.needs_training_losses(Patience())
end

@testset "NumberLimit" begin
Expand All @@ -157,11 +169,17 @@ end
for i in 1:length(losses)
@test stopping_time(NumberLimit(i), losses) == i
end

@test !EarlyStopping.needs_loss(NumberLimit())
@test !EarlyStopping.needs_training_losses(NumberLimit())

end

@testset "Threshold" begin
@test Threshold().value == 0.0
stopping_time(Threshold(2.5), Float64[12, 32, 3, 2, 5, 7]) == 4
@test EarlyStopping.needs_loss(Threshold())
@test !EarlyStopping.needs_training_losses(Threshold())
end

@testset "robustness to first loss being a training loss" begin
Expand Down

0 comments on commit 952d984

Please sign in to comment.