diff --git a/src/Raven.jl b/src/Raven.jl index d2321b8..2dbb57c 100644 --- a/src/Raven.jl +++ b/src/Raven.jl @@ -65,8 +65,8 @@ if !isdefined(Base, :get_extension) using Requires end -@static if !isdefined(Base, :get_extension) - function __init__() +function __init__() + @static if !isdefined(Base, :get_extension) @require CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" include( "../ext/RavenCUDAExt.jl", ) @@ -74,6 +74,12 @@ end "../ext/RavenWriteVTKExt.jl", ) end + + MPI.add_finalize_hook!() do + for cm in COMM_MANAGERS + finalize(cm.value) + end + end end end # module Raven diff --git a/src/communication.jl b/src/communication.jl index 20be420..a865d59 100644 --- a/src/communication.jl +++ b/src/communication.jl @@ -1,3 +1,5 @@ +const COMM_MANAGERS = WeakRef[] + struct CommPattern{AT,RI,RR,RRI,SI,SR,SRI} recvindices::RI recvranks::RR @@ -74,7 +76,7 @@ end abstract type AbstractCommManager end -struct CommManagerBuffered{CP,RBD,RB,SBD,SB} <: AbstractCommManager +mutable struct CommManagerBuffered{CP,RBD,RB,SBD,SB} <: AbstractCommManager comm::MPI.Comm pattern::CP tag::Cint @@ -86,8 +88,8 @@ struct CommManagerBuffered{CP,RBD,RB,SBD,SB} <: AbstractCommManager sendrequests::MPI.UnsafeMultiRequest end -struct CommManagerTripleBuffered{CP,RBC,RBH,RBD,RB,RS,SBC,SBH,SBD,SB,SS} <: - AbstractCommManager +mutable struct CommManagerTripleBuffered{CP,RBC,RBH,RBD,RB,RS,SBC,SBH,SBD,SB,SS} <: + AbstractCommManager comm::MPI.Comm pattern::CP tag::Cint @@ -184,7 +186,7 @@ function commmanager( ) end - return if triplebuffer + cm = if triplebuffer backend = get_backend(arraytype(pattern)) recvstream = Stream(backend) sendstream = Stream(backend) @@ -218,6 +220,18 @@ function commmanager( sendrequests, ) end + + finalizer(cm) do cm + for reqs in (cm.recvrequests, cm.sendrequests) + for req in reqs + MPI.free(req) + end + end + end + + push!(COMM_MANAGERS, WeakRef(cm)) + + return cm end get_backend(cm::AbstractCommManager) = get_backend(arraytype(cm.pattern))