diff --git a/Project.toml b/Project.toml index 45f22a69..79254b35 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "MLJBase" uuid = "a7f614a8-145f-11e9-1d2a-a57a1082229d" authors = ["Anthony D. Blaom "] -version = "0.13.4" +version = "0.13.5" [deps] CategoricalArrays = "324d7699-5711-5eae-9e2f-1d82baa6b597" diff --git a/src/MLJBase.jl b/src/MLJBase.jl index 253f23a4..fb75b5e4 100644 --- a/src/MLJBase.jl +++ b/src/MLJBase.jl @@ -96,8 +96,7 @@ export coerce, coerce!, autotype, schema, info # ------------------------------------------------------------------- # exports from MLJBase -export DeterministicNetwork, ProbabilisticNetwork, UnsupervisedNetwork, - best, @load +export DeterministicNetwork, ProbabilisticNetwork, UnsupervisedNetwork, @load # computational_resources.jl: export default_resource @@ -109,8 +108,7 @@ export ParamRange, NumericRange, NominalRange, iterator, scale export params # note this is *not* an extension of StatsBase.params # data.jl: -export reconstruct, levels_seen, container_type, - partition, unpack, complement, restrict, corestrict +export partition, unpack, complement, restrict, corestrict # utilities.jl: export @set_defaults, flat_values, recursive_setproperty!, @@ -122,10 +120,6 @@ export HANDLE_GIVEN_ID, @more, @constant, color_on, color_off # distributions.jl: export average -# tasks.jl: -export SupervisedTask, UnsupervisedTask, MLJTask, - X_and_y, X_, y_, nfeatures - # info_dict.jl: export info_dict diff --git a/src/data/data.jl b/src/data/data.jl index c9149c19..caae59b4 100644 --- a/src/data/data.jl +++ b/src/data/data.jl @@ -28,10 +28,20 @@ function _partition(rows, fractions, ::Nothing) return tuple((rows[h:t] for (h, t) in zip(heads, tails))...) end +_make_numerical(v::AbstractVector) = + throw(ArgumentError("`stratify` must have `Count`, `Continuous` "* + "or `Finite` element scitpye. Consider "* + "`coerce(stratify, Finite)`. ")) +_make_numerical(v::AbstractVector{<:Union{Missing,Real}}) = v +_make_numerical(v::AbstractVector{<:Union{Missing,CategoricalElement}}) = + int.(v) + # Helper function for partitioning in the stratified case -function _partition(rows, fractions, stratify::AbstractVector) +function _partition(rows, fractions, raw_stratify::AbstractVector) + stratify = _make_numerical(raw_stratify) length(stratify) == length(rows) || - throw(ArgumentError("The stratification vector must have as many entries as " * + throw(ArgumentError("The stratification vector must "* + "have as many entries as " * "the rows to partition.")) uv = unique(stratify) # construct table (n_classes * idx_of_that_class) @@ -72,7 +82,7 @@ end """ partition(rows::AbstractVector{Int}, fractions...; - shuffle=nothing, rng=Random.GLOBAL_RNG) + shuffle=nothing, rng=Random.GLOBAL_RNG, stratify=nothing) Splits the vector `rows` into a tuple of vectors whose lengths are given by the corresponding `fractions` of `length(rows)` where valid diff --git a/src/resampling.jl b/src/resampling.jl index 5b46a9a0..919b4d27 100644 --- a/src/resampling.jl +++ b/src/resampling.jl @@ -630,7 +630,7 @@ function _evaluate!(func, machines, ::CPUThreads, nfolds,verbosity) @sync begin - @sync for parts in Iterators.partition(1:nfolds, max(1,floor(Int, nfolds/n_threads))) + @sync for parts in Iterators.partition(1:nfolds, max(1,cld(nfolds, n_threads))) Threads.@spawn begin for k in parts id = Threads.threadid() diff --git a/test/data/data.jl b/test/data/data.jl index 07ac3793..dde6357d 100644 --- a/test/data/data.jl +++ b/test/data/data.jl @@ -44,7 +44,11 @@ import MLJBase: decoder, int, classes, partition, unpack, selectcols, matrix, y = ones(Int, 1000) y[end-100:end] .= 0; # 90% + train1, test1 = + partition(eachindex(y), 0.8, stratify=categorical(y), rng=34) train, test = partition(eachindex(y), 0.8, stratify=y, rng=34) + @test train == train1 + @test test == test1 @test isapprox(sum(y[train])/length(train), 0.9, rtol=1e-2) @test isapprox(sum(y[test])/length(test), 0.9, rtol=1e-2)