Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

improve performance of recode! for array dest #355

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 38 additions & 13 deletions src/recode.jl
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@ recode!(dest::AbstractArray, src::AbstractArray, pairs::Pair...) =
# To fix ambiguity
recode!(dest::CategoricalArray, src::AbstractArray, pairs::Pair...) =
recode!(dest, src, nothing, pairs...)
recode!(dest::AbstractArray, src::CategoricalArray, pairs::Pair...) =
recode!(dest, src, nothing, pairs...)
recode!(dest::CategoricalArray, src::CategoricalArray, pairs::Pair...) =
recode!(dest, src, nothing, pairs...)

Expand All @@ -52,22 +54,47 @@ A user defined type could override this method to define an appropriate test fun
optimize_pair(pair::Pair) = pair
optimize_pair(pair::Pair{<:AbstractArray}) = Set(pair.first) => pair.second

function missing_check(value)
ismissing(value) && throw(MissingException("missing value found, but dest does not support them: " *
"recode them to a supported value"))
value
end

function recode!(dest::AbstractArray{T}, src::CategoricalArray, default::Any, pairs::Pair...) where {T}
if length(dest) != length(src)
throw(DimensionMismatch("dest and src must be of the same length (got $(length(dest)) and $(length(src)))"))
end

pairs = map(pairs) do p
p.first => convert(T, p.second)
end
recoded = recode(src, default, pairs...)
if T >: Missing
dest .= unwrap.(recoded)
else
dest .= missing_check.(unwrap.(recoded))
end
Comment on lines +68 to +76
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Rather than doing this, to avoid making a copy and two passes over the data, we should call recode on levels(src), and then do something like:

@inbounds for i in eachindex(dest, src)
    dest[i] = newlevels[src.refs[i]+1]
end

The actual implementation needs to be a bit more complex so that the first entry in newlevels is missing (to handle the case when src.refs is 0).

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Rather than doing this, to avoid making a copy and two passes over the data

By copy do you mean copy of the src.levels? In this implementation no copy of the actual array (or refs) is made which is the main reason why it is so much faster (as outlined in my StackOverflow answer) is because all the actual copying of the refs happens only once at the last line dest .= unwrap.(recoded) the recoded variable shares the refs with src.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

recode(src, default, pairs...) allocates a new vector, right? That's relatively fast, but it's even better to avoid it.

Copy link
Author

@ahnlabb ahnlabb Jul 24, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You're right, I remembered the details wrong. In the StackOverflow example I did:

mapping = Dict("X"=>1, "Y"=>2, "Z"=>3)
b = CategoricalArray{Int64,1,UInt32}(undef, 0)
b.refs = a.refs
levels!(b.pool, [mapping[l] for l in levels(a.pool)])

which is similar to what you're suggesting. However, in this PR we initialize the CategoricalArray that will be put in the recoded variable with something like CategoricalArray{S, N, R}(undef, size(a)) so the refs are not shared between src and recoded.

EDIT: Like I noted on SO using levels! does not work in the general case

dest
end

function recode!(dest::AbstractArray{T}, src::AbstractArray, default::Any, pairs::Pair...) where {T}
if length(dest) != length(src)
throw(DimensionMismatch("dest and src must be of the same length (got $(length(dest)) and $(length(src)))"))
end

opt_pairs = map(optimize_pair, pairs)

@inbounds for i in eachindex(dest, src)
x = src[i]
map!(dest, src) do x

for j in 1:length(opt_pairs)
p = opt_pairs[j]
# we use isequal and recode_in because we cannot really distinguish scalars from collections
if x ≅ p.first || recode_in(x, p.first)
dest[i] = p.second
@goto nextitem
# we use isequal and recode_in because we cannot really distinguish scalars from collections
for p in opt_pairs
if x ≅ p.first
return p.second
end
end
for p in opt_pairs
if recode_in(x, p.first)
return p.second
end
Comment on lines +90 to 98
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This could change the behavior in case of overlap between pairs. Why did you change this?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since this was a while a go I'll need to take some time with it and rerun my (micro-)benchmarks to be sure but unless my memory fails me recode_in was a performance bottleneck and splitting the checks (in addition to switching to map!) made a noticeable difference for highly optimizable cases. You're absolutely right that it is a breaking change, and should have been highlighted in the PR since it warrants discussion. I spent some time trying to get recode_in to optimize away but was not satisfied with the result. The most troublesome part is of course the any(x ≅ y for y in collection) for the case when collection is a primitive. I'll get back to you with data.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK. Maybe better do this in a separate PR since it's a bit more tricky.

end

Expand All @@ -76,21 +103,19 @@ function recode!(dest::AbstractArray{T}, src::AbstractArray, default::Any, pairs
eltype(dest) >: Missing ||
throw(MissingException("missing value found, but dest does not support them: " *
"recode them to a supported value"))
dest[i] = missing
return missing
elseif default isa Nothing
try
dest[i] = x
return convert(T, x)
catch err
isa(err, MethodError) || rethrow(err)
throw(ArgumentError("cannot `convert` value $(repr(x)) (of type $(typeof(x))) to type of recoded levels ($T). " *
"This will happen with recode() when not all original levels are recoded " *
"(i.e. some are preserved) and their type is incompatible with that of recoded levels."))
end
else
dest[i] = default
return default
end

@label nextitem
nalimilan marked this conversation as resolved.
Show resolved Hide resolved
end

dest
Expand Down