Skip to content

Commit

Permalink
clarify need for obs to be involutive
Browse files Browse the repository at this point in the history
  • Loading branch information
ablaom committed Oct 30, 2024
1 parent 6f436ef commit 2b11e6b
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 13 deletions.
23 changes: 17 additions & 6 deletions docs/src/anatomy_of_an_implementation.md
Original file line number Diff line number Diff line change
Expand Up @@ -420,10 +420,21 @@ LearnAPI.fit(learner::Ridge, data; kwargs...) =

### The `obs` contract

Providing `fit` signatures matching the output of `obs`, is the first part of the `obs`
contract. The second part is this: *The output of `obs` must implement the interface
specified by the trait* [`LearnAPI.data_interface(learner)`](@ref). Assuming this is
[`LearnAPI.RandomAccess()`](@ref) (the default) it usually suffices to overload
Providing `fit` signatures matching the output of [`obs`](@ref), is the first part of the
`obs` contract. Since `obs(learner, data)` should evidentally support all `data` that
`fit(learner, data)` supports, we must be able to apply `obs(learner, _)` to it's own
output (`observations` below). This leads to the additional "no-op" declaration

```@example anatomy2
LearnAPI.obs(::Ridge, observations::RidgeFitObs) = observations
```

In other words, we ensure that `obs(learner, _)` is
[involutive](https://en.wikipedia.org/wiki/Involution_(mathematics)).

The second part of the `obs` contract is this: *The output of `obs` must implement the
interface specified by the trait* [`LearnAPI.data_interface(learner)`](@ref). Assuming
this is [`LearnAPI.RandomAccess()`](@ref) (the default) it usually suffices to overload
`Base.getindex` and `Base.length`:

```@example anatomy2
Expand All @@ -432,11 +443,11 @@ Base.getindex(data::RidgeFitObs, I) =
Base.length(data::RidgeFitObs) = length(data.y)
```

We can do something similar for `predict`, but there's no need for a new type in this
case:
We do something similar for `predict`, but there's no need for a new type in this case:

```@example anatomy2
LearnAPI.obs(::RidgeFitted, Xnew) = Tables.matrix(Xnew)'
LearnAPI.obs(::RidgeFitted, observations::AbstractArray) = observations # involutivity
LearnAPI.predict(model::RidgeFitted, ::Point, observations::AbstractMatrix) =
observations'*model.coefficients
Expand Down
27 changes: 20 additions & 7 deletions src/obs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -54,8 +54,19 @@ For each supported form of `data` in `fit(learner, data)`, it must be true that
fit(learner, observations)` is equivalent to `model = fit(learner, data)`, whenever
`observations = obs(learner, data)`. For each supported form of `data` in calls
`predict(model, ..., data)` and `transform(model, data)`, where implemented, the calls
`predict(model, ..., observations)` and `transform(model, observations)` are supported
alternatives, whenever `observations = obs(model, data)`.
`predict(model, ..., observations)` and `transform(model, observations)` must be supported
alternatives with the same output, whenever `observations = obs(model, data)`.
Implicit in the above requirements is that `obs(learner, _)` and `obs(model, _)` are
involutive, meaning both the following hold:
```julia
obs(learner, obs(learner, data)) == obs(learner, data)
obs(model, obs(model, data) == obs(model, obs(model, data)
```
If one overloads `obs`, one typically needs additionally overloadings to guarantee
involutivity.
The fallback for `obs` is `obs(model_or_learner, data) = data`, and the fallback for
`LearnAPI.data_interface(learner)` is `LearnAPI.RandomAccess()`. For details refer to
Expand All @@ -67,14 +78,16 @@ to be overloaded. However, the user will get no performance benefits by using `o
that case.
When overloading `obs(learner, data)` to output new model-specific representations of
data, it may be necessary to also overload [`LearnAPI.features`](@ref),
[`LearnAPI.target`](@ref) (supervised learners), and/or [`LearnAPI.weights`](@ref) (if
weights are supported), for extracting relevant parts of the representation.
data, it may be necessary to also overload [`LearnAPI.features(learner,
observations)`](@ref), [`LearnAPI.target(learner, observations)`](@ref) (supervised
learners), and/or [`LearnAPI.weights(learner, observations)`](@ref) (if weights are
supported), for each kind output `observations` of `obs(learner, data)`.
## Sample implementation
Refer to the "Anatomy of an Implementation" section of the LearnAPI.jl
[manual](https://juliaai.github.io/LearnAPI.jl/dev/).
Refer to the ["Anatomy of an
Implementation"](https://juliaai.github.io/LearnAPI.jl/dev/anatomy_of_an_implementation/#Providing-an-advanced-data-interface)
section of the LearnAPI.jl manual.
"""
Expand Down

0 comments on commit 2b11e6b

Please sign in to comment.