Skip to content
This repository was archived by the owner on May 23, 2022. It is now read-only.

Commit b15e1f4

Browse files
committed
Switch to positional interface
1 parent b89c1cb commit b15e1f4

File tree

4 files changed

+23
-188
lines changed

4 files changed

+23
-188
lines changed

src/LearnBase.jl

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,16 @@
11
module LearnBase
22

3-
import StatsBase
4-
using StatsBase: nobs
3+
import StatsBase: nobs
54

65
# AGGREGATION MODES
76
include("aggmode.jl")
87

9-
# VIEW AND ITERATORS
10-
include("iteration.jl")
11-
128
# OBSERVATION DIMENSIONS
139
include("observation.jl")
1410

11+
# LABEL ENCONDINGS
12+
include("labels.jl")
13+
1514
# LEARNING COSTS (e.g. loss & penalty)
1615
include("costs.jl")
1716

src/iteration.jl

Lines changed: 0 additions & 82 deletions
This file was deleted.

src/observation.jl

Lines changed: 13 additions & 93 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ default_obsdim(A::AbstractArray) = ndims(A)
1414
default_obsdim(tup::Tuple) = map(default_obsdim, tup)
1515

1616
"""
17-
getobs(data, idx; obsdim = default_obsdim(data))
17+
getobs(data, idx, obsdim = default_obsdim(data))
1818
1919
Return the observations corresponding to the observation-index `idx`.
2020
Note that `idx` can be of type `Int` or `AbstractVector`.
@@ -24,7 +24,7 @@ The returned observation(s) should be in the form intended to
2424
be passed as-is to some learning algorithm. There is no strict
2525
interface requirement on how this "actual data" must look like.
2626
Every author behind some custom data container can make this
27-
decision himself/herself. We do, however, expect it to be consistent
27+
decision themselves. We do, however, expect it to be consistent
2828
for `idx` being an integer, as well as `idx` being an abstract
2929
vector, respectively.
3030
@@ -33,10 +33,11 @@ to indicate which dimension of `data` denotes the observations.
3333
See [`default_obsdim`](@ref) for defining a default dimension.
3434
"""
3535
function getobs end
36-
getobs(data, idx; obsdim) = getobs(data, idx)
36+
getobs(data, idx) = data[idx]
37+
getobs(data, idx, obsdim) = getobs(data, idx)
3738

3839
"""
39-
getobs!(buffer, data, idx; obsdim = default_obsdim(obsdim))
40+
getobs!(buffer, data, idx, obsdim = default_obsdim(obsdim))
4041
4142
Inplace version of `getobs(data, idx; obsdim)`. If this method
4243
is defined for the type of `data`, then `buffer` should be used
@@ -54,7 +55,7 @@ to indicate which dimension of `data` denotes the observations.
5455
See [`default_obsdim`](@ref) for defining a default dimension.
5556
"""
5657
function getobs! end
57-
getobs!(buffer, data, idx; obsdim = default_obsdim(data)) = getobs(data, idx; obsdim = obsdim)
58+
getobs!(buffer, data, idx, obsdim = default_obsdim(data)) = getobs(data, idx, obsdim)
5859

5960
# --------------------------------------------------------------------
6061

@@ -94,95 +95,14 @@ function targets end
9495

9596
# --------------------------------------------------------------------
9697

97-
"""
98-
abstract DataView{TElem, TData} <: AbstractVector{TElem}
99-
100-
Baseclass for all vector-like views of some data structure.
101-
This allow for example to see some design matrix as a vector of
102-
individual observation-vectors instead of one matrix.
103-
104-
see `MLDataPattern.ObsView` and `MLDataPattern.BatchView` for examples.
105-
"""
106-
abstract type DataView{TElem, TData} <: AbstractVector{TElem} end
107-
108-
"""
109-
abstract AbstractObsView{TElem, TData} <: DataView{TElem, TData}
110-
111-
Baseclass for all vector-like views of some data structure,
112-
that views it as some form or vector of observations.
113-
114-
see `MLDataPattern.ObsView` for a concrete example.
115-
"""
116-
abstract type AbstractObsView{TElem, TData} <: DataView{TElem, TData} end
117-
118-
"""
119-
abstract AbstractBatchView{TElem, TData} <: DataView{TElem, TData}
120-
121-
Baseclass for all vector-like views of some data structure,
122-
that views it as some form or vector of equally sized batches.
123-
124-
see `MLDataPattern.BatchView` for a concrete example.
125-
"""
126-
abstract type AbstractBatchView{TElem, TData} <: DataView{TElem, TData} end
127-
128-
# --------------------------------------------------------------------
129-
130-
"""
131-
abstract DataIterator{TElem,TData}
132-
133-
Baseclass for all types that iterate over a `data` source
134-
in some manner. The total number of observations may or may
135-
not be known or defined and in general there is no contract that
136-
`getobs` or `nobs` has to be supported by the type of `data`.
137-
Furthermore, `length` should be used to query how many elements
138-
the iterator can provide, while `nobs` may return the underlying
139-
true amount of observations available (if known).
140-
141-
see `MLDataPattern.RandomObs`, `MLDataPattern.RandomBatches`
142-
"""
143-
abstract type DataIterator{TElem,TData} end
144-
145-
"""
146-
abstract ObsIterator{TElem,TData} <: DataIterator{TElem,TData}
147-
148-
Baseclass for all types that iterate over some data source
149-
one observation at a time.
150-
151-
```julia
152-
using MLDataPattern
153-
@assert typeof(RandomObs(X)) <: ObsIterator
154-
155-
for x in RandomObs(X)
156-
# ...
157-
end
158-
```
98+
abstract type AbstractDataContainer end
15999

160-
see `MLDataPattern.RandomObs`
161-
"""
162-
abstract type ObsIterator{TElem,TData} <: DataIterator{TElem,TData} end
163-
164-
"""
165-
abstract BatchIterator{TElem,TData} <: DataIterator{TElem,TData}
166-
167-
Baseclass for all types that iterate over of some data source one
168-
batch at a time.
169-
170-
```julia
171-
@assert typeof(RandomBatches(X, size=10)) <: BatchIterator
172-
173-
for x in RandomBatches(X, size=10)
174-
@assert nobs(x) == 10
175-
# ...
176-
end
177-
```
178-
179-
see `MLDataPattern.RandomBatches`
180-
"""
181-
abstract type BatchIterator{TElem,TData} <: DataIterator{TElem,TData} end
100+
Base.getindex(x::AbstractDataContainer, i) = getobs(x, i, default_obsdim(x))
101+
Base.iterate(x::AbstractDataContainer, state = 1) = getobs(x, state, default_obsdim(x)), state + 1
182102

183103
# --------------------------------------------------------------------
184104

185-
# just for dispatch for those who care to
186-
const AbstractDataIterator{E,T} = Union{DataIterator{E,T}, DataView{E,T}}
187-
const AbstractObsIterator{E,T} = Union{ObsIterator{E,T}, AbstractObsView{E,T}}
188-
const AbstractBatchIterator{E,T} = Union{BatchIterator{E,T},AbstractBatchView{E,T}}
105+
# Might need this distinction later
106+
# e.g. shuffleobs can be anywhere in pipeline but
107+
# eachbatch is usually at the end
108+
abstract type AbstractDataIterator <: AbstractDataContainer end

test/observation.jl

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,9 @@ using LearnBase: getobs, nobs, default_obsdim
1313
@test @inferred getobs(a, 1) == a[:,1]
1414
@test @inferred getobs(a, 2) == a[:,2]
1515
@test @inferred getobs(a, 1:2) == a[:,1:2]
16-
@test @inferred getobs(a, 1, obsdim=1) == a[1,:]
17-
@test @inferred getobs(a, 2, obsdim=1) == a[2,:]
18-
@test @inferred getobs(a, 2, obsdim=nothing) a[:,2]
16+
@test @inferred getobs(a, 1, 1) == a[1,:]
17+
@test @inferred getobs(a, 2, 1) == a[2,:]
18+
@test @inferred getobs(a, 2, nothing) a[:,2]
1919
end
2020

2121
@testset "tuple" begin
@@ -36,7 +36,7 @@ using LearnBase: getobs, nobs, default_obsdim
3636
else
3737
o = getobs(dataset, 1:2)
3838
end
39-
39+
4040
@test o[1] == X[:,1:2]
4141
@test o[2] == Y[1:2]
4242
end
@@ -53,7 +53,7 @@ using LearnBase: getobs, nobs, default_obsdim
5353
end
5454
@test o.x == X[:,2]
5555
@test o.y == Y[2]
56-
56+
5757
if VERSION >= v"1.6"
5858
o = @inferred getobs(dataset, 1:2)
5959
else
@@ -67,7 +67,7 @@ using LearnBase: getobs, nobs, default_obsdim
6767
X, Y = rand(2, 3), rand(3)
6868
dataset = Dict("X" => X, "Y" => Y)
6969
@test nobs(dataset) == 3
70-
70+
7171
# o = @inferred getobs(dataset, 2) # not inferred
7272
o = getobs(dataset, 2)
7373
@test o["X"] == X[:,2]
@@ -78,5 +78,3 @@ using LearnBase: getobs, nobs, default_obsdim
7878
@test o["Y"] == Y[1:2]
7979
end
8080
end
81-
82-

0 commit comments

Comments
 (0)