Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use multithreading in row_group_slots refarray method #2661

Merged
merged 8 commits into from
Mar 22, 2021
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/src/lib/internals.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,4 +15,5 @@ gennames
getmaxwidths
ourshow
ourstrwidth
tforeach
```
71 changes: 39 additions & 32 deletions src/dataframerow/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -338,46 +338,53 @@ function row_group_slots(cols::NTuple{N, AbstractVector},
end
refmap
end
@inbounds for i in eachindex(groups)
local refs_i
let i=i # Workaround for julia#15276
refs_i = map(c -> c[i], refarrays)
end
vals = map((m, r, s, fi) -> m[r-fi+1] * s, refmaps, refs_i, strides, firstinds)
j = sum(vals) + 1
# x < 0 happens with -1 in refmap, which corresponds to missing
if skipmissing && any(x -> x < 0, vals)
j = 0
else
seen[j] = true
tforeach(eachindex(groups), basesize=1_000_000) do i
@inbounds begin
local refs_i
let i=i # Workaround for julia#15276
refs_i = map(c -> c[i], refarrays)
end
vals = map((m, r, s, fi) -> m[r-fi+1] * s, refmaps, refs_i, strides, firstinds)
j = sum(vals) + 1
# x < 0 happens with -1 in refmap, which corresponds to missing
if skipmissing && any(x -> x < 0, vals)
j = 0
else
seen[j] = true
end
groups[i] = j
end
groups[i] = j
end
else
@inbounds for i in eachindex(groups)
local refs_i
let i=i # Workaround for julia#15276
refs_i = map(refarrays, missinginds) do ref, missingind
r = Int(ref[i])
if skipmissing
return r == missingind ? -1 : (r > missingind ? r-1 : r)
else
return r
tforeach(eachindex(groups), basesize=1_000_000) do i
@inbounds begin
local refs_i
let i=i # Workaround for julia#15276
refs_i = map(refarrays, missinginds) do ref, missingind
r = Int(ref[i])
if skipmissing
return r == missingind ? -1 : (r > missingind ? r-1 : r)
else
return r
end
end
end
vals = map((r, s, fi) -> (r-fi) * s, refs_i, strides, firstinds)
j = sum(vals) + 1
# x < 0 happens with -1, which corresponds to missing
if skipmissing && any(x -> x < 0, vals)
j = 0
else
seen[j] = true
end
groups[i] = j
end
vals = map((r, s, fi) -> (r-fi) * s, refs_i, strides, firstinds)
j = sum(vals) + 1
# x < 0 happens with -1, which corresponds to missing
if skipmissing && any(x -> x < 0, vals)
j = 0
else
seen[j] = true
end
groups[i] = j
end
end
if !all(seen) # Compress group indices to remove unused ones
# If some groups are unused, compress group indices to drop them
# sum(seen) is faster than all(seen) when not short-circuiting,
# and short-circuit would only happen in the slower case anyway
if sum(seen) < length(seen)
oldngroups = ngroups
remap = zeros(Int, ngroups)
ngroups = 0
Expand Down
45 changes: 45 additions & 0 deletions src/other/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -83,3 +83,48 @@ else
end

funname(c::ComposedFunction) = Symbol(funname(c.outer), :_, funname(c.inner))

# Compute chunks of indices, each with at least `basesize` entries
# This method ensures balanced sizes by avoiding a small last chunk
function split_indices(len::Integer, basesize::Integer)
len′ = Int64(len) # Avoid overflow on 32-bit machines
np = max(1, div(len, basesize))
return (Int(1 + ((i - 1) * len′) ÷ np):Int((i * len′) ÷ np) for i in 1:np)
end

"""
tforeach(f, x::AbstractArray; basesize::Integer)

Apply function `f` to each entry in `x` in parallel, spawning
one separate task for each block of at least `basesize` entries.

A number of task higher than `Threads.nthreads()` may be spawned,
since that can allow for a more efficient load balancing in case
some threads are busy (nested parallelism).
"""
function tforeach(f, x::AbstractArray; basesize::Integer)
@assert firstindex(x) == 1

@static if VERSION >= v"1.4"
nt = Threads.nthreads()
len = length(x)
if nt > 1 && len > basesize
@sync for p in split_indices(len, basesize)
Threads.@spawn begin
for i in p
f(@inbounds x[i])
end
end
end
else
for i in eachindex(x)
f(@inbounds x[i])
end
end
else
for i in eachindex(x)
f(@inbounds x[i])
end
end
return
end
18 changes: 18 additions & 0 deletions test/grouping.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3831,4 +3831,22 @@ end
((x, y, z) -> x[1] <= 5 ? unwrap(y[1]) : unwrap(z[1])) => :res)
end

@testset "groupby multithreading" begin
for x in (PooledArray(rand(1:10, 1_100_000)),
PooledArray(rand([1:9; missing], 1_100_000))),
y in (PooledArray(rand(["a", "b", "c", "d"], 1_100_000)),
PooledArray(rand(["a"; "b"; "c"; missing], 1_100_000)))
df = DataFrame(x=x, y=y)

# Checks are done by groupby_checked
@test length(groupby_checked(df, :x)) == 10
@test length(groupby_checked(df, :x, skipmissing=true)) ==
length(unique(skipmissing(x)))

@test length(groupby_checked(df, [:x, :y])) == 40
@test length(groupby_checked(df, [:x, :y], skipmissing=true)) ==
length(unique(skipmissing(x))) * length(unique(skipmissing(y)))
end
end

end # module
24 changes: 24 additions & 0 deletions test/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -101,4 +101,28 @@ end
@test fetch(t) === true
end

@testset "split_indices" begin
for len in 0:12
basesize = 10
x = DataFrames.split_indices(len, basesize)

@test length(x) == max(1, div(len, basesize))
@test reduce(vcat, x) === 1:len
vmin, vmax = extrema(length, x)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To fix 1.0:

Suggested change
vmin, vmax = extrema(length, x)
vmin, vmax = extrema(length(v) for v in x)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah - this pesky 1.0 LTS :)

@test vmin + 1 == vmax || vmin == vmax
@test len < basesize || vmin >= basesize
end

# Check overflow on 32-bit
len = typemax(Int32)
basesize = 100_000_000
x = collect(DataFrames.split_indices(len, basesize))
@test length(x) == div(len, basesize)
@test x[1][1] === 1
@test x[end][end] === Int(len)
vmin, vmax = extrema(length, x)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
vmin, vmax = extrema(length, x)
vmin, vmax = extrema(length(v) for v in x)

@test vmin + 1 == vmax || vmin == vmax
@test len < basesize || vmin >= basesize
end

end # module