Skip to content
This repository has been archived by the owner on May 23, 2022. It is now read-only.

Commit

Permalink
Switch to positional interface
Browse files Browse the repository at this point in the history
  • Loading branch information
darsnack committed Oct 20, 2021
1 parent b89c1cb commit b15e1f4
Show file tree
Hide file tree
Showing 4 changed files with 23 additions and 188 deletions.
9 changes: 4 additions & 5 deletions src/LearnBase.jl
Original file line number Diff line number Diff line change
@@ -1,17 +1,16 @@
module LearnBase

import StatsBase
using StatsBase: nobs
import StatsBase: nobs

# AGGREGATION MODES
include("aggmode.jl")

# VIEW AND ITERATORS
include("iteration.jl")

# OBSERVATION DIMENSIONS
include("observation.jl")

# LABEL ENCONDINGS
include("labels.jl")

# LEARNING COSTS (e.g. loss & penalty)
include("costs.jl")

Expand Down
82 changes: 0 additions & 82 deletions src/iteration.jl

This file was deleted.

106 changes: 13 additions & 93 deletions src/observation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ default_obsdim(A::AbstractArray) = ndims(A)
default_obsdim(tup::Tuple) = map(default_obsdim, tup)

"""
getobs(data, idx; obsdim = default_obsdim(data))
getobs(data, idx, obsdim = default_obsdim(data))
Return the observations corresponding to the observation-index `idx`.
Note that `idx` can be of type `Int` or `AbstractVector`.
Expand All @@ -24,7 +24,7 @@ The returned observation(s) should be in the form intended to
be passed as-is to some learning algorithm. There is no strict
interface requirement on how this "actual data" must look like.
Every author behind some custom data container can make this
decision himself/herself. We do, however, expect it to be consistent
decision themselves. We do, however, expect it to be consistent
for `idx` being an integer, as well as `idx` being an abstract
vector, respectively.
Expand All @@ -33,10 +33,11 @@ to indicate which dimension of `data` denotes the observations.
See [`default_obsdim`](@ref) for defining a default dimension.
"""
function getobs end
getobs(data, idx; obsdim) = getobs(data, idx)
getobs(data, idx) = data[idx]
getobs(data, idx, obsdim) = getobs(data, idx)

"""
getobs!(buffer, data, idx; obsdim = default_obsdim(obsdim))
getobs!(buffer, data, idx, obsdim = default_obsdim(obsdim))
Inplace version of `getobs(data, idx; obsdim)`. If this method
is defined for the type of `data`, then `buffer` should be used
Expand All @@ -54,7 +55,7 @@ to indicate which dimension of `data` denotes the observations.
See [`default_obsdim`](@ref) for defining a default dimension.
"""
function getobs! end
getobs!(buffer, data, idx; obsdim = default_obsdim(data)) = getobs(data, idx; obsdim = obsdim)
getobs!(buffer, data, idx, obsdim = default_obsdim(data)) = getobs(data, idx, obsdim)

# --------------------------------------------------------------------

Expand Down Expand Up @@ -94,95 +95,14 @@ function targets end

# --------------------------------------------------------------------

"""
abstract DataView{TElem, TData} <: AbstractVector{TElem}
Baseclass for all vector-like views of some data structure.
This allow for example to see some design matrix as a vector of
individual observation-vectors instead of one matrix.
see `MLDataPattern.ObsView` and `MLDataPattern.BatchView` for examples.
"""
abstract type DataView{TElem, TData} <: AbstractVector{TElem} end

"""
abstract AbstractObsView{TElem, TData} <: DataView{TElem, TData}
Baseclass for all vector-like views of some data structure,
that views it as some form or vector of observations.
see `MLDataPattern.ObsView` for a concrete example.
"""
abstract type AbstractObsView{TElem, TData} <: DataView{TElem, TData} end

"""
abstract AbstractBatchView{TElem, TData} <: DataView{TElem, TData}
Baseclass for all vector-like views of some data structure,
that views it as some form or vector of equally sized batches.
see `MLDataPattern.BatchView` for a concrete example.
"""
abstract type AbstractBatchView{TElem, TData} <: DataView{TElem, TData} end

# --------------------------------------------------------------------

"""
abstract DataIterator{TElem,TData}
Baseclass for all types that iterate over a `data` source
in some manner. The total number of observations may or may
not be known or defined and in general there is no contract that
`getobs` or `nobs` has to be supported by the type of `data`.
Furthermore, `length` should be used to query how many elements
the iterator can provide, while `nobs` may return the underlying
true amount of observations available (if known).
see `MLDataPattern.RandomObs`, `MLDataPattern.RandomBatches`
"""
abstract type DataIterator{TElem,TData} end

"""
abstract ObsIterator{TElem,TData} <: DataIterator{TElem,TData}
Baseclass for all types that iterate over some data source
one observation at a time.
```julia
using MLDataPattern
@assert typeof(RandomObs(X)) <: ObsIterator
for x in RandomObs(X)
# ...
end
```
abstract type AbstractDataContainer end

see `MLDataPattern.RandomObs`
"""
abstract type ObsIterator{TElem,TData} <: DataIterator{TElem,TData} end

"""
abstract BatchIterator{TElem,TData} <: DataIterator{TElem,TData}
Baseclass for all types that iterate over of some data source one
batch at a time.
```julia
@assert typeof(RandomBatches(X, size=10)) <: BatchIterator
for x in RandomBatches(X, size=10)
@assert nobs(x) == 10
# ...
end
```
see `MLDataPattern.RandomBatches`
"""
abstract type BatchIterator{TElem,TData} <: DataIterator{TElem,TData} end
Base.getindex(x::AbstractDataContainer, i) = getobs(x, i, default_obsdim(x))
Base.iterate(x::AbstractDataContainer, state = 1) = getobs(x, state, default_obsdim(x)), state + 1

# --------------------------------------------------------------------

# just for dispatch for those who care to
const AbstractDataIterator{E,T} = Union{DataIterator{E,T}, DataView{E,T}}
const AbstractObsIterator{E,T} = Union{ObsIterator{E,T}, AbstractObsView{E,T}}
const AbstractBatchIterator{E,T} = Union{BatchIterator{E,T},AbstractBatchView{E,T}}
# Might need this distinction later
# e.g. shuffleobs can be anywhere in pipeline but
# eachbatch is usually at the end
abstract type AbstractDataIterator <: AbstractDataContainer end
14 changes: 6 additions & 8 deletions test/observation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,9 @@ using LearnBase: getobs, nobs, default_obsdim
@test @inferred getobs(a, 1) == a[:,1]
@test @inferred getobs(a, 2) == a[:,2]
@test @inferred getobs(a, 1:2) == a[:,1:2]
@test @inferred getobs(a, 1, obsdim=1) == a[1,:]
@test @inferred getobs(a, 2, obsdim=1) == a[2,:]
@test @inferred getobs(a, 2, obsdim=nothing) a[:,2]
@test @inferred getobs(a, 1, 1) == a[1,:]
@test @inferred getobs(a, 2, 1) == a[2,:]
@test @inferred getobs(a, 2, nothing) a[:,2]
end

@testset "tuple" begin
Expand All @@ -36,7 +36,7 @@ using LearnBase: getobs, nobs, default_obsdim
else
o = getobs(dataset, 1:2)
end

@test o[1] == X[:,1:2]
@test o[2] == Y[1:2]
end
Expand All @@ -53,7 +53,7 @@ using LearnBase: getobs, nobs, default_obsdim
end
@test o.x == X[:,2]
@test o.y == Y[2]

if VERSION >= v"1.6"
o = @inferred getobs(dataset, 1:2)
else
Expand All @@ -67,7 +67,7 @@ using LearnBase: getobs, nobs, default_obsdim
X, Y = rand(2, 3), rand(3)
dataset = Dict("X" => X, "Y" => Y)
@test nobs(dataset) == 3

# o = @inferred getobs(dataset, 2) # not inferred
o = getobs(dataset, 2)
@test o["X"] == X[:,2]
Expand All @@ -78,5 +78,3 @@ using LearnBase: getobs, nobs, default_obsdim
@test o["Y"] == Y[1:2]
end
end


0 comments on commit b15e1f4

Please sign in to comment.