Skip to content

Commit

Permalink
Merge pull request #29 from awadell1/nothing_state
Browse files Browse the repository at this point in the history
Codify `state == nothing` is an uninitialized state
  • Loading branch information
ablaom authored Oct 27, 2021
2 parents 3acfbbb + 8277556 commit c4cfe9e
Show file tree
Hide file tree
Showing 14 changed files with 325 additions and 244 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "EarlyStopping"
uuid = "792122b4-ca99-40de-a6bc-6742525f08b6"
authors = ["Anthony D. Blaom <[email protected]>"]
version = "0.1.9"
version = "0.2.0"

[deps]
Dates = "ade2ca70-3891-5945-98fb-dc099432e06a"
Expand Down
93 changes: 77 additions & 16 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,9 @@ For criteria tracking both an "out-of-sample" loss and a "training"
loss (eg, stopping criterion of type `PQ`), specify `training=true` if
the update is for training, as in

done!(stopper, 0.123, training=true)
``` julia
done!(stopper, 0.123, training=true)
```

In these cases, the out-of-sample update must always come after the
corresponding training update. Multiple training updates may precede
Expand Down Expand Up @@ -171,25 +173,66 @@ stopping_time(PQ(),
To implement a new stopping criterion, one must:
- Define a new `struct` for the criterion, which must subtype
`StoppingCriterion`.
`StoppingCriterion`.
- Overload methods `update` and `done` for the new type.
- Optionally overload methods `message`.
```julia

struct NewCriteria <: StoppingCriterion
# Put relevant fields here
end

# Provide a default constructor with all key-word arguments
NewCriteria(; kwargs...) = ...

# Return the initial state of the NewCriteria after
# receiving an out-of-sample loss
update(c::NewCriteria, loss, ::Nothing) = ...

- Optionally overload `update_training` and the trait
# Return an updated state for NewCriteria given a `loss`
# and the current `state`
update(c::NewCriteria, loss, state) = ...

# Return true if NewCriteria should stop given `state`.
# Always return false if `state === nothing`
done(c::NewCriteria, state) = state === nothing ? false : ....

```
Optionally, one may define the following:
- Overload the final message with `message`.
- Handle training losses by overloading `update_training` and the trait
`needs_training_losses`.
```julia

# Final message when NewCriteria triggers a stop
message(c::NewCriteria, state) = ...

# Methods for initializing/updating the state given a training loss
update_training(c::NewCriteria, loss, ::Nothing) = ...
update_training(c::NewCriteria, loss, state) = ...

# Define `needs_training_losses` iff NewCriteria needs one or more
# training losses per out-of-sample loss to work. Otherwise, this may be
# omitted
needs_training_losses(::Type{<:NewCriteria}) = true

```
## New Criteria Example
We demonstrate this with a simplified version of the
[code](/src/criteria.jl) for `Patience`:
[code](./src/criteria.jl) for `Patience`:
### Defining the new type
```julia
using EarlyStopping

mutable struct Patience <: StoppingCriterion
struct Patience <: StoppingCriterion
n::Int
end
Patience(; n=5) = Patience(n)
Expand All @@ -199,15 +242,20 @@ Patience(; n=5) = Patience(n)
All information to be "remembered" must passed around in an object
called `state` below, which is the return value of `update` (and
`update_training`). The `update` function has two methods - one for
initialization, without a `state` argument, and one for all subsequent
loss updates, which requires the `state` returned by the preceding
`update` (or `update_training`) call:
`update_training`). The `update` function has two methods:
- Initialization: `update(c::NewCriteria, loss, ::Nothing)`
- Subsequent Loss Updates: `update(c::NewCriteria, loss, state)`
Where `state` is the return of the previous call to `update` or `update_training`.
Notice, that `state === nothing` indicates an uninitialized criteria.
```julia
import EarlyStopping: update, done

update(criterion::Patience, loss) = (loss=loss, n_increases=0) # state
function update(criterion::Patience, loss, ::Nothing)
return (loss=loss, n_increases=0) # state
end

function update(criterion::Patience, loss, state)
old_loss, n = state
Expand All @@ -220,15 +268,18 @@ function update(criterion::Patience, loss, state)
end
```
The `done` method returns `true` or `false` depending on the `state`:
The `done` method returns `true` or `false` depending on the `state`, but
always returns `false` if `state === nothing`.
```julia
done(criterion::Patience, state) = state.n_increases == criterion.n
done(criterion::Patience, state) =
state === nothing ? false : state.n_increases == criterion.n

```
### Optional methods
The final message of an `EarlyStopper` is generated by a `message`
The final message of an `EarlyStopper` is generated using a `message`
method for `StoppingCriterion`. Here is the fallback (which does not
use `state`):
Expand All @@ -239,7 +290,7 @@ EarlyStopping.message(criteria::StoppingCriterion, state)
The optional `update_training` methods (two for each criterion) have
the same signature as the `update` methods above. Refer to the `PQ`
[code](/src/criteria.jl) for an example.
[code](./src/criteria.jl) for an example.
If a stopping criterion requires one or more `update_training` calls
per `update` call to work, you should overload the trait
Expand All @@ -249,3 +300,13 @@ the source code:
```julia
EarlyStopping.needs_training_losses(::Type{<:PQ}) = true
```
### Unit Testing
The following are provided to facilitate testing of new criteria:
- `stopping_time`: returns the stopping time for an iterator
`losses` using `criterion`.
- [`@test_criteria NewCriteria()`](./test/smoke.jl): Runs a suite of unit tests
against the provided `StoppingCriteria`. This macro is only part of the
test suite and is not part of the API.
3 changes: 2 additions & 1 deletion src/EarlyStopping.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,14 +22,15 @@ export StoppingCriterion,
stopping_time,
EarlyStopper,
done!,
reset!,
message,
needs_training_losses,
needs_loss

include("api.jl")
include("criteria.jl")
include("disjunction.jl")
include("stopping_time.jl")
include("object_oriented_api.jl")
include("stopping_time.jl")

end # module
12 changes: 2 additions & 10 deletions src/api.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,8 @@ abstract type StoppingCriterion end


## FALL BACK METHODS

# initialization call is either:
update(::StoppingCriterion, loss) = nothing # state

# ... or:
update_training(::StoppingCriterion, loss) = nothing # state

# subsequent updating:
update(::StoppingCriterion, loss, state) = state
update_training(::StoppingCriterion, loss, state) = state
update(::StoppingCriterion, loss, state=nothing) = state
update_training(::StoppingCriterion, loss, state=nothing) = state

# returns whether it's time to stop:
done(::StoppingCriterion, state) = false
Expand Down
Loading

0 comments on commit c4cfe9e

Please sign in to comment.