Skip to content

Commit

Permalink
refactor: change lambda prototype to match MPI functions
Browse files Browse the repository at this point in the history
Removes reference-captures and adds explicit typing.
  • Loading branch information
dssgabriel committed Apr 19, 2024
1 parent fa23192 commit 116774e
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 18 deletions.
27 changes: 17 additions & 10 deletions src/impl/KokkosComm_isend.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -42,18 +42,25 @@ KokkosComm::Req isend(const ExecSpace &space, const SendView &sv, int dest,
using KCT = KokkosComm::Traits<SendView>;
using KCPT = KokkosComm::PackTraits<SendView>;

auto isend_fn = [&](auto &&view, auto count, auto datatype, auto &&request) {
auto mpi_isend_fn = [](void *mpi_view, int mpi_count,
MPI_Datatype mpi_datatype, int mpi_dest, int mpi_tag,
MPI_Comm mpi_comm, MPI_Request *mpi_req) {
if constexpr (SendMode == CommMode::Standard) {
MPI_Isend(view, count, datatype, dest, tag, comm, request);
MPI_Isend(mpi_view, mpi_count, mpi_datatype, mpi_dest, mpi_tag, mpi_comm,
mpi_req);
} else if constexpr (SendMode == CommMode::Ready) {
MPI_Irsend(view, count, datatype, dest, tag, comm, request);
MPI_Irsend(mpi_view, mpi_count, mpi_datatype, mpi_dest, mpi_tag, mpi_comm,
mpi_req);
} else if constexpr (SendMode == CommMode::Synchronous) {
MPI_Issend(view, count, datatype, dest, tag, comm, request);
MPI_Issend(mpi_view, mpi_count, mpi_datatype, mpi_dest, mpi_tag, mpi_comm,
mpi_req);
} else if constexpr (SendMode == CommMode::Default) {
#ifdef KOKKOSCOMM_FORCE_SYNCHRONOUS_MODE
MPI_Issend(view, count, datatype, dest, tag, comm, request);
MPI_Issend(mpi_view, mpi_count, mpi_datatype, mpi_dest, mpi_tag, mpi_comm,
mpi_req);
#else
MPI_Isend(view, count, datatype, dest, tag, comm, request);
MPI_Isend(mpi_view, mpi_count, mpi_datatype, mpi_dest, mpi_tag, mpi_comm,
mpi_req);
#endif
}
};
Expand All @@ -64,13 +71,13 @@ KokkosComm::Req isend(const ExecSpace &space, const SendView &sv, int dest,

MpiArgs args = Packer::pack(space, sv);
space.fence();
isend_fn(KCT::data_handle(args.view), args.count, args.datatype,
&req.mpi_req());
mpi_isend_fn(KCT::data_handle(args.view), args.count, args.datatype, dest,
tag, comm, &req.mpi_req());
req.keep_until_wait(args.view);
} else {
using SendScalar = typename SendView::value_type;
isend_fn(KCT::data_handle(sv), KCT::span(sv), mpi_type_v<SendScalar>,
&req.mpi_req());
mpi_isend_fn(KCT::data_handle(sv), KCT::span(sv), mpi_type_v<SendScalar>,
dest, tag, comm, &req.mpi_req());
if (KCT::is_reference_counted()) {
req.keep_until_wait(sv);
}
Expand Down
18 changes: 10 additions & 8 deletions src/impl/KokkosComm_send.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,29 +34,31 @@ void send(const ExecSpace &space, const SendView &sv, int dest, int tag,

using Packer = typename KokkosComm::PackTraits<SendView>::packer_type;

auto send_fn = [&](auto &&view, auto count, auto datatype) {
auto mpi_send_fn = [](void *mpi_view, int mpi_count,
MPI_Datatype mpi_datatype, int mpi_dest, int mpi_tag,
MPI_Comm mpi_comm) {
if constexpr (SendMode == CommMode::Standard) {
MPI_Send(view, count, datatype, dest, tag, comm);
MPI_Send(mpi_view, mpi_count, mpi_datatype, mpi_dest, mpi_tag, mpi_comm);
} else if constexpr (SendMode == CommMode::Ready) {
MPI_Rsend(view, count, datatype, dest, tag, comm);
MPI_Rsend(mpi_view, mpi_count, mpi_datatype, mpi_dest, mpi_tag, mpi_comm);
} else if constexpr (SendMode == CommMode::Synchronous) {
MPI_Ssend(view, count, datatype, dest, tag, comm);
MPI_Ssend(mpi_view, mpi_count, mpi_datatype, mpi_dest, mpi_tag, mpi_comm);
} else if constexpr (SendMode == CommMode::Default) {
#ifdef KOKKOSCOMM_FORCE_SYNCHRONOUS_MODE
MPI_Ssend(view, count, datatype, dest, tag, comm);
MPI_Ssend(mpi_view, mpi_count, mpi_datatype, mpi_dest, mpi_tag, mpi_comm);
#else
MPI_Send(view, count, datatype, dest, tag, comm);
MPI_Send(mpi_view, mpi_count, mpi_datatype, mpi_dest, mpi_tag, mpi_comm);
#endif
}
};

if (KokkosComm::PackTraits<SendView>::needs_pack(sv)) {
auto args = Packer::pack(space, sv);
space.fence();
send_fn(args.view.data(), args.count, args.datatype);
mpi_send_fn(args.view.data(), args.count, args.datatype, dest, tag, comm);
} else {
using SendScalar = typename SendView::value_type;
send_fn(sv.data(), sv.span(), mpi_type_v<SendScalar>);
mpi_send_fn(sv.data(), sv.span(), mpi_type_v<SendScalar>, dest, tag, comm);
}

Kokkos::Tools::popRegion();
Expand Down

0 comments on commit 116774e

Please sign in to comment.