Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix corner cases of cut with duplicated breaks #410

Merged
merged 7 commits into from
Dec 30, 2024
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 39 additions & 24 deletions src/extras.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,14 @@ function fill_refs!(refs::AbstractArray, X::AbstractArray,
@inbounds for i in eachindex(X)
x = X[i]

if ismissing(x)
if x isa Number && isnan(x)
throw(ArgumentError("NaN values are not allowed in input vector"))
elseif ismissing(x)
refs[i] = 0
elseif x == upper
elseif isequal(x, upper)
refs[i] = n-1
elseif extend !== true && !(lower <= x <= upper)
elseif extend !== true &&
!((isless(lower, x) || isequal(x, lower)) && isless(x, upper))
extend === missing ||
throw(ArgumentError("value $x (at index $i) does not fall inside the breaks: " *
"adapt them manually, or pass extend=true or extend=missing"))
Expand Down Expand Up @@ -55,10 +58,10 @@ also accept them.
the intervals; or a function `f(from, to, i; leftclosed, rightclosed)` that generates
the labels from the left and right interval boundaries and the group index. Defaults to
`"[from, to)"` (or `"[from, to]"` for the rightmost interval if `extend == true`).
* `allowempty::Bool=false`: when `false`, an error is raised if some breaks appear
multiple times, generating empty intervals; when `true`, duplicate breaks are allowed
and the intervals they generate are kept as unused levels
(but duplicate labels are not allowed).
* `allowempty::Bool=false`: when `false`, an error is raised if some breaks other than
the last one appear multiple times, generating empty intervals; when `true`,
duplicate breaks are allowed and the intervals they generate are kept as
unused levels (but duplicate labels are not allowed).

# Examples
```jldoctest
Expand Down Expand Up @@ -132,14 +135,19 @@ function _cut(x::AbstractArray{T, N}, breaks::AbstractVector,
extend::Union{Bool, Missing},
labels::Union{AbstractVector{<:SupportedTypes},Function},
allowempty::Bool=false) where {T, N}
if !allowempty && !allunique(breaks)
throw(ArgumentError("all breaks must be unique unless `allowempty=true`"))
end

if !issorted(breaks)
breaks = sort(breaks)
end

if any(x -> x isa Number && isnan(x), breaks)
throw(ArgumentError("NaN values are not allowed in breaks"))
end

if !allowempty && !allunique(@view breaks[1:end-1])
throw(ArgumentError("all breaks other than the last one must be unique " *
"unless `allowempty=true`"))
end

if extend === true
xnm = T >: Missing ? skipmissing(x) : x
length(breaks) >= 1 || throw(ArgumentError("at least one break must be provided"))
Expand All @@ -158,11 +166,11 @@ function _cut(x::AbstractArray{T, N}, breaks::AbstractVector,
rethrow(err)
end
end
if !ismissing(min_x) && breaks[1] > min_x
if !ismissing(min_x) && isless(min_x, breaks[1])
# this type annotation is needed on Julia<1.7 for stable inference
breaks = [min_x::nonmissingtype(eltype(x)); breaks]
end
if !ismissing(max_x) && breaks[end] < max_x
if !ismissing(max_x) && isless(breaks[end], max_x)
breaks = [breaks; max_x::nonmissingtype(eltype(x))]
end
length(breaks) > 1 ||
Expand All @@ -189,16 +197,15 @@ function _cut(x::AbstractArray{T, N}, breaks::AbstractVector,
from = breaks[1:n-1]
to = breaks[2:n]
firstlevel = labels(from[1], to[1], 1,
leftclosed=breaks[1] != breaks[2], rightclosed=false)
leftclosed=!isequal(breaks[1], breaks[2]), rightclosed=false)
levs = Vector{typeof(firstlevel)}(undef, n-1)
levs[1] = firstlevel
for i in 2:n-2
levs[i] = labels(from[i], to[i], i,
leftclosed=breaks[i] != breaks[i+1], rightclosed=false)
leftclosed=!isequal(breaks[i], breaks[i+1]), rightclosed=false)
end
levs[end] = labels(from[end], to[end], n-1,
leftclosed=breaks[end-1] != breaks[end],
rightclosed=true)
leftclosed=true, rightclosed=true)
else
length(labels) == n-1 ||
throw(ArgumentError("labels must be of length $(n-1), but got length $(length(labels))"))
Expand Down Expand Up @@ -243,21 +250,29 @@ quantiles.
the labels from the left and right interval boundaries and the group index. Defaults to
`"Qi: [from, to)"` (or `"Qi: [from, to]"` for the rightmost interval).
* `allowempty::Bool=false`: when `false`, an error is raised if some quantiles breakpoints
are equal, generating empty intervals; when `true`, duplicate breaks are allowed
and the intervals they generate are kept as unused levels
(but duplicate labels are not allowed).
other than the last one are equal, generating empty intervals;
when `true`, duplicate breaks are allowed and the intervals they generate are kept as
unused levels (but duplicate labels are not allowed).
"""
function cut(x::AbstractArray, ngroups::Integer;
labels::Union{AbstractVector{<:SupportedTypes},Function}=quantile_formatter,
allowempty::Bool=false)
ngroups >= 1 || throw(ArgumentError("ngroups must be strictly positive (got $ngroups)"))
xnm = eltype(x) >: Missing ? skipmissing(x) : x
breaks = Statistics.quantile(xnm, (1:ngroups-1)/ngroups)
if !allowempty && !allunique(breaks)
# Computing extrema is faster than taking 0 and 1 quantiles
min_x, max_x = extrema(xnm)
if (min_x isa Number && isnan(min_x)) ||
(max_x isa Number && isnan(max_x))
throw(ArgumentError("NaN values are not allowed in input vector"))
end
breaks = quantile(xnm, (1:ngroups-1)/ngroups)
breaks = [min_x; breaks; max_x]
if !allowempty && !allunique(@view breaks[1:end-1])
n = length(unique(breaks)) - 1
throw(ArgumentError("cannot compute $ngroups quantiles: `quantile` " *
"returned only $n groups due to duplicated values in `x`." *
"returned only $n group(s) due to duplicated values in `x`. " *
"Pass `allowempty=true` to allow empty quantiles or " *
"choose a lower value for `ngroups`."))
end
cut(x, breaks; extend=true, labels=labels, allowempty=allowempty)
cut(x, breaks; labels=labels, allowempty=allowempty)
end
91 changes: 88 additions & 3 deletions test/15_extras.jl
Original file line number Diff line number Diff line change
Expand Up @@ -111,9 +111,6 @@ const ≅ = isequal
@test isa(x, CategoricalVector{Union{Int, String, T}})
@test isordered(x)
@test levels(x) == [0, "2", 4, "6", 8]

@test_throws ArgumentError cut([-0.0, 0.0], 2)
@test_throws ArgumentError cut([-0.0, 0.0], 2, labels=[-0.0, 0.0])
end

@testset "cut with missing values in input" begin
Expand Down Expand Up @@ -144,6 +141,11 @@ end
@test levels(x) == ["Q1: [2.0, 3.5)", "Q2: [3.5, 5.0]"]
end

@testset "cut(x, n) with invalid n" begin
@test_throws ArgumentError cut(1:10, 0)
@test_throws ArgumentError cut(1:10, -1)
end

@testset "cut with formatter function" begin
my_formatter(from, to, i; leftclosed, rightclosed) = "$i: $from -- $to"

Expand Down Expand Up @@ -185,11 +187,20 @@ end
x = [zeros(10); ones(10)]
@test_throws ArgumentError cut(x, [0, 0.1, 0.1, 10])
@test_throws ArgumentError cut(x, 10)
y = cut(x, [0, 0.1, 10, 10])
@test y == [fill("[0.0, 0.1)", 10); fill("[0.1, 10.0)", 10)]
@test levels(y) == ["[0.0, 0.1)", "[0.1, 10.0)", "[10.0, 10.0]"]

@test_throws ArgumentError cut(1:10, [1, 5, 5, 11])
y = cut(1:10, [1, 5, 5, 11], allowempty=true)
@test y == cut(1:10, [1, 5, 11])
@test levels(y) == ["[1, 5)", "(5, 5)", "[5, 11]"]
y = cut(1:10, [1, 5, 11, 11])
@test y == [fill("[1, 5)", 4); fill("[5, 11)", 6)]
@test levels(y) == ["[1, 5)", "[5, 11)", "[11, 11]"]
y = cut(1:10, [1, 5, 10, 10])
@test y == [fill("[1, 5)", 4); fill("[5, 10)", 5); "[10, 10]"]
@test levels(y) == ["[1, 5)", "[5, 10)", "[10, 10]"]

@test_throws ArgumentError cut(1:10, [1, 5, 5, 5, 11])
@test_throws ArgumentError cut(1:10, [1, 5, 5, 11],
Expand Down Expand Up @@ -242,6 +253,49 @@ end

fmt = (from, to, i; leftclosed, rightclosed) -> (i % 2 == 0 ? to : 0.0)
@test_throws ArgumentError cut(1:8, 0:2:10, labels=fmt)

@test_throws ArgumentError cut([fill(1, 10); 4], 2)
@test_throws ArgumentError cut([fill(1, 10); 4], 3)
x = cut([fill(1, 10); 4], 2, allowempty=true)
@test unique(x) == ["Q2: [1.0, 4.0]"]
x = cut([fill(1, 10); 4], 3, allowempty=true)
@test unique(x) == ["Q3: [1.0, 4.0]"]
@test levels(x) == ["Q1: (1.0, 1.0)", "Q2: (1.0, 1.0)", "Q3: [1.0, 4.0]"]

x = cut([fill(1, 5); fill(4, 5)], 2)
@test x == [fill("Q1: [1.0, 2.5)", 5); fill("Q2: [2.5, 4.0]", 5)]
@test levels(x) == ["Q1: [1.0, 2.5)", "Q2: [2.5, 4.0]"]
@test_throws ArgumentError cut([fill(1, 5); fill(4, 5)], 3)
x = cut([fill(1, 5); fill(4, 5)], 3, allowempty=true)
@test x == [fill("Q2: [1.0, 4.0)", 5); fill("Q3: [4.0, 4.0]", 5)]
@test levels(x) == ["Q1: (1.0, 1.0)", "Q2: [1.0, 4.0)", "Q3: [4.0, 4.0]"]
end

@testset "cut with -0.0" begin
x = cut([-0.0, 0.0, 0.0, -0.0], 2)
@test x == ["Q1: [-0.0, 0.0)", "Q2: [0.0, 0.0]", "Q2: [0.0, 0.0]", "Q1: [-0.0, 0.0)"]
@test levels(x) == ["Q1: [-0.0, 0.0)", "Q2: [0.0, 0.0]"]

x = cut([-0.0, 0.0, 0.0, -0.0], [-0.0, 0.0, 0.0])
@test x == ["[-0.0, 0.0)", "[0.0, 0.0]", "[0.0, 0.0]", "[-0.0, 0.0)"]
@test levels(x) == ["[-0.0, 0.0)", "[0.0, 0.0]"]

x = cut([-0.0, 0.0, 0.0, -0.0], [-0.0, 0.0])
@test x == fill("[-0.0, 0.0]", 4)
@test levels(x) == ["[-0.0, 0.0]"]

x = cut([-0.0, 0.0, 0.0, -0.0], [0.0], extend=true)
@test x == fill("[-0.0, 0.0]", 4)
@test levels(x) == ["[-0.0, 0.0]"]

x = cut([-0.0, 0.0, 0.0, -0.0], [-0.0], extend=true)
@test x == fill("[-0.0, 0.0]", 4)
@test levels(x) == ["[-0.0, 0.0]"]

x = cut([-0.0, 0.0, 0.0, -0.0], 2, labels=[-0.0, 0.0])
@test x == [-0.0, 0.0, 0.0, -0.0]

@test_throws ArgumentError cut([-0.0, 0.0, 0.0, -0.0], [-0.0, -0.0, 0.0])
end

@testset "cut with extend=true" begin
Expand Down Expand Up @@ -276,4 +330,35 @@ end
@test x == ["[-1.0, 0.0)", "[-1.0, 0.0)", "[0.0, 1.0]", "[0.0, 1.0]", "[0.0, 1.0]"]
end

@testset "cut with NaN and Inf" begin
@test_throws ArgumentError("NaN values are not allowed in input vector") cut([1, NaN, 2, 3], [1, 10])
@test_throws ArgumentError("NaN values are not allowed in input vector") cut([1, NaN, 2, 3], [1], extend=true)
@test_throws ArgumentError("NaN values are not allowed in input vector") cut([1, NaN, 2, 3], 2)
@test_throws ArgumentError("NaN values are not allowed in breaks") cut([1, 2], [1, NaN])

x = cut([1, Inf], [1], extend=true)
@test x ≅ ["[1.0, Inf]", "[1.0, Inf]"]
@test levels(x) == ["[1.0, Inf]"]

x = cut([1, -Inf], [1], extend=true)
@test x ≅ ["[-Inf, 1.0]", "[-Inf, 1.0]"]
@test levels(x) == ["[-Inf, 1.0]"]

x = cut([1:5; Inf], [1, 2, Inf])
@test x ≅ ["[1.0, 2.0)"; fill("[2.0, Inf]", 5)]
@test levels(x) == ["[1.0, 2.0)", "[2.0, Inf]"]

x = cut([1:5; -Inf], [-Inf, 2, 5])
@test x ≅ ["[-Inf, 2.0)"; fill("[2.0, 5.0]", 4); "[-Inf, 2.0)"]
@test levels(x) == ["[-Inf, 2.0)", "[2.0, 5.0]"]

x = cut([1:5; Inf], 2)
@test x ≅ [fill("Q1: [1.0, 3.5)", 3); fill("Q2: [3.5, Inf]", 3)]
@test levels(x) == ["Q1: [1.0, 3.5)", "Q2: [3.5, Inf]"]

x = cut([1:5; -Inf], 2)
@test x ≅ [fill("Q1: [-Inf, 2.5)", 2); fill("Q2: [2.5, 5.0]", 3); "Q1: [-Inf, 2.5)"]
@test levels(x) == ["Q1: [-Inf, 2.5)", "Q2: [2.5, 5.0]"]
end

end
2 changes: 2 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@ module TestCategoricalArrays
using Test
using CategoricalArrays

const ≊ = isequal

tests = [
"01_value.jl",
"04_constructors.jl",
Expand Down
Loading