From 3bc64fe5c7336b850eecfb6de8a201eb26fab754 Mon Sep 17 00:00:00 2001 From: Lucas C Wilcox Date: Tue, 6 Aug 2024 12:38:53 -0700 Subject: [PATCH] Avoid GPU synchronization in serial MPI runs --- src/communication.jl | 33 +++++++++++++++++++++++++-------- 1 file changed, 25 insertions(+), 8 deletions(-) diff --git a/src/communication.jl b/src/communication.jl index 7d62d12..d40bf17 100644 --- a/src/communication.jl +++ b/src/communication.jl @@ -299,21 +299,38 @@ function progress(::AbstractCommManager) end function start!(A, cm::CommManagerBuffered) - setbuffer!(cm.sendbufferdevice, A, cm.pattern.sendindices) - KernelAbstractions.synchronize(get_backend(cm)) + if !isempty(cm.sendrequests) + setbuffer!(cm.sendbufferdevice, A, cm.pattern.sendindices) + end + + if !isempty(cm.sendrequests) || !isempty(cm.recvrequests) + KernelAbstractions.synchronize(get_backend(cm)) + end + + if !isempty(cm.recvrequests) + MPI.Startall(cm.recvrequests) + end - MPI.Startall(cm.recvrequests) - MPI.Startall(cm.sendrequests) + if !isempty(cm.sendrequests) + MPI.Startall(cm.sendrequests) + end return end function finish!(A, cm::CommManagerBuffered) - MPI.Waitall(cm.recvrequests) - MPI.Waitall(cm.sendrequests) + if !isempty(cm.recvrequests) + MPI.Waitall(cm.recvrequests) + end + + if !isempty(cm.sendrequests) + MPI.Waitall(cm.sendrequests) + end - A = viewwithghosts(A) - getbuffer!(A, cm.recvbufferdevice, cm.pattern.recvindices) + if !isempty(cm.recvrequests) + A = viewwithghosts(A) + getbuffer!(A, cm.recvbufferdevice, cm.pattern.recvindices) + end return end