Skip to content

Commit

Permalink
Improve examples for splitobs (#165)
Browse files Browse the repository at this point in the history
* better examples for splitobs

* tweak

* Update src/splitobs.jl
  • Loading branch information
mcabbott authored Aug 26, 2023
1 parent 3a7d3e3 commit af7ebea
Showing 1 changed file with 17 additions and 11 deletions.
28 changes: 17 additions & 11 deletions src/splitobs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -40,26 +40,32 @@ end
"""
splitobs(data; at, shuffle=false) -> Tuple
Split the `data` into multiple subsets proportional to the
value(s) of `at`.
Partition the `data` into two or more subsets.
When `at` is a number (between 0 and 1) this specifies the proportion in the first subset.
When `at` is a tuple, each entry specifies the proportion an a subset,
with the last having `1-sum(at)`. In all there are `length(at)+1` subsets returned.
If `shuffle=true`, randomly permute the observations before splitting.
Supports any datatype implementing the [`numobs`](@ref) and
[`getobs`](@ref) interfaces.
[`getobs`](@ref) interfaces -- including arrays, tuples & NamedTuples of arrays.
# Examples
```julia
# A 70%-30% split
train, test = splitobs(X, at=0.7)
```jldoctest
julia> splitobs(permutedims(1:100); at=0.7) # simple 70%-30% split, of a matrix
([1 2 … 69 70], [71 72 … 99 100])
julia> data = (x=ones(2,10), n=1:10) # a NamedTuple, consistent last dimension
(x = [1.0 1.0 … 1.0 1.0; 1.0 1.0 … 1.0 1.0], n = 1:10)
julia> splitobs(data, at=(0.5, 0.3)) # a 50%-30%-20% split, e.g. train/test/validation
((x = [1.0 1.0 … 1.0 1.0; 1.0 1.0 … 1.0 1.0], n = 1:5), (x = [1.0 1.0 1.0; 1.0 1.0 1.0], n = 6:8), (x = [1.0 1.0; 1.0 1.0], n = 9:10))
# A 50%-30%-20% split
train, val, test = splitobs(X, at=(0.5, 0.3))
julia> train, test = splitobs((permutedims(1.0:100.0), 101:200), at=0.7, shuffle=true); # split a Tuple
# A 70%-30% split with multiple arrays and shuffling
train, test = splitobs((X, y), at=0.7, shuffle=true)
Xtrain, Ytrain = train
julia> vec(test[1]) .+ 100 == test[2]
true
```
"""
function splitobs(data; at, shuffle::Bool=false)
Expand Down

0 comments on commit af7ebea

Please sign in to comment.