Skip to content

Commit

Permalink
pass first and last pairs as tuples
Browse files Browse the repository at this point in the history
  • Loading branch information
tiemvanderdeure committed Nov 22, 2024
1 parent b412ce4 commit 0581607
Showing 1 changed file with 25 additions and 27 deletions.
52 changes: 25 additions & 27 deletions src/recode.jl
Original file line number Diff line number Diff line change
Expand Up @@ -45,11 +45,8 @@ The default method is to test if any element in the `collection` `isequal` to
`x`. For `Set`s `in` is used as it is faster than the default method and equivalent to it.
A user defined type could override this method to define an appropriate test function.
"""
recode_in(x, ::Missing) = false
recode_in(::Missing, ::Missing) = true
recode_in(x, collection::Set) = x in collection
recode_in(x, collection) = x collection || any(x y for y in collection)
recode_in(x::T, y::T) where T = x === y
@inline recode_in(x, collection) = any(x y for y in collection)
@inline recode_in(x, ::Missing) = false

optimize_pair(pair::Pair) = pair
optimize_pair(pair::Pair{<:AbstractArray}) = Set(pair.first) => pair.second
Expand All @@ -59,21 +56,25 @@ function recode!(dest::AbstractArray, src::AbstractArray, default::Any, pairs::P
throw(DimensionMismatch("dest and src must be of the same length (got $(length(dest)) and $(length(src)))"))
end

opt_pairs = map(optimize_pair, pairs)
opt_pairs = optimize_pair.(pairs)

_recode!(dest, src, default, opt_pairs...)
if dest isa CategoricalArray && src isa CategoricalArray
# in this case, we don't need to do much for type stability
_recode!(dest, src, default, opt_pairs...)
else
# in these cases, this is only type stable if we pass the pairs as tuples
_recode!(dest, src, default, first.(opt_pairs), last.(opt_pairs))
end
end

function _recode!(dest::AbstractArray{T}, src::AbstractArray, default::Any, pairs::Pair...) where {T}
function _recode!(dest::AbstractArray{T}, src::AbstractArray, default, recode_from::Tuple, recode_to::Tuple) where {T}
@inbounds for i in eachindex(dest, src)
x = src[i]

for p in pairs
# we use isequal and recode_in because we cannot really distinguish scalars from collections
if recode_in(x, p.first)
dest[i] = p.second
@goto nextitem
end
j = findfirst(y -> isequal(x, y) || recode_in(x,y), recode_from)
if !isnothing(j)
dest[i] = recode_to[j]
@goto nextitem
end

# Value not in any of the pairs
Expand Down Expand Up @@ -101,29 +102,26 @@ function _recode!(dest::AbstractArray{T}, src::AbstractArray, default::Any, pair
dest
end

function _recode!(dest::CategoricalArray{T}, src::AbstractArray, default::Any, pairs::Pair...) where {T}
vals = T[p.second for p in pairs]
default !== nothing && push!(vals, default)
function _recode!(dest::CategoricalArray{T, <:Any, R}, src::AbstractArray, default::Any, recode_from::Tuple, recode_to::Tuple) where {T, R}
vals = convert.(T, recode_to)
vals = default === nothing ? vals : (vals..., default)

levels!(dest.pool, filter!(!ismissing, unique(vals)))
# In the absence of duplicated recoded values, we do not need to lookup the reference
# for each pair in the loop, which is more efficient (with loop unswitching)
dupvals = length(vals) != length(levels(dest.pool))

drefs = dest.refs
pairmap = [ismissing(v) ? 0 : get(dest.pool, v) for v in vals]
defaultref = default === nothing || ismissing(default) ? 0 : get(dest.pool, default)
pairmap = [ismissing(v) ? zero(R) : get(dest.pool, v) for v in vals]
defaultref = default === nothing ? nothing : ismissing(default) ? 0 : get(dest.pool, default)

@inbounds for i in eachindex(drefs, src)
x = src[i]

for j in eachindex(pairs)
p = pairs[j]
# we use isequal and recode_in because we cannot really distinguish scalars from collections
if recode_in(x, p.first)
drefs[i] = dupvals ? pairmap[j] : j
@goto nextitem
end
j = findfirst(y -> isequal(x, y) || recode_in(x,y), recode_from)
if !isnothing(j)
drefs[i] = dupvals ? pairmap[j] : j
@goto nextitem
end

# Value not in any of the pairs
Expand Down Expand Up @@ -228,7 +226,7 @@ function _recode!(dest::CategoricalArray{T, N, R}, src::CategoricalArray,
@inbounds for (i, l) in enumerate(srclevels)
for j in 1:length(pairs)
p = pairs[j]
if recode_in(l, p.first)
if l p.first ||recode_in(l, p.first)
levelsmap[i+1] = pairmap[j]
@goto nextitem
end
Expand Down

0 comments on commit 0581607

Please sign in to comment.