From 0581607316b5260ca05310417e6f46903cadd00d Mon Sep 17 00:00:00 2001 From: tiemvanderdeure Date: Fri, 22 Nov 2024 14:16:42 +0100 Subject: [PATCH] pass first and last pairs as tuples --- src/recode.jl | 52 +++++++++++++++++++++++++-------------------------- 1 file changed, 25 insertions(+), 27 deletions(-) diff --git a/src/recode.jl b/src/recode.jl index c72b8788..1b6aeae9 100644 --- a/src/recode.jl +++ b/src/recode.jl @@ -45,11 +45,8 @@ The default method is to test if any element in the `collection` `isequal` to `x`. For `Set`s `in` is used as it is faster than the default method and equivalent to it. A user defined type could override this method to define an appropriate test function. """ -recode_in(x, ::Missing) = false -recode_in(::Missing, ::Missing) = true -recode_in(x, collection::Set) = x in collection -recode_in(x, collection) = x ≅ collection || any(x ≅ y for y in collection) -recode_in(x::T, y::T) where T = x === y +@inline recode_in(x, collection) = any(x ≅ y for y in collection) +@inline recode_in(x, ::Missing) = false optimize_pair(pair::Pair) = pair optimize_pair(pair::Pair{<:AbstractArray}) = Set(pair.first) => pair.second @@ -59,21 +56,25 @@ function recode!(dest::AbstractArray, src::AbstractArray, default::Any, pairs::P throw(DimensionMismatch("dest and src must be of the same length (got $(length(dest)) and $(length(src)))")) end - opt_pairs = map(optimize_pair, pairs) + opt_pairs = optimize_pair.(pairs) - _recode!(dest, src, default, opt_pairs...) + if dest isa CategoricalArray && src isa CategoricalArray + # in this case, we don't need to do much for type stability + _recode!(dest, src, default, opt_pairs...) + else + # in these cases, this is only type stable if we pass the pairs as tuples + _recode!(dest, src, default, first.(opt_pairs), last.(opt_pairs)) + end end -function _recode!(dest::AbstractArray{T}, src::AbstractArray, default::Any, pairs::Pair...) where {T} +function _recode!(dest::AbstractArray{T}, src::AbstractArray, default, recode_from::Tuple, recode_to::Tuple) where {T} @inbounds for i in eachindex(dest, src) x = src[i] - for p in pairs - # we use isequal and recode_in because we cannot really distinguish scalars from collections - if recode_in(x, p.first) - dest[i] = p.second - @goto nextitem - end + j = findfirst(y -> isequal(x, y) || recode_in(x,y), recode_from) + if !isnothing(j) + dest[i] = recode_to[j] + @goto nextitem end # Value not in any of the pairs @@ -101,9 +102,9 @@ function _recode!(dest::AbstractArray{T}, src::AbstractArray, default::Any, pair dest end -function _recode!(dest::CategoricalArray{T}, src::AbstractArray, default::Any, pairs::Pair...) where {T} - vals = T[p.second for p in pairs] - default !== nothing && push!(vals, default) +function _recode!(dest::CategoricalArray{T, <:Any, R}, src::AbstractArray, default::Any, recode_from::Tuple, recode_to::Tuple) where {T, R} + vals = convert.(T, recode_to) + vals = default === nothing ? vals : (vals..., default) levels!(dest.pool, filter!(!ismissing, unique(vals))) # In the absence of duplicated recoded values, we do not need to lookup the reference @@ -111,19 +112,16 @@ function _recode!(dest::CategoricalArray{T}, src::AbstractArray, default::Any, p dupvals = length(vals) != length(levels(dest.pool)) drefs = dest.refs - pairmap = [ismissing(v) ? 0 : get(dest.pool, v) for v in vals] - defaultref = default === nothing || ismissing(default) ? 0 : get(dest.pool, default) + pairmap = [ismissing(v) ? zero(R) : get(dest.pool, v) for v in vals] + defaultref = default === nothing ? nothing : ismissing(default) ? 0 : get(dest.pool, default) @inbounds for i in eachindex(drefs, src) x = src[i] - for j in eachindex(pairs) - p = pairs[j] - # we use isequal and recode_in because we cannot really distinguish scalars from collections - if recode_in(x, p.first) - drefs[i] = dupvals ? pairmap[j] : j - @goto nextitem - end + j = findfirst(y -> isequal(x, y) || recode_in(x,y), recode_from) + if !isnothing(j) + drefs[i] = dupvals ? pairmap[j] : j + @goto nextitem end # Value not in any of the pairs @@ -228,7 +226,7 @@ function _recode!(dest::CategoricalArray{T, N, R}, src::CategoricalArray, @inbounds for (i, l) in enumerate(srclevels) for j in 1:length(pairs) p = pairs[j] - if recode_in(l, p.first) + if l ≅ p.first ||recode_in(l, p.first) levelsmap[i+1] = pairmap[j] @goto nextitem end