Skip to content

Commit

Permalink
Cleanup MPI Requests
Browse files Browse the repository at this point in the history
  • Loading branch information
lcw committed Apr 19, 2024
1 parent 16e6e3f commit 67cd741
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 6 deletions.
10 changes: 8 additions & 2 deletions src/Raven.jl
Original file line number Diff line number Diff line change
Expand Up @@ -65,15 +65,21 @@ if !isdefined(Base, :get_extension)
using Requires
end

@static if !isdefined(Base, :get_extension)
function __init__()
function __init__()
@static if !isdefined(Base, :get_extension)
@require CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" include(
"../ext/RavenCUDAExt.jl",
)
@require WriteVTK = "64499a7a-5c06-52f2-abe2-ccb03c286192" include(
"../ext/RavenWriteVTKExt.jl",
)
end

MPI.add_finalize_hook!() do
for cm in COMM_MANAGERS
finalize(cm.value)
end
end
end

end # module Raven
22 changes: 18 additions & 4 deletions src/communication.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
const COMM_MANAGERS = WeakRef[]

struct CommPattern{AT,RI,RR,RRI,SI,SR,SRI}
recvindices::RI
recvranks::RR
Expand Down Expand Up @@ -74,7 +76,7 @@ end

abstract type AbstractCommManager end

struct CommManagerBuffered{CP,RBD,RB,SBD,SB} <: AbstractCommManager
mutable struct CommManagerBuffered{CP,RBD,RB,SBD,SB} <: AbstractCommManager
comm::MPI.Comm
pattern::CP
tag::Cint
Expand All @@ -86,8 +88,8 @@ struct CommManagerBuffered{CP,RBD,RB,SBD,SB} <: AbstractCommManager
sendrequests::MPI.UnsafeMultiRequest
end

struct CommManagerTripleBuffered{CP,RBC,RBH,RBD,RB,RS,SBC,SBH,SBD,SB,SS} <:
AbstractCommManager
mutable struct CommManagerTripleBuffered{CP,RBC,RBH,RBD,RB,RS,SBC,SBH,SBD,SB,SS} <:
AbstractCommManager
comm::MPI.Comm
pattern::CP
tag::Cint
Expand Down Expand Up @@ -184,7 +186,7 @@ function commmanager(
)
end

return if triplebuffer
cm = if triplebuffer
backend = get_backend(arraytype(pattern))
recvstream = Stream(backend)
sendstream = Stream(backend)
Expand Down Expand Up @@ -218,6 +220,18 @@ function commmanager(
sendrequests,
)
end

finalizer(cm) do cm
for reqs in (cm.recvrequests, cm.sendrequests)
for req in reqs
MPI.free(req)
end
end
end

push!(COMM_MANAGERS, WeakRef(cm))

return cm
end

get_backend(cm::AbstractCommManager) = get_backend(arraytype(cm.pattern))
Expand Down

0 comments on commit 67cd741

Please sign in to comment.