Skip to content

Commit

Permalink
Merge pull request #8 from ablaom/dev
Browse files Browse the repository at this point in the history
For a 0.1.2 release
  • Loading branch information
ablaom authored Mar 1, 2021
2 parents 991db07 + 2b9a7c0 commit b571ad0
Show file tree
Hide file tree
Showing 6 changed files with 49 additions and 4 deletions.
5 changes: 3 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "EarlyStopping"
uuid = "792122b4-ca99-40de-a6bc-6742525f08b6"
authors = ["Anthony D. Blaom <[email protected]>"]
version = "0.1.1"
version = "0.1.2"

[deps]
Dates = "ade2ca70-3891-5945-98fb-dc099432e06a"
Expand All @@ -11,7 +11,8 @@ Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
julia = "^1"

[extras]
InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[targets]
test = ["Test"]
test = ["InteractiveUtils", "Test"]
3 changes: 2 additions & 1 deletion src/EarlyStopping.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@ import Base.+

export StoppingCriterion,
Never, NotANumber, TimeLimit, GL, Patience, UP, PQ, NumberLimit,
Disjunction, criteria, stopping_time, EarlyStopper,
Threshold, Disjunction,
criteria, stopping_time, EarlyStopper,
done!, message, needs_in_and_out_of_sample

include("api.jl")
Expand Down
29 changes: 29 additions & 0 deletions src/criteria.jl
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,8 @@ GL(; alpha=2.0) = GL(alpha)

update(::GL, loss) = (loss=loss, min_loss=loss)
update(::GL, loss, state) = (loss=loss, min_loss=min(loss, state.min_loss))
# in case first loss consumed was a training loss:
update(criterion::GL, loss, ::Nothing) = update(criterion, loss)
done(criterion::GL, state) =
generalization_loss(state.loss, state.min_loss) > criterion.alpha

Expand Down Expand Up @@ -295,6 +297,8 @@ update(criterion::Patience, loss) = (loss=loss, n_increases=0)
end
return (loss=loss, n_increases=n)
end
# in case first loss consumed was a training loss:
update(criterion::Patience, loss, ::Nothing) = update(criterion, loss)

done(criterion::Patience, state) = state.n_increases == criterion.n

Expand Down Expand Up @@ -324,5 +328,30 @@ update(criterion::NumberLimit, loss) = 1
@inline function update(criterion::NumberLimit, loss, state)
return state+1
end
# in case first loss consumed was a training loss:
update(criterion::NumberLimit, loss, ::Nothing) = update(criterion, loss)

done(criterion::NumberLimit, state) = state == criterion.n


# ## THRESHOLD

"""
Threshold(; value=0.0)
$STOPPING_DOC
A stop is triggered as soon as the loss drops below `value`.
"""
mutable struct Threshold <: StoppingCriterion
value::Float64
end
Threshold(; value=0.0) = Threshold(value)

update(criterion::Threshold, loss) = loss
update(criterion::Threshold, loss, state) = loss
# in case first loss consumed was a training loss:
update(criterion::Threshold, loss, ::Nothing) = loss

done(criterion::Threshold, state) = state < criterion.value
1 change: 1 addition & 0 deletions src/disjunction.jl
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ struct Disjunction{A,B} <: StoppingCriterion
end
end

Disjunction() = Never()
Disjunction(a) = a
Disjunction(a, b, c...) = Disjunction(Disjunction(a,b), c...)

Expand Down
13 changes: 13 additions & 0 deletions test/criteria.jl
Original file line number Diff line number Diff line change
Expand Up @@ -159,4 +159,17 @@ end
end
end

@testset "Threshold" begin
@test Threshold().value == 0.0
stopping_time(Threshold(2.5), Float64[12, 32, 3, 2, 5, 7]) == 4
end

@testset "robustness to first loss being a training loss" begin
for C in subtypes(StoppingCriterion)
losses = float.(4:-1:1)
is_training = [true, true, false, false]
stopping_time(C(), losses, is_training)
end
end

true
2 changes: 1 addition & 1 deletion test/runtests.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
using EarlyStopping, Dates, Test
using EarlyStopping, Dates, Test, InteractiveUtils

@testset "criteria.jl" begin
include("criteria.jl")
Expand Down

0 comments on commit b571ad0

Please sign in to comment.