Skip to content

Commit

Permalink
Add field array support to expand comm pattern
Browse files Browse the repository at this point in the history
  • Loading branch information
lcw committed Jul 23, 2024
1 parent 296646d commit 3dfa96a
Show file tree
Hide file tree
Showing 3 changed files with 46 additions and 9 deletions.
38 changes: 33 additions & 5 deletions src/communication.jl
Original file line number Diff line number Diff line change
Expand Up @@ -45,18 +45,46 @@ function expand(r::UnitRange, factor)
return a:b
end

expand(v::AbstractVector, factor) = vec((0x1:factor) .+ (v' .- 0x1) .* factor)
@kernel function expand_vector_kernel!(dest, src, factor, offset)
i = @index(Global)

@inbounds begin
b, a = fldmod1(src[i], offset)

for f = 1:factor
dest[f+(i-0x1)*factor] = a + (f - 0x1) * offset + (b - 0x1) * factor * offset
end
end
end

function expand(v::AbstractVector, factor, offset = 1)
w = similar(v, length(v) * factor)

expand_vector_kernel!(get_backend(w), 256)(w, v, factor, offset, ndrange = length(v))

return w
end

"""
expand(patten::CommPattern, factor, offset)

Check warning on line 69 in src/communication.jl

View workflow job for this annotation

GitHub Actions / Spell Check with Typos

"patten" should be "pattern" or "patent".
Create a new `CommPattern` where the `recvindices` and `sendindices` are
expanded in size by `factor` entries and offset by `offset`.
function expand(pattern::CommPattern{AT}, factor) where {AT}
recvindices = expand(pattern.recvindices, factor)
For example, to expand a `nodecommpattern` to communicate all fields of an
array that is indexed via `(node, field, element)` use
`expand(nodecommpattern(grid), nfields, nnodes)`.
"""
function expand(pattern::CommPattern{AT}, factor, offset = 1) where {AT}
recvindices = expand(pattern.recvindices, factor, offset)
recvranks = copy(pattern.recvranks)
recvrankindices = similar(pattern.recvrankindices)
@assert eltype(recvrankindices) <: UnitRange
for i in eachindex(recvrankindices)
recvrankindices[i] = expand(pattern.recvrankindices[i], factor)
end

sendindices = expand(pattern.sendindices, factor)
sendindices = expand(pattern.sendindices, factor, offset)
sendranks = copy(pattern.sendranks)
sendrankindices = similar(pattern.sendrankindices)
@assert eltype(sendrankindices) <: UnitRange
Expand Down Expand Up @@ -109,7 +137,7 @@ end

function _get_mpi_buffers(buffer, rankindices)
# Hack to make the element type of the buffer arrays concrete
@assert eltype(rankindices) == typeof(1:length(rankindices))
@assert eltype(rankindices) == UnitRange{eltype(eltype(rankindices))}
T = typeof(view(buffer, 1:length(rankindices)))

bufs = Array{MPI.Buffer{T}}(undef, length(rankindices))
Expand Down
4 changes: 2 additions & 2 deletions src/grids.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ floattype(grid::AbstractGrid) = floattype(typeof(grid))
arraytype(grid::AbstractGrid) = arraytype(typeof(grid))
celltype(grid::AbstractGrid) = celltype(typeof(grid))

struct Grid{C<:AbstractCell,P,V,S,L,T,F,B,PN,N,CTOD,DTOC,CC,FM} <: AbstractGrid{C}
struct Grid{C<:AbstractCell,P,V,S,L,T,F,B,PN,N,CTOD,DTOC,CC,NCC,FM} <: AbstractGrid{C}
comm::MPI.Comm
part::Int
nparts::Int
Expand All @@ -27,7 +27,7 @@ struct Grid{C<:AbstractCell,P,V,S,L,T,F,B,PN,N,CTOD,DTOC,CC,FM} <: AbstractGrid{
continuoustodiscontinuous::CTOD
discontinuoustocontinuous::DTOC
communicatingcells::CC
noncommunicatingcells::CC
noncommunicatingcells::NCC
facemaps::FM
end

Expand Down
13 changes: 11 additions & 2 deletions test/communication.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
sendranks = Cint[1, 2]
sendrankindices = [1:2, 3:5]

pattern = Raven.CommPattern{Array}(
originalpattern = Raven.CommPattern{Array}(
recvindices,
recvranks,
recvrankindices,
Expand All @@ -23,7 +23,16 @@
sendrankindices,
)

pattern = Raven.expand(pattern, 3)
pattern = Raven.expand(originalpattern, 3, 5)

@test pattern.recvranks == recvranks
@test pattern.recvindices == [17, 22, 27, 18, 23, 28, 19, 24, 29, 20, 25, 30]
@test pattern.recvrankindices == UnitRange{Int64}[1:3, 4:12]
@test pattern.sendranks == sendranks
@test pattern.sendindices == [1, 6, 11, 3, 8, 13, 1, 6, 11, 3, 8, 13, 5, 10, 15]
@test pattern.sendrankindices == UnitRange{Int64}[1:6, 7:15]

pattern = Raven.expand(originalpattern, 3)

@test pattern.recvranks == recvranks
@test pattern.recvindices == [19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30]
Expand Down

0 comments on commit 3dfa96a

Please sign in to comment.