From 116774e9cb1dec96a9bada150d3528d9b199a036 Mon Sep 17 00:00:00 2001 From: Gabriel Dos Santos Date: Fri, 19 Apr 2024 11:16:34 +0200 Subject: [PATCH] refactor: change lambda prototype to match MPI functions Removes reference-captures and adds explicit typing. --- src/impl/KokkosComm_isend.hpp | 27 +++++++++++++++++---------- src/impl/KokkosComm_send.hpp | 18 ++++++++++-------- 2 files changed, 27 insertions(+), 18 deletions(-) diff --git a/src/impl/KokkosComm_isend.hpp b/src/impl/KokkosComm_isend.hpp index 0899ebaa..e9f39ed4 100644 --- a/src/impl/KokkosComm_isend.hpp +++ b/src/impl/KokkosComm_isend.hpp @@ -42,18 +42,25 @@ KokkosComm::Req isend(const ExecSpace &space, const SendView &sv, int dest, using KCT = KokkosComm::Traits; using KCPT = KokkosComm::PackTraits; - auto isend_fn = [&](auto &&view, auto count, auto datatype, auto &&request) { + auto mpi_isend_fn = [](void *mpi_view, int mpi_count, + MPI_Datatype mpi_datatype, int mpi_dest, int mpi_tag, + MPI_Comm mpi_comm, MPI_Request *mpi_req) { if constexpr (SendMode == CommMode::Standard) { - MPI_Isend(view, count, datatype, dest, tag, comm, request); + MPI_Isend(mpi_view, mpi_count, mpi_datatype, mpi_dest, mpi_tag, mpi_comm, + mpi_req); } else if constexpr (SendMode == CommMode::Ready) { - MPI_Irsend(view, count, datatype, dest, tag, comm, request); + MPI_Irsend(mpi_view, mpi_count, mpi_datatype, mpi_dest, mpi_tag, mpi_comm, + mpi_req); } else if constexpr (SendMode == CommMode::Synchronous) { - MPI_Issend(view, count, datatype, dest, tag, comm, request); + MPI_Issend(mpi_view, mpi_count, mpi_datatype, mpi_dest, mpi_tag, mpi_comm, + mpi_req); } else if constexpr (SendMode == CommMode::Default) { #ifdef KOKKOSCOMM_FORCE_SYNCHRONOUS_MODE - MPI_Issend(view, count, datatype, dest, tag, comm, request); + MPI_Issend(mpi_view, mpi_count, mpi_datatype, mpi_dest, mpi_tag, mpi_comm, + mpi_req); #else - MPI_Isend(view, count, datatype, dest, tag, comm, request); + MPI_Isend(mpi_view, mpi_count, mpi_datatype, mpi_dest, mpi_tag, mpi_comm, + mpi_req); #endif } }; @@ -64,13 +71,13 @@ KokkosComm::Req isend(const ExecSpace &space, const SendView &sv, int dest, MpiArgs args = Packer::pack(space, sv); space.fence(); - isend_fn(KCT::data_handle(args.view), args.count, args.datatype, - &req.mpi_req()); + mpi_isend_fn(KCT::data_handle(args.view), args.count, args.datatype, dest, + tag, comm, &req.mpi_req()); req.keep_until_wait(args.view); } else { using SendScalar = typename SendView::value_type; - isend_fn(KCT::data_handle(sv), KCT::span(sv), mpi_type_v, - &req.mpi_req()); + mpi_isend_fn(KCT::data_handle(sv), KCT::span(sv), mpi_type_v, + dest, tag, comm, &req.mpi_req()); if (KCT::is_reference_counted()) { req.keep_until_wait(sv); } diff --git a/src/impl/KokkosComm_send.hpp b/src/impl/KokkosComm_send.hpp index 25b39fd4..8d88e236 100644 --- a/src/impl/KokkosComm_send.hpp +++ b/src/impl/KokkosComm_send.hpp @@ -34,18 +34,20 @@ void send(const ExecSpace &space, const SendView &sv, int dest, int tag, using Packer = typename KokkosComm::PackTraits::packer_type; - auto send_fn = [&](auto &&view, auto count, auto datatype) { + auto mpi_send_fn = [](void *mpi_view, int mpi_count, + MPI_Datatype mpi_datatype, int mpi_dest, int mpi_tag, + MPI_Comm mpi_comm) { if constexpr (SendMode == CommMode::Standard) { - MPI_Send(view, count, datatype, dest, tag, comm); + MPI_Send(mpi_view, mpi_count, mpi_datatype, mpi_dest, mpi_tag, mpi_comm); } else if constexpr (SendMode == CommMode::Ready) { - MPI_Rsend(view, count, datatype, dest, tag, comm); + MPI_Rsend(mpi_view, mpi_count, mpi_datatype, mpi_dest, mpi_tag, mpi_comm); } else if constexpr (SendMode == CommMode::Synchronous) { - MPI_Ssend(view, count, datatype, dest, tag, comm); + MPI_Ssend(mpi_view, mpi_count, mpi_datatype, mpi_dest, mpi_tag, mpi_comm); } else if constexpr (SendMode == CommMode::Default) { #ifdef KOKKOSCOMM_FORCE_SYNCHRONOUS_MODE - MPI_Ssend(view, count, datatype, dest, tag, comm); + MPI_Ssend(mpi_view, mpi_count, mpi_datatype, mpi_dest, mpi_tag, mpi_comm); #else - MPI_Send(view, count, datatype, dest, tag, comm); + MPI_Send(mpi_view, mpi_count, mpi_datatype, mpi_dest, mpi_tag, mpi_comm); #endif } }; @@ -53,10 +55,10 @@ void send(const ExecSpace &space, const SendView &sv, int dest, int tag, if (KokkosComm::PackTraits::needs_pack(sv)) { auto args = Packer::pack(space, sv); space.fence(); - send_fn(args.view.data(), args.count, args.datatype); + mpi_send_fn(args.view.data(), args.count, args.datatype, dest, tag, comm); } else { using SendScalar = typename SendView::value_type; - send_fn(sv.data(), sv.span(), mpi_type_v); + mpi_send_fn(sv.data(), sv.span(), mpi_type_v, dest, tag, comm); } Kokkos::Tools::popRegion();