diff --git a/src/communication.jl b/src/communication.jl index 7d62d12..d40bf17 100644 --- a/src/communication.jl +++ b/src/communication.jl @@ -299,21 +299,38 @@ function progress(::AbstractCommManager) end function start!(A, cm::CommManagerBuffered) - setbuffer!(cm.sendbufferdevice, A, cm.pattern.sendindices) - KernelAbstractions.synchronize(get_backend(cm)) + if !isempty(cm.sendrequests) + setbuffer!(cm.sendbufferdevice, A, cm.pattern.sendindices) + end + + if !isempty(cm.sendrequests) || !isempty(cm.recvrequests) + KernelAbstractions.synchronize(get_backend(cm)) + end + + if !isempty(cm.recvrequests) + MPI.Startall(cm.recvrequests) + end - MPI.Startall(cm.recvrequests) - MPI.Startall(cm.sendrequests) + if !isempty(cm.sendrequests) + MPI.Startall(cm.sendrequests) + end return end function finish!(A, cm::CommManagerBuffered) - MPI.Waitall(cm.recvrequests) - MPI.Waitall(cm.sendrequests) + if !isempty(cm.recvrequests) + MPI.Waitall(cm.recvrequests) + end + + if !isempty(cm.sendrequests) + MPI.Waitall(cm.sendrequests) + end - A = viewwithghosts(A) - getbuffer!(A, cm.recvbufferdevice, cm.pattern.recvindices) + if !isempty(cm.recvrequests) + A = viewwithghosts(A) + getbuffer!(A, cm.recvbufferdevice, cm.pattern.recvindices) + end return end