Skip to content

Commit

Permalink
numobs and getobs support for Tables.jl's tables (#124)
Browse files Browse the repository at this point in the history
* support tables

* use SimpleTraits.jl

* more tests

* docstring

* more docs
  • Loading branch information
CarloLucibello committed Oct 23, 2022
1 parent e247fb5 commit a85c098
Show file tree
Hide file tree
Showing 5 changed files with 102 additions and 12 deletions.
11 changes: 9 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,34 +1,41 @@
name = "MLUtils"
uuid = "f1d291b0-491e-4a28-83b9-f70985020b54"
authors = ["Carlo Lucibello <[email protected]> and contributors"]
version = "0.2.11"
version = "0.2.12"

[deps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
DataAPI = "9a962f9c-6df0-11e9-0e5d-c546b8b5ee8a"
DelimitedFiles = "8bb1440f-4735-579b-a4ab-409b98df4dab"
FLoops = "cc61a311-1640-44b5-9fba-1b764f453329"
FoldsThreads = "9c68100b-dfe1-47cf-94c8-95104e173443"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
ShowCases = "605ecd9f-84a6-4c9e-81e2-4798472b76a3"
SimpleTraits = "699a6c99-e7fa-54fc-8d76-47d257e15c1d"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
Tables = "bd369af6-aec1-5ad0-b16a-f7cc5008161c"
Transducers = "28d57a85-8fef-5791-bfe6-a80928e7c999"

[compat]
ChainRulesCore = "1.0"
DataAPI = "1.0"
DelimitedFiles = "1.0"
FLoops = "0.2"
FoldsThreads = "0.1"
SimpleTraits = "0.9"
ShowCases = "0.1"
StatsBase = "0.33"
Tables = "1.10"
Transducers = "0.4"
julia = "1.6"

[extras]
ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a"
DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0"
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[targets]
test = ["ChainRulesTestUtils", "SparseArrays", "Test", "Zygote"]
test = ["ChainRulesTestUtils", "DataFrames", "SparseArrays", "Test", "Zygote"]
7 changes: 7 additions & 0 deletions src/MLUtils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,19 @@ using FLoops.Transducers: Executor, ThreadedEx
using FoldsThreads: TaskPoolEx
import StatsBase: sample
using Transducers
using Tables
using DataAPI
using Base: @propagate_inbounds
using Random: AbstractRNG, shuffle!, GLOBAL_RNG, rand!, randn!
import ChainRulesCore: rrule
using ChainRulesCore: @non_differentiable, unthunk, AbstractZero,
NoTangent, ZeroTangent, ProjectTo

using SimpleTraits

@traitdef IsTable{X}
@traitimpl IsTable{X} <- Tables.istable(X)


include("observation.jl")
export numobs,
Expand Down
86 changes: 76 additions & 10 deletions src/observation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,30 +3,68 @@
Return the total number of observations contained in `data`.
If `data` does not have `numobs` defined, then this function
falls back to `length(data)`.
If `data` does not have `numobs` defined,
then in the case of `Tables.table(data) == true`
returns the number of rows, otherwise returns `length(data)`.
Authors of custom data containers should implement
`Base.length` for their type instead of `numobs`.
`numobs` should only be implemented for types where there is a
difference between `numobs` and `Base.length`
(such as multi-dimensional arrays).
See also [`getobs`](@ref)
`getobs` supports by default nested combinations of array, tuple,
named tuples, and dictionaries.
See also [`getobs`](@ref).
# Examples
```jldoctest
# named tuples
x = (a = [1, 2, 3], b = rand(6, 3))
numobs(x) == 3
# dictionaries
x = Dict(:a => [1, 2, 3], :b => rand(6, 3))
numobs(x) == 3
```
All internal containers must have the same number of observations:
```juliarepl
julia> x = (a = [1, 2, 3, 4], b = rand(6, 3));
julia> numobs(x)
ERROR: DimensionMismatch: All data containers must have the same number of observations.
Stacktrace:
[1] _check_numobs_error()
@ MLUtils ~/.julia/dev/MLUtils/src/observation.jl:163
[2] _check_numobs
@ ~/.julia/dev/MLUtils/src/observation.jl:130 [inlined]
[3] numobs(data::NamedTuple{(:a, :b), Tuple{Vector{Int64}, Matrix{Float64}}})
@ MLUtils ~/.julia/dev/MLUtils/src/observation.jl:177
[4] top-level scope
@ REPL[35]:1
```
"""
function numobs end

# Generic Fallbacks
numobs(data) = length(data)
@traitfn numobs(data::X) where {X; IsTable{X}} = DataAPI.nrow(data)
@traitfn numobs(data::X) where {X; !IsTable{X}} = length(data)


"""
getobs(data, [idx])
Return the observations corresponding to the observation-index `idx`.
Return the observations corresponding to the observation index `idx`.
Note that `idx` can be any type as long as `data` has defined
`getobs` for that type.
`getobs` for that type. If `idx` is not provided, then materialize
all observations in `data`.
If `data` does not have `getobs` defined,
then in the case of `Tables.table(data) == true`
returns the row(s) in position `idx`, otherwise returns `data[idx]`.
If `data` does not have `getobs` defined, then this function
falls back to `data[idx]`.
Authors of custom data containers should implement
`Base.getindex` for their type instead of `getobs`.
`getobs` should only be implemented for types where there is a
Expand All @@ -40,13 +78,37 @@ Every author behind some custom data container can make this
decision themselves.
The output should be consistent when `idx` is a scalar vs vector.
See also [`getobs!`](@ref) and [`numobs`](@ref)
`getobs` supports by default nested combinations of array, tuple,
named tuples, and dictionaries.
See also [`getobs!`](@ref) and [`numobs`](@ref).
# Examples
```jldoctest
# named tuples
x = (a = [1, 2, 3], b = rand(6, 3))
getobs(x, 2) == (a = 2, b = x.b[:, 2])
getobs(x, [1, 3]) == (a = [1, 3], b = x.b[:, [1, 3]])
# dictionaries
x = Dict(:a => [1, 2, 3], :b => rand(6, 3))
getobs(x, 2) == Dict(:a => 2, :b => x[:b][:, 2])
getobs(x, [1, 3]) == Dict(:a => [1, 3], :b => x[:b][:, [1, 3]])
```
"""
function getobs end

# Generic Fallbacks

getobs(data) = data
getobs(data, idx) = data[idx]

@traitfn getobs(data::X, idx) where {X; IsTable{X}} = Tables.subset(data, idx, viewhint=false)
@traitfn getobs(data::X, idx) where {X; !IsTable{X}} = data[idx]


"""
getobs!(buffer, data, idx)
Expand All @@ -61,6 +123,8 @@ method is provided for the type of `data`, then `buffer` will be
because the type of `data` may not lend itself to the concept
of `copy!`. Thus, supporting a custom `getobs!` is optional
and not required.
See also [`getobs`](@ref) and [`numobs`](@ref).
"""
function getobs! end
# getobs!(buffer, data) = getobs(data)
Expand Down Expand Up @@ -161,3 +225,5 @@ function getobs!(buffers, data::Dict, i)

return buffers
end


8 changes: 8 additions & 0 deletions test/observation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -190,4 +190,12 @@ end
@test getobs!((nothing,xbuf),(Xs,X), 3:4) == (getobs(Xs,3:4),xbuf)
@test xbuf == getobs(X,3:4)
end

@testset "tables" begin
df = DataFrame(a=[1,2,3], y=["a","b","c"])
@test getobs(df) == df
@test getobs(df, 1) == df[1,:]
@test getobs(df, 2:3) == df[2:3,:]
@test numobs(df) == 3
end
end
2 changes: 2 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
using MLUtils
using MLUtils.Datasets
using MLUtils: RingBuffer, eachobsparallel
using MLUtils: flatten, stack, unstack # also exported by DataFrames.jl
using SparseArrays
using Random, Statistics
using Test
Expand All @@ -9,6 +10,7 @@ using FoldsThreads: TaskPoolEx
using ChainRulesTestUtils: test_rrule
using Zygote: ZygoteRuleConfig
using ChainRulesCore: rrule_via_ad
using DataFrames

showcompact(io, x) = show(IOContext(io, :compact => true), x)

Expand Down

0 comments on commit a85c098

Please sign in to comment.