1414//
1515// Authors: Philipp Dumitrescu, Olivier Parcollet, Nils Wentzell
1616
17- #include " mpi/mpi.hpp"
17+ #include < mpi/mpi.hpp>
18+ #include < mpi/vector.hpp>
1819#include < gtest/gtest.h>
1920#include < numeric>
2021
@@ -95,10 +96,10 @@ TEST(MPI_Window, SharedArray) {
9596 auto shm = world.split_shared ();
9697 int const rank_shm = shm.rank ();
9798
98- constexpr int const size = 20 ;
99+ constexpr int const array_size = 20 ;
99100 constexpr int const magic = 21 ;
100101
101- mpi::shared_window<int > win{shm, rank_shm == 0 ? size : 0 };
102+ mpi::shared_window<int > win{shm, rank_shm == 0 ? array_size : 0 };
102103 std::span array_view{win.base (0 ), static_cast <std::size_t >(win.size (0 ))};
103104
104105 win.fence ();
@@ -110,7 +111,77 @@ TEST(MPI_Window, SharedArray) {
110111 win.fence ();
111112
112113 int sum = std::accumulate (array_view.begin (), array_view.end (), int {0 });
113- EXPECT_EQ (sum, size * magic);
114+ EXPECT_EQ (sum, array_size * magic);
115+ }
116+
117+ TEST (MPI_Window, DistributedSharedArray) {
118+ mpi::communicator world;
119+ auto shm = world.split_shared ();
120+
121+ // Number of total array elements (prime number to make it a bit more exciting)
122+ constexpr int const array_size_total = 197 ;
123+
124+ // Create a communicator between rank0 of all shared memory islands ("head node")
125+ auto head = world.split (shm.rank () == 0 ? 0 : MPI_UNDEFINED);
126+
127+ // Determine number of shared memory islands and broadcast to everyone
128+ int head_size = (world.rank () == 0 ? head.size (): -1 );
129+ mpi::broadcast (head_size, world);
130+
131+ // Determine rank in head node communicator and broadcast to all other ranks
132+ // on the same shared memory island
133+ int head_rank = (head.get () != MPI_COMM_NULL ? head.rank () : -1 );
134+ mpi::broadcast (head_rank, shm);
135+
136+ // Determine number of ranks on each shared memory island and broadcast to everyone
137+ std::vector<int > shm_sizes (head_size, 0 );
138+ if (!head.is_null ()) {
139+ shm_sizes.at (head_rank) = shm.size ();
140+ shm_sizes = mpi::all_reduce (shm_sizes, head);
141+ }
142+ mpi::broadcast (shm_sizes, world);
143+
144+ // Chunk the total array such that each rank has approximately the same number
145+ // of array elements
146+ std::vector<int > array_sizes (head_size, 0 );
147+ for (auto &&[shm_size, array_size]: itertools::zip (shm_sizes, array_sizes)) {
148+ array_size = array_size_total / world.size () * shm_size;
149+ }
150+ // Last shared memory island will get the excess
151+ array_sizes.back () += array_size_total % world.size ();
152+
153+ // Determine the global index offset on the current shared memory island
154+ auto begin = array_sizes.begin ();
155+ std::advance (begin, head_rank);
156+ std::ptrdiff_t offset = std::accumulate (array_sizes.begin (), begin, std::ptrdiff_t {0 });
157+
158+ // Allocate memory
159+ mpi::shared_window<int > win{shm, shm.rank () == 0 ? array_sizes.at (head_rank) : 0 };
160+ std::span array_view{win.base (0 ), static_cast <std::size_t >(win.size (0 ))};
161+
162+ // Fill array with global index (= local index + global offset)
163+ // We do this in parallel on each shared memory island by chunking the total range
164+ win.fence ();
165+ auto slice = itertools::chunk_range (0 , array_view.size (), shm.size (), shm.rank ());
166+ for (auto i = slice.first ; i < slice.second ; ++i) {
167+ array_view[i] = i + offset;
168+ }
169+ win.fence ();
170+
171+ // Calculate partial sum on head node of each shared memory island and
172+ // all_reduce the partial sums into a total sum over the head node
173+ // communicator and broadcast result to everyone
174+ std::vector<int > partial_sum (head_size, 0 );
175+ int sum = 0 ;
176+ if (!head.is_null ()) {
177+ partial_sum[head_rank] = std::accumulate (array_view.begin (), array_view.end (), int {0 });
178+ partial_sum = mpi::all_reduce (partial_sum, head);
179+ sum = std::accumulate (partial_sum.begin (), partial_sum.end (), int {0 });
180+ }
181+ mpi::broadcast (sum, world);
182+
183+ // Total sum is just sum of numbers in interval [0, array_size_total)
184+ EXPECT_EQ (sum, (array_size_total * (array_size_total - 1 )) / 2 );
114185}
115186
116187MPI_TEST_MAIN;
0 commit comments