Skip to content

Commit 92e2e2b

Browse files
committed
Add complicated distributed shared array test
1 parent 48f5ce8 commit 92e2e2b

File tree

2 files changed

+77
-4
lines changed

2 files changed

+77
-4
lines changed

c++/mpi/mpi.hpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,8 @@ namespace mpi {
7979

8080
[[nodiscard]] MPI_Comm get() const noexcept { return _com; }
8181

82+
[[nodiscard]] bool is_null() const noexcept { return _com == MPI_COMM_NULL; }
83+
8284
[[nodiscard]] int rank() const {
8385
if (has_env) {
8486
int num = 0;

test/c++/mpi_window.cpp

Lines changed: 75 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,8 @@
1414
//
1515
// Authors: Philipp Dumitrescu, Olivier Parcollet, Nils Wentzell
1616

17-
#include "mpi/mpi.hpp"
17+
#include <mpi/mpi.hpp>
18+
#include <mpi/vector.hpp>
1819
#include <gtest/gtest.h>
1920
#include <numeric>
2021

@@ -95,10 +96,10 @@ TEST(MPI_Window, SharedArray) {
9596
auto shm = world.split_shared();
9697
int const rank_shm = shm.rank();
9798

98-
constexpr int const size = 20;
99+
constexpr int const array_size = 20;
99100
constexpr int const magic = 21;
100101

101-
mpi::shared_window<int> win{shm, rank_shm == 0 ? size : 0};
102+
mpi::shared_window<int> win{shm, rank_shm == 0 ? array_size : 0};
102103
std::span array_view{win.base(0), static_cast<std::size_t>(win.size(0))};
103104

104105
win.fence();
@@ -110,7 +111,77 @@ TEST(MPI_Window, SharedArray) {
110111
win.fence();
111112

112113
int sum = std::accumulate(array_view.begin(), array_view.end(), int{0});
113-
EXPECT_EQ(sum, size * magic);
114+
EXPECT_EQ(sum, array_size * magic);
115+
}
116+
117+
TEST(MPI_Window, DistributedSharedArray) {
118+
mpi::communicator world;
119+
auto shm = world.split_shared();
120+
121+
// Number of total array elements (prime number to make it a bit more exciting)
122+
constexpr int const array_size_total = 197;
123+
124+
// Create a communicator between rank0 of all shared memory islands ("head node")
125+
auto head = world.split(shm.rank() == 0 ? 0 : MPI_UNDEFINED);
126+
127+
// Determine number of shared memory islands and broadcast to everyone
128+
int head_size = (world.rank() == 0 ? head.size(): -1);
129+
mpi::broadcast(head_size, world);
130+
131+
// Determine rank in head node communicator and broadcast to all other ranks
132+
// on the same shared memory island
133+
int head_rank = (head.get() != MPI_COMM_NULL ? head.rank() : -1);
134+
mpi::broadcast(head_rank, shm);
135+
136+
// Determine number of ranks on each shared memory island and broadcast to everyone
137+
std::vector<int> shm_sizes(head_size, 0);
138+
if (!head.is_null()) {
139+
shm_sizes.at(head_rank) = shm.size();
140+
shm_sizes = mpi::all_reduce(shm_sizes, head);
141+
}
142+
mpi::broadcast(shm_sizes, world);
143+
144+
// Chunk the total array such that each rank has approximately the same number
145+
// of array elements
146+
std::vector<int> array_sizes(head_size, 0);
147+
for (auto &&[shm_size, array_size]: itertools::zip(shm_sizes, array_sizes)) {
148+
array_size = array_size_total / world.size() * shm_size;
149+
}
150+
// Last shared memory island will get the excess
151+
array_sizes.back() += array_size_total % world.size();
152+
153+
// Determine the global index offset on the current shared memory island
154+
auto begin = array_sizes.begin();
155+
std::advance(begin, head_rank);
156+
std::ptrdiff_t offset = std::accumulate(array_sizes.begin(), begin, std::ptrdiff_t{0});
157+
158+
// Allocate memory
159+
mpi::shared_window<int> win{shm, shm.rank() == 0 ? array_sizes.at(head_rank) : 0};
160+
std::span array_view{win.base(0), static_cast<std::size_t>(win.size(0))};
161+
162+
// Fill array with global index (= local index + global offset)
163+
// We do this in parallel on each shared memory island by chunking the total range
164+
win.fence();
165+
auto slice = itertools::chunk_range(0, array_view.size(), shm.size(), shm.rank());
166+
for (auto i = slice.first; i < slice.second; ++i) {
167+
array_view[i] = i + offset;
168+
}
169+
win.fence();
170+
171+
// Calculate partial sum on head node of each shared memory island and
172+
// all_reduce the partial sums into a total sum over the head node
173+
// communicator and broadcast result to everyone
174+
std::vector<int> partial_sum(head_size, 0);
175+
int sum = 0;
176+
if (!head.is_null()) {
177+
partial_sum[head_rank] = std::accumulate(array_view.begin(), array_view.end(), int{0});
178+
partial_sum = mpi::all_reduce(partial_sum, head);
179+
sum = std::accumulate(partial_sum.begin(), partial_sum.end(), int{0});
180+
}
181+
mpi::broadcast(sum, world);
182+
183+
// Total sum is just sum of numbers in interval [0, array_size_total)
184+
EXPECT_EQ(sum, (array_size_total * (array_size_total - 1)) / 2);
114185
}
115186

116187
MPI_TEST_MAIN;

0 commit comments

Comments
 (0)