Skip to content

Commit fa23192

Browse files
committed
refactor: define a lambda to avoid logic duplication in pack vs. non-pack paths
1 parent b627c8f commit fa23192

File tree

2 files changed

+31
-56
lines changed

2 files changed

+31
-56
lines changed

src/impl/KokkosComm_isend.hpp

Lines changed: 18 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -42,52 +42,35 @@ KokkosComm::Req isend(const ExecSpace &space, const SendView &sv, int dest,
4242
using KCT = KokkosComm::Traits<SendView>;
4343
using KCPT = KokkosComm::PackTraits<SendView>;
4444

45-
if (KCPT::needs_pack(sv)) {
46-
using Packer = typename KCPT::packer_type;
47-
using MpiArgs = typename Packer::args_type;
48-
49-
MpiArgs args = Packer::pack(space, sv);
50-
space.fence();
51-
45+
auto isend_fn = [&](auto &&view, auto count, auto datatype, auto &&request) {
5246
if constexpr (SendMode == CommMode::Standard) {
53-
MPI_Isend(KCT::data_handle(args.view), args.count, args.datatype, dest,
54-
tag, comm, &req.mpi_req());
47+
MPI_Isend(view, count, datatype, dest, tag, comm, request);
5548
} else if constexpr (SendMode == CommMode::Ready) {
56-
MPI_Irsend(KCT::data_handle(args.view), args.count, args.datatype, dest,
57-
tag, comm, &req.mpi_req());
49+
MPI_Irsend(view, count, datatype, dest, tag, comm, request);
5850
} else if constexpr (SendMode == CommMode::Synchronous) {
59-
MPI_Issend(KCT::data_handle(args.view), args.count, args.datatype, dest,
60-
tag, comm, &req.mpi_req());
51+
MPI_Issend(view, count, datatype, dest, tag, comm, request);
6152
} else if constexpr (SendMode == CommMode::Default) {
6253
#ifdef KOKKOSCOMM_FORCE_SYNCHRONOUS_MODE
63-
MPI_Issend(KCT::data_handle(args.view), args.count, args.datatype, dest,
64-
tag, comm, &req.mpi_req());
54+
MPI_Issend(view, count, datatype, dest, tag, comm, request);
6555
#else
66-
MPI_Isend(KCT::data_handle(args.view), args.count, args.datatype, dest,
67-
tag, comm, &req.mpi_req());
56+
MPI_Isend(view, count, datatype, dest, tag, comm, request);
6857
#endif
6958
}
59+
};
60+
61+
if (KCPT::needs_pack(sv)) {
62+
using Packer = typename KCPT::packer_type;
63+
using MpiArgs = typename Packer::args_type;
64+
65+
MpiArgs args = Packer::pack(space, sv);
66+
space.fence();
67+
isend_fn(KCT::data_handle(args.view), args.count, args.datatype,
68+
&req.mpi_req());
7069
req.keep_until_wait(args.view);
7170
} else {
7271
using SendScalar = typename SendView::value_type;
73-
if constexpr (SendMode == CommMode::Standard) {
74-
MPI_Isend(KCT::data_handle(sv), KCT::span(sv), mpi_type_v<SendScalar>,
75-
dest, tag, comm, &req.mpi_req());
76-
} else if constexpr (SendMode == CommMode::Ready) {
77-
MPI_Irsend(KCT::data_handle(sv), KCT::span(sv), mpi_type_v<SendScalar>,
78-
dest, tag, comm, &req.mpi_req());
79-
} else if constexpr (SendMode == CommMode::Synchronous) {
80-
MPI_Issend(KCT::data_handle(sv), KCT::span(sv), mpi_type_v<SendScalar>,
81-
dest, tag, comm, &req.mpi_req());
82-
} else if constexpr (SendMode == CommMode::Default) {
83-
#ifdef KOKKOSCOMM_FORCE_SYNCHRONOUS_MODE
84-
MPI_Issend(KCT::data_handle(sv), KCT::span(sv), mpi_type_v<SendScalar>,
85-
dest, tag, comm, &req.mpi_req());
86-
#else
87-
MPI_Isend(KCT::data_handle(sv), KCT::span(sv), mpi_type_v<SendScalar>,
88-
dest, tag, comm, &req.mpi_req());
89-
#endif
90-
}
72+
isend_fn(KCT::data_handle(sv), KCT::span(sv), mpi_type_v<SendScalar>,
73+
&req.mpi_req());
9174
if (KCT::is_reference_counted()) {
9275
req.keep_until_wait(sv);
9376
}

src/impl/KokkosComm_send.hpp

Lines changed: 13 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -34,37 +34,29 @@ void send(const ExecSpace &space, const SendView &sv, int dest, int tag,
3434

3535
using Packer = typename KokkosComm::PackTraits<SendView>::packer_type;
3636

37-
if (KokkosComm::PackTraits<SendView>::needs_pack(sv)) {
38-
auto args = Packer::pack(space, sv);
39-
space.fence();
37+
auto send_fn = [&](auto &&view, auto count, auto datatype) {
4038
if constexpr (SendMode == CommMode::Standard) {
41-
MPI_Send(args.view.data(), args.count, args.datatype, dest, tag, comm);
39+
MPI_Send(view, count, datatype, dest, tag, comm);
4240
} else if constexpr (SendMode == CommMode::Ready) {
43-
MPI_Rsend(args.view.data(), args.count, args.datatype, dest, tag, comm);
41+
MPI_Rsend(view, count, datatype, dest, tag, comm);
4442
} else if constexpr (SendMode == CommMode::Synchronous) {
45-
MPI_Ssend(args.view.data(), args.count, args.datatype, dest, tag, comm);
43+
MPI_Ssend(view, count, datatype, dest, tag, comm);
4644
} else if constexpr (SendMode == CommMode::Default) {
4745
#ifdef KOKKOSCOMM_FORCE_SYNCHRONOUS_MODE
48-
MPI_Ssend(args.view.data(), args.count, args.datatype, dest, tag, comm);
46+
MPI_Ssend(view, count, datatype, dest, tag, comm);
4947
#else
50-
MPI_Send(args.view.data(), args.count, args.datatype, dest, tag, comm);
48+
MPI_Send(view, count, datatype, dest, tag, comm);
5149
#endif
5250
}
51+
};
52+
53+
if (KokkosComm::PackTraits<SendView>::needs_pack(sv)) {
54+
auto args = Packer::pack(space, sv);
55+
space.fence();
56+
send_fn(args.view.data(), args.count, args.datatype);
5357
} else {
5458
using SendScalar = typename SendView::value_type;
55-
if constexpr (SendMode == CommMode::Standard) {
56-
MPI_Send(sv.data(), sv.span(), mpi_type_v<SendScalar>, dest, tag, comm);
57-
} else if constexpr (SendMode == CommMode::Ready) {
58-
MPI_Rsend(sv.data(), sv.span(), mpi_type_v<SendScalar>, dest, tag, comm);
59-
} else if constexpr (SendMode == CommMode::Synchronous) {
60-
MPI_Ssend(sv.data(), sv.span(), mpi_type_v<SendScalar>, dest, tag, comm);
61-
} else if constexpr (SendMode == CommMode::Default) {
62-
#ifdef KOKKOSCOMM_FORCE_SYNCHRONOUS_MODE
63-
MPI_Ssend(sv.data(), sv.span(), mpi_type_v<SendScalar>, dest, tag, comm);
64-
#else
65-
MPI_Send(sv.data(), sv.span(), mpi_type_v<SendScalar>, dest, tag, comm);
66-
#endif
67-
}
59+
send_fn(sv.data(), sv.span(), mpi_type_v<SendScalar>);
6860
}
6961

7062
Kokkos::Tools::popRegion();

0 commit comments

Comments
 (0)