Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

numobs and getobs support for Tables.jl's tables #124

Merged
merged 5 commits into from
Oct 23, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 9 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,34 +1,41 @@
name = "MLUtils"
uuid = "f1d291b0-491e-4a28-83b9-f70985020b54"
authors = ["Carlo Lucibello <[email protected]> and contributors"]
version = "0.2.11"
version = "0.2.12"

[deps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
DataAPI = "9a962f9c-6df0-11e9-0e5d-c546b8b5ee8a"
DelimitedFiles = "8bb1440f-4735-579b-a4ab-409b98df4dab"
FLoops = "cc61a311-1640-44b5-9fba-1b764f453329"
FoldsThreads = "9c68100b-dfe1-47cf-94c8-95104e173443"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
ShowCases = "605ecd9f-84a6-4c9e-81e2-4798472b76a3"
SimpleTraits = "699a6c99-e7fa-54fc-8d76-47d257e15c1d"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
Tables = "bd369af6-aec1-5ad0-b16a-f7cc5008161c"
Transducers = "28d57a85-8fef-5791-bfe6-a80928e7c999"

[compat]
ChainRulesCore = "1.0"
DataAPI = "1.0"
DelimitedFiles = "1.0"
FLoops = "0.2"
FoldsThreads = "0.1"
SimpleTraits = "0.9"
ShowCases = "0.1"
StatsBase = "0.33"
Tables = "1.10"
Transducers = "0.4"
julia = "1.6"

[extras]
ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a"
DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0"
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[targets]
test = ["ChainRulesTestUtils", "SparseArrays", "Test", "Zygote"]
test = ["ChainRulesTestUtils", "DataFrames", "SparseArrays", "Test", "Zygote"]
7 changes: 7 additions & 0 deletions src/MLUtils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,19 @@ using FLoops.Transducers: Executor, ThreadedEx
using FoldsThreads: TaskPoolEx
import StatsBase: sample
using Transducers
using Tables
using DataAPI
using Base: @propagate_inbounds
using Random: AbstractRNG, shuffle!, GLOBAL_RNG, rand!, randn!
import ChainRulesCore: rrule
using ChainRulesCore: @non_differentiable, unthunk, AbstractZero,
NoTangent, ZeroTangent, ProjectTo

using SimpleTraits

@traitdef IsTable{X}
@traitimpl IsTable{X} <- Tables.istable(X)


include("observation.jl")
export numobs,
Expand Down
86 changes: 76 additions & 10 deletions src/observation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,30 +3,68 @@

Return the total number of observations contained in `data`.

If `data` does not have `numobs` defined, then this function
falls back to `length(data)`.
If `data` does not have `numobs` defined,
then in the case of `Tables.table(data) == true`
returns the number of rows, otherwise returns `length(data)`.

Authors of custom data containers should implement
`Base.length` for their type instead of `numobs`.
`numobs` should only be implemented for types where there is a
difference between `numobs` and `Base.length`
(such as multi-dimensional arrays).

See also [`getobs`](@ref)
`getobs` supports by default nested combinations of array, tuple,
named tuples, and dictionaries.

See also [`getobs`](@ref).

# Examples
```jldoctest

# named tuples
x = (a = [1, 2, 3], b = rand(6, 3))
numobs(x) == 3

# dictionaries
x = Dict(:a => [1, 2, 3], :b => rand(6, 3))
numobs(x) == 3
```
All internal containers must have the same number of observations:
```juliarepl
julia> x = (a = [1, 2, 3, 4], b = rand(6, 3));

julia> numobs(x)
ERROR: DimensionMismatch: All data containers must have the same number of observations.
Stacktrace:
[1] _check_numobs_error()
@ MLUtils ~/.julia/dev/MLUtils/src/observation.jl:163
[2] _check_numobs
@ ~/.julia/dev/MLUtils/src/observation.jl:130 [inlined]
[3] numobs(data::NamedTuple{(:a, :b), Tuple{Vector{Int64}, Matrix{Float64}}})
@ MLUtils ~/.julia/dev/MLUtils/src/observation.jl:177
[4] top-level scope
@ REPL[35]:1
```
"""
function numobs end

# Generic Fallbacks
numobs(data) = length(data)
@traitfn numobs(data::X) where {X; IsTable{X}} = DataAPI.nrow(data)
@traitfn numobs(data::X) where {X; !IsTable{X}} = length(data)


"""
getobs(data, [idx])

Return the observations corresponding to the observation-index `idx`.
Return the observations corresponding to the observation index `idx`.
Note that `idx` can be any type as long as `data` has defined
`getobs` for that type.
`getobs` for that type. If `idx` is not provided, then materialize
all observations in `data`.

If `data` does not have `getobs` defined,
then in the case of `Tables.table(data) == true`
returns the row(s) in position `idx`, otherwise returns `data[idx]`.

If `data` does not have `getobs` defined, then this function
falls back to `data[idx]`.
Authors of custom data containers should implement
`Base.getindex` for their type instead of `getobs`.
`getobs` should only be implemented for types where there is a
Expand All @@ -40,13 +78,37 @@ Every author behind some custom data container can make this
decision themselves.
The output should be consistent when `idx` is a scalar vs vector.

See also [`getobs!`](@ref) and [`numobs`](@ref)
`getobs` supports by default nested combinations of array, tuple,
named tuples, and dictionaries.

See also [`getobs!`](@ref) and [`numobs`](@ref).

# Examples

```jldoctest
# named tuples
x = (a = [1, 2, 3], b = rand(6, 3))

getobs(x, 2) == (a = 2, b = x.b[:, 2])
getobs(x, [1, 3]) == (a = [1, 3], b = x.b[:, [1, 3]])


# dictionaries
x = Dict(:a => [1, 2, 3], :b => rand(6, 3))

getobs(x, 2) == Dict(:a => 2, :b => x[:b][:, 2])
getobs(x, [1, 3]) == Dict(:a => [1, 3], :b => x[:b][:, [1, 3]])
```
"""
function getobs end

# Generic Fallbacks

getobs(data) = data
getobs(data, idx) = data[idx]

@traitfn getobs(data::X, idx) where {X; IsTable{X}} = Tables.subset(data, idx, viewhint=false)
@traitfn getobs(data::X, idx) where {X; !IsTable{X}} = data[idx]


"""
getobs!(buffer, data, idx)
Expand All @@ -61,6 +123,8 @@ method is provided for the type of `data`, then `buffer` will be
because the type of `data` may not lend itself to the concept
of `copy!`. Thus, supporting a custom `getobs!` is optional
and not required.

See also [`getobs`](@ref) and [`numobs`](@ref).
"""
function getobs! end
# getobs!(buffer, data) = getobs(data)
Expand Down Expand Up @@ -161,3 +225,5 @@ function getobs!(buffers, data::Dict, i)

return buffers
end


8 changes: 8 additions & 0 deletions test/observation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -190,4 +190,12 @@ end
@test getobs!((nothing,xbuf),(Xs,X), 3:4) == (getobs(Xs,3:4),xbuf)
@test xbuf == getobs(X,3:4)
end

@testset "tables" begin
df = DataFrame(a=[1,2,3], y=["a","b","c"])
@test getobs(df) == df
@test getobs(df, 1) == df[1,:]
@test getobs(df, 2:3) == df[2:3,:]
@test numobs(df) == 3
end
end
2 changes: 2 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
using MLUtils
using MLUtils.Datasets
using MLUtils: RingBuffer, eachobsparallel
using MLUtils: flatten, stack, unstack # also exported by DataFrames.jl
using SparseArrays
using Random, Statistics
using Test
Expand All @@ -9,6 +10,7 @@ using FoldsThreads: TaskPoolEx
using ChainRulesTestUtils: test_rrule
using Zygote: ZygoteRuleConfig
using ChainRulesCore: rrule_via_ad
using DataFrames

showcompact(io, x) = show(IOContext(io, :compact => true), x)

Expand Down