Skip to content

Commit

Permalink
Improve handling of -0.0
Browse files Browse the repository at this point in the history
  • Loading branch information
nalimilan committed Dec 27, 2024
1 parent 5664fef commit 4c691ab
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 32 deletions.
42 changes: 16 additions & 26 deletions src/extras.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,10 @@ function fill_refs!(refs::AbstractArray, X::AbstractArray,

if ismissing(x)
refs[i] = 0
elseif x == upper
elseif isequal(x, upper)
refs[i] = n-1
elseif extend !== true && !(lower <= x <= upper)
elseif extend !== true &&
!((isless(lower, x) || isequal(x, lower)) && isless(x, upper))
extend === missing ||
throw(ArgumentError("value $x (at index $i) does not fall inside the breaks: " *
"adapt them manually, or pass extend=true or extend=missing"))
Expand Down Expand Up @@ -136,14 +137,9 @@ function _cut(x::AbstractArray{T, N}, breaks::AbstractVector,
breaks = sort(breaks)
end

if !allowempty
num_eq = 0
for i in 2:length(breaks)
num_eq += breaks[i] == breaks[i-1]
end
num_eq > 0 &&
throw(ArgumentError("all breaks other than the last one must be unique " *
"unless `allowempty=true`"))
if !allowempty && !allunique(@view breaks[1:end-1])
throw(ArgumentError("all breaks other than the last one must be unique " *
"unless `allowempty=true`"))
end

if extend === true
Expand All @@ -164,11 +160,11 @@ function _cut(x::AbstractArray{T, N}, breaks::AbstractVector,
rethrow(err)
end
end
if !ismissing(min_x) && breaks[1] > min_x
if !ismissing(min_x) && isless(min_x, breaks[1])
# this type annotation is needed on Julia<1.7 for stable inference
breaks = [min_x::nonmissingtype(eltype(x)); breaks]
end
if !ismissing(max_x) && breaks[end] < max_x
if !ismissing(max_x) && isless(breaks[end], max_x)
breaks = [breaks; max_x::nonmissingtype(eltype(x))]
end
length(breaks) > 1 ||
Expand All @@ -195,12 +191,12 @@ function _cut(x::AbstractArray{T, N}, breaks::AbstractVector,
from = breaks[1:n-1]
to = breaks[2:n]
firstlevel = labels(from[1], to[1], 1,
leftclosed=breaks[1] != breaks[2], rightclosed=false)
leftclosed=!isequal(breaks[1], breaks[2]), rightclosed=false)
levs = Vector{typeof(firstlevel)}(undef, n-1)
levs[1] = firstlevel
for i in 2:n-2
levs[i] = labels(from[i], to[i], i,
leftclosed=breaks[i] != breaks[i+1], rightclosed=false)
leftclosed=!isequal(breaks[i], breaks[i+1]), rightclosed=false)
end
levs[end] = labels(from[end], to[end], n-1,
leftclosed=true, rightclosed=true)
Expand Down Expand Up @@ -261,18 +257,12 @@ function cut(x::AbstractArray, ngroups::Integer;
# Computing extrema is faster than taking 0 and 1 quantiles
min_x, max_x = extrema(xnm)
breaks = [min_x; breaks; max_x]
if !allowempty # Only two last breaks are allowed to be equal
num_eq = 0
for i in 2:length(breaks)
num_eq += breaks[i] == breaks[i-1]
end
if num_eq > 0
n = length(breaks) - num_eq
throw(ArgumentError("cannot compute $ngroups quantiles: `quantile` " *
"returned only $n group(s) due to duplicated values in `x`." *
"Pass `allowempty=true` to allow empty quantiles or " *
"choose a lower value for `ngroups`."))
end
if !allowempty && !allunique(@view breaks[1:end-1])
n = length(unique(breaks))
throw(ArgumentError("cannot compute $ngroups quantiles: `quantile` " *
"returned only $n group(s) due to duplicated values in `x`." *
"Pass `allowempty=true` to allow empty quantiles or " *
"choose a lower value for `ngroups`."))
end
cut(x, breaks; labels=labels, allowempty=allowempty)
end
32 changes: 26 additions & 6 deletions test/15_extras.jl
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,7 @@ end
@test cut(0.0:8.0, 3, labels=[-0.0, 0.0, 1.0]) ==
[-0.0, -0.0, -0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0]

@test cut([-0.0, 0.0, 1.0, 2.0, 3.0, 4.0], [-0.0, 1.0, 5.0], labels=[-0.0, 0.0]) ==
@test cut([-0.0, 0.0, 1.0, 2.0, 3.0, 4.0], [-0.0, 0.0, 5.0], labels=[-0.0, 0.0]) ==
[-0.0, 0.0, 0.0, 0.0, 0.0, 0.0]
end

Expand Down Expand Up @@ -269,13 +269,33 @@ end
x = cut([fill(1, 5); fill(4, 5)], 3, allowempty=true)
@test x == [fill("Q2: [1.0, 4.0)", 5); fill("Q3: [4.0, 4.0]", 5)]
@test levels(x) == ["Q1: (1.0, 1.0)", "Q2: [1.0, 4.0)", "Q3: [4.0, 4.0]"]
end

@testset "cut with -0.0" begin
x = cut([-0.0, 0.0, 0.0, -0.0], 2)
@test x == ["Q1: [-0.0, 0.0)", "Q2: [0.0, 0.0]", "Q2: [0.0, 0.0]", "Q1: [-0.0, 0.0)"]
@test levels(x) == ["Q1: [-0.0, 0.0)", "Q2: [0.0, 0.0]"]

x = cut([-0.0, 0.0, 0.0, -0.0], [-0.0, 0.0, 0.0])
@test x == ["[-0.0, 0.0)", "[0.0, 0.0]", "[0.0, 0.0]", "[-0.0, 0.0)"]
@test levels(x) == ["[-0.0, 0.0)", "[0.0, 0.0]"]

x = cut([-0.0, 0.0, 0.0, -0.0], [-0.0, 0.0])
@test x == fill("[-0.0, 0.0]", 4)
@test levels(x) == ["[-0.0, 0.0]"]

x = cut([-0.0, 0.0, 0.0, -0.0], [0.0], extend=true)
@test x == fill("[-0.0, 0.0]", 4)
@test levels(x) == ["[-0.0, 0.0]"]

x = cut([-0.0, 0.0, 0.0, -0.0], [-0.0], extend=true)
@test x == fill("[-0.0, 0.0]", 4)
@test levels(x) == ["[-0.0, 0.0]"]

@test_throws ArgumentError cut([-0.0, 0.0], 2)
@test_throws ArgumentError cut([-0.0, 0.0], 2, labels=[-0.0, 0.0])
@test_throws ArgumentError cut([-0.0, 0.0], [0.0], extend=true)
@test_throws ArgumentError cut([-0.0, 0.0], [-0.0, 0.0])
x = cut([-0.0, 0.0, 0.0, -0.0], 2, labels=[-0.0, 0.0])
@test x == [-0.0, 0.0, 0.0, -0.0]

@test_throws ArgumentError cut([-0.0, 0.0, 1.0, 2.0, 3.0, 4.0], [-0.0, 0.0, 5.0])
@test_throws ArgumentError cut([-0.0, 0.0, 0.0, -0.0], [-0.0, -0.0, 0.0])
end

@testset "cut with extend=true" begin
Expand Down

0 comments on commit 4c691ab

Please sign in to comment.