Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

make recode! type stable #407

Merged
merged 17 commits into from
Jan 3, 2025
Merged
Show file tree
Hide file tree
Changes from 15 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ uuid = "324d7699-5711-5eae-9e2f-1d82baa6b597"
version = "0.10.8"

[deps]
Compat = "34da2185-b29b-5c13-b0c7-acf172513d20"
DataAPI = "9a962f9c-6df0-11e9-0e5d-c546b8b5ee8a"
Future = "9fa8497b-333b-5362-9e8d-4d0656e87820"
Missings = "e1d29d7a-bbdc-5cf2-9ac0-f12de2c33e28"
Expand All @@ -24,6 +25,7 @@ CategoricalArraysSentinelArraysExt = "SentinelArrays"
CategoricalArraysStructTypesExt = "StructTypes"

[compat]
Compat = "3.37, 4"
DataAPI = "1.6"
JSON = "0.15, 0.16, 0.17, 0.18, 0.19, 0.20, 0.21"
JSON3 = "1.1.2"
Expand Down
1 change: 1 addition & 0 deletions src/CategoricalArrays.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ module CategoricalArrays
using DataAPI
using Missings
using Printf
import Compat

# JuliaLang/julia#36810
if VERSION < v"1.5.2"
Expand Down
84 changes: 38 additions & 46 deletions src/recode.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import Compat
tiemvanderdeure marked this conversation as resolved.
Show resolved Hide resolved
const ≅ = isequal

"""
Expand Down Expand Up @@ -52,27 +53,33 @@ A user defined type could override this method to define an appropriate test fun
optimize_pair(pair::Pair) = pair
optimize_pair(pair::Pair{<:AbstractArray}) = Set(pair.first) => pair.second

function recode!(dest::AbstractArray{T}, src::AbstractArray, default::Any, pairs::Pair...) where {T}
function recode!(dest::AbstractArray, src::AbstractArray, default::Any, pairs::Pair...)
if length(dest) != length(src)
throw(DimensionMismatch("dest and src must be of the same length (got $(length(dest)) and $(length(src)))"))
end

opt_pairs = map(optimize_pair, pairs)
opt_pairs = optimize_pair.(pairs)

_recode!(dest, src, default, opt_pairs)
end

function _recode!(dest::AbstractArray{T}, src::AbstractArray, default, pairs::NTuple{<:Any, Pair}) where {T}
tiemvanderdeure marked this conversation as resolved.
Show resolved Hide resolved
recode_to = last.(pairs)
recode_from = first.(pairs)

@inbounds for i in eachindex(dest, src)
x = src[i]

for j in 1:length(opt_pairs)
p = opt_pairs[j]
# we use isequal and recode_in because we cannot really distinguish scalars from collections
if x ≅ p.first || recode_in(x, p.first)
dest[i] = p.second
@goto nextitem
end
end

# @inline is needed for type stability and Compat for compatibility before julia v1.8
# we use isequal and recode_in because we cannot really
# distinguish scalars from collections
j = Compat.@inline findfirst(y -> isequal(x, y) || recode_in(x,y), recode_from)
tiemvanderdeure marked this conversation as resolved.
Show resolved Hide resolved

# Value in one of the pairs
if j !== nothing
dest[i] = recode_to[j]
# Value not in any of the pairs
if ismissing(x)
elseif ismissing(x)
eltype(dest) >: Missing ||
throw(MissingException("missing value found, but dest does not support them: " *
"recode them to a supported value"))
Expand All @@ -89,21 +96,16 @@ function recode!(dest::AbstractArray{T}, src::AbstractArray, default::Any, pairs
else
dest[i] = default
end

@label nextitem
end

dest
end

function recode!(dest::CategoricalArray{T}, src::AbstractArray, default::Any, pairs::Pair...) where {T}
if length(dest) != length(src)
throw(DimensionMismatch("dest and src must be of the same length (got $(length(dest)) and $(length(src)))"))
end

opt_pairs = map(optimize_pair, pairs)
function _recode!(dest::CategoricalArray{T, <:Any, R}, src::AbstractArray, default::Any,
pairs::NTuple{<:Any, Pair}) where {T, R}
tiemvanderdeure marked this conversation as resolved.
Show resolved Hide resolved
recode_from = first.(pairs)
vals = T[p.second for p in pairs]
tiemvanderdeure marked this conversation as resolved.
Show resolved Hide resolved

vals = T[p.second for p in opt_pairs]
default !== nothing && push!(vals, default)

levels!(dest.pool, filter!(!ismissing, unique(vals)))
Expand All @@ -112,22 +114,18 @@ function recode!(dest::CategoricalArray{T}, src::AbstractArray, default::Any, pa
dupvals = length(vals) != length(levels(dest.pool))

drefs = dest.refs
pairmap = [ismissing(v) ? 0 : get(dest.pool, v) for v in vals]
defaultref = default === nothing || ismissing(default) ? 0 : get(dest.pool, default)
pairmap = [ismissing(v) ? zero(R) : get(dest.pool, v) for v in vals]
defaultref = default === nothing || ismissing(default) ? zero(R) : get(dest.pool, default)

@inbounds for i in eachindex(drefs, src)
x = src[i]

for j in 1:length(opt_pairs)
p = opt_pairs[j]
# we use isequal and recode_in because we cannot really distinguish scalars from collections
if x ≅ p.first || recode_in(x, p.first)
drefs[i] = dupvals ? pairmap[j] : j
@goto nextitem
end
end

# Value not in any of the pairs
if ismissing(x)
# we use isequal and recode_in because we cannot really
tiemvanderdeure marked this conversation as resolved.
Show resolved Hide resolved
# distinguish scalars from collections
j = Compat.@inline findfirst(y -> isequal(x, y) || recode_in(x, y), recode_from)
if j !== nothing
drefs[i] = dupvals ? pairmap[j] : j
elseif ismissing(x)
eltype(dest) >: Missing ||
throw(MissingException("missing value found, but dest does not support them: " *
"recode them to a supported value"))
Expand All @@ -144,8 +142,6 @@ function recode!(dest::CategoricalArray{T}, src::AbstractArray, default::Any, pa
else
drefs[i] = defaultref
end

@label nextitem
end

# Put existing levels first, and sort them if possible
Expand All @@ -168,25 +164,21 @@ function recode!(dest::CategoricalArray{T}, src::AbstractArray, default::Any, pa
dest
end

function recode!(dest::CategoricalArray{T, N, R}, src::CategoricalArray,
default::Any, pairs::Pair...) where {T, N, R<:Integer}
if length(dest) != length(src)
throw(DimensionMismatch("dest and src must be of the same length " *
"(got $(length(dest)) and $(length(src)))"))
end

function _recode!(dest::CategoricalArray{T, N, R}, src::CategoricalArray,
default::Any, pairs::NTuple{<:Any, Pair}) where {T, N, R<:Integer}
tiemvanderdeure marked this conversation as resolved.
Show resolved Hide resolved
recode_from = first.(pairs)
vals = T[p.second for p in pairs]

if default === nothing
srclevels = levels(src)
tiemvanderdeure marked this conversation as resolved.
Show resolved Hide resolved

# Remove recoded levels as they won't appear in result
firsts = (p.first for p in pairs)
keptlevels = Vector{T}(undef, 0)
sizehint!(keptlevels, length(srclevels))

for l in srclevels
if !(any(x -> x ≅ l, firsts) ||
any(f -> recode_in(l, f), firsts))
if !(any(x -> x ≅ l, recode_from) ||
any(f -> recode_in(l, f), recode_from))
try
push!(keptlevels, l)
catch err
Expand Down
Loading