diff --git a/src/communication.jl b/src/communication.jl index a865d59..7d62d12 100644 --- a/src/communication.jl +++ b/src/communication.jl @@ -45,10 +45,38 @@ function expand(r::UnitRange, factor) return a:b end -expand(v::AbstractVector, factor) = vec((0x1:factor) .+ (v' .- 0x1) .* factor) +@kernel function expand_vector_kernel!(dest, src, factor, offset) + i = @index(Global) + + @inbounds begin + b, a = fldmod1(src[i], offset) + + for f = 1:factor + dest[f+(i-0x1)*factor] = a + (f - 0x1) * offset + (b - 0x1) * factor * offset + end + end +end + +function expand(v::AbstractVector, factor, offset = 1) + w = similar(v, length(v) * factor) + + expand_vector_kernel!(get_backend(w), 256)(w, v, factor, offset, ndrange = length(v)) + + return w +end + +""" + expand(pattern::CommPattern, factor, offset) + +Create a new `CommPattern` where the `recvindices` and `sendindices` are +expanded in size by `factor` entries and offset by `offset`. -function expand(pattern::CommPattern{AT}, factor) where {AT} - recvindices = expand(pattern.recvindices, factor) +For example, to expand a `nodecommpattern` to communicate all fields of an +array that is indexed via `(node, field, element)` use +`expand(nodecommpattern(grid), nfields, nnodes)`. +""" +function expand(pattern::CommPattern{AT}, factor, offset = 1) where {AT} + recvindices = expand(pattern.recvindices, factor, offset) recvranks = copy(pattern.recvranks) recvrankindices = similar(pattern.recvrankindices) @assert eltype(recvrankindices) <: UnitRange @@ -56,7 +84,7 @@ function expand(pattern::CommPattern{AT}, factor) where {AT} recvrankindices[i] = expand(pattern.recvrankindices[i], factor) end - sendindices = expand(pattern.sendindices, factor) + sendindices = expand(pattern.sendindices, factor, offset) sendranks = copy(pattern.sendranks) sendrankindices = similar(pattern.sendrankindices) @assert eltype(sendrankindices) <: UnitRange @@ -109,7 +137,7 @@ end function _get_mpi_buffers(buffer, rankindices) # Hack to make the element type of the buffer arrays concrete - @assert eltype(rankindices) == typeof(1:length(rankindices)) + @assert eltype(rankindices) == UnitRange{eltype(eltype(rankindices))} T = typeof(view(buffer, 1:length(rankindices))) bufs = Array{MPI.Buffer{T}}(undef, length(rankindices)) diff --git a/src/grids.jl b/src/grids.jl index 4a0255c..8090a49 100644 --- a/src/grids.jl +++ b/src/grids.jl @@ -8,7 +8,7 @@ floattype(grid::AbstractGrid) = floattype(typeof(grid)) arraytype(grid::AbstractGrid) = arraytype(typeof(grid)) celltype(grid::AbstractGrid) = celltype(typeof(grid)) -struct Grid{C<:AbstractCell,P,V,S,L,T,F,B,PN,N,CTOD,DTOC,CC,FM} <: AbstractGrid{C} +struct Grid{C<:AbstractCell,P,V,S,L,T,F,B,PN,N,CTOD,DTOC,CC,NCC,FM} <: AbstractGrid{C} comm::MPI.Comm part::Int nparts::Int @@ -27,7 +27,7 @@ struct Grid{C<:AbstractCell,P,V,S,L,T,F,B,PN,N,CTOD,DTOC,CC,FM} <: AbstractGrid{ continuoustodiscontinuous::CTOD discontinuoustocontinuous::DTOC communicatingcells::CC - noncommunicatingcells::CC + noncommunicatingcells::NCC facemaps::FM end diff --git a/test/communication.jl b/test/communication.jl index f6c8ca2..9fb6056 100644 --- a/test/communication.jl +++ b/test/communication.jl @@ -14,7 +14,7 @@ sendranks = Cint[1, 2] sendrankindices = [1:2, 3:5] - pattern = Raven.CommPattern{Array}( + originalpattern = Raven.CommPattern{Array}( recvindices, recvranks, recvrankindices, @@ -23,7 +23,16 @@ sendrankindices, ) - pattern = Raven.expand(pattern, 3) + pattern = Raven.expand(originalpattern, 3, 5) + + @test pattern.recvranks == recvranks + @test pattern.recvindices == [17, 22, 27, 18, 23, 28, 19, 24, 29, 20, 25, 30] + @test pattern.recvrankindices == UnitRange{Int64}[1:3, 4:12] + @test pattern.sendranks == sendranks + @test pattern.sendindices == [1, 6, 11, 3, 8, 13, 1, 6, 11, 3, 8, 13, 5, 10, 15] + @test pattern.sendrankindices == UnitRange{Int64}[1:6, 7:15] + + pattern = Raven.expand(originalpattern, 3) @test pattern.recvranks == recvranks @test pattern.recvindices == [19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30]