@@ -368,48 +368,103 @@ namespace mpi {
368368 window& operator =(window const &) = delete ;
369369 window& operator =(window &&) = delete ;
370370
371+ // / Create a window over an existing local memory buffer
371372 explicit window (communicator &c, BaseType *base, MPI_Aint size = 0 ) {
372373 MPI_Win_create (base, size * sizeof (BaseType), alignof (BaseType), MPI_INFO_NULL, c.get (), &win);
373374 }
374375
376+ // / Create a window and allocate memory for a local memory buffer
377+ explicit window (communicator &c, MPI_Aint size = 0 ) {
378+ void *baseptr = nullptr ;
379+ MPI_Win_allocate (size * sizeof (BaseType), alignof (BaseType), MPI_INFO_NULL, c.get (), &baseptr, &win);
380+ }
381+
375382 ~window () {
376383 if (win != MPI_WIN_NULL) {
377384 MPI_Win_free (&win);
378385 }
379386 }
380387
381- operator MPI_Win () const { return win; };
382- operator MPI_Win*() { return &win; };
388+ explicit operator MPI_Win () const { return win; };
389+ explicit operator MPI_Win*() { return &win; };
383390
391+ // / Synchronization routine in active target RMA. It opens and closes an access epoch.
384392 void fence (int assert = 0 ) const {
385393 MPI_Win_fence (assert , win);
386394 }
387395
396+ // / Complete all outstanding RMA operations at both the origin and the target
397+ void flush (int rank = -1 ) {
398+ if (rank < 0 ) {
399+ MPI_Win_flush_all (win);
400+ } else {
401+ MPI_Win_flush (rank, win);
402+ }
403+ }
404+
405+ // / Synchronize the private and public copies of the window
406+ void sync () {
407+ MPI_Win_sync (win);
408+ }
409+
410+ // / Starts an RMA access epoch locking access to a particular or all ranks in the window
411+ void lock (int rank = -1 , int lock_type = MPI_LOCK_SHARED, int assert = 0 ) {
412+ if (rank < 0 ) {
413+ MPI_Win_lock_all (assert , win);
414+ } else {
415+ MPI_Win_lock (lock_type, rank, assert , win);
416+ }
417+ }
418+
419+ // / Completes an RMA access epoch started by a call to lock()
420+ void unlock (int rank = -1 ) {
421+ if (rank < 0 ) {
422+ MPI_Win_unlock_all (win);
423+ } else {
424+ MPI_Win_unlock (rank, win);
425+ }
426+ }
427+
428+ // / Load data from a remote memory window.
388429 template <typename TargetType = BaseType, typename OriginType>
389430 std::enable_if_t <has_mpi_type<OriginType> && has_mpi_type<TargetType>, void >
390431 get (OriginType *origin_addr, int origin_count, int target_rank, MPI_Aint target_disp = 0 , int target_count = -1 ) const {
391- MPI_Datatype origin_datatype = mpi_type<OriginType>::get ();
392- MPI_Datatype target_datatype = mpi_type<TargetType>::get ();
393- int target_count_ = target_count < 0 ? origin_count : target_count;
394- MPI_Get (origin_addr, origin_count, origin_datatype, target_rank, target_disp, target_count_, target_datatype, win);
432+ MPI_Datatype origin_datatype = mpi_type<OriginType>::get ();
433+ MPI_Datatype target_datatype = mpi_type<TargetType>::get ();
434+ int target_count_ = target_count < 0 ? origin_count : target_count;
435+ MPI_Get (origin_addr, origin_count, origin_datatype, target_rank, target_disp, target_count_, target_datatype, win);
395436 };
396437
438+ // / Store data to a remote memory window.
397439 template <typename TargetType = BaseType, typename OriginType>
398440 std::enable_if_t <has_mpi_type<OriginType> && has_mpi_type<TargetType>, void >
399441 put (OriginType *origin_addr, int origin_count, int target_rank, MPI_Aint target_disp = 0 , int target_count = -1 ) const {
400- MPI_Datatype origin_datatype = mpi_type<OriginType>::get ();
401- MPI_Datatype target_datatype = mpi_type<TargetType>::get ();
402- int target_count_ = target_count < 0 ? origin_count : target_count;
403- MPI_Put (origin_addr, origin_count, origin_datatype, target_rank, target_disp, target_count_, target_datatype, win);
442+ MPI_Datatype origin_datatype = mpi_type<OriginType>::get ();
443+ MPI_Datatype target_datatype = mpi_type<TargetType>::get ();
444+ int target_count_ = target_count < 0 ? origin_count : target_count;
445+ MPI_Put (origin_addr, origin_count, origin_datatype, target_rank, target_disp, target_count_, target_datatype, win);
404446 };
405447
448+ // / Accumulate data into target process through remote memory access.
449+ template <typename TargetType = BaseType, typename OriginType>
450+ std::enable_if_t <has_mpi_type<OriginType> && has_mpi_type<TargetType>, void >
451+ accumulate (OriginType const *origin_addr, int origin_count, int target_rank, MPI_Aint target_disp = 0 , int target_count = -1 , MPI_Op op = MPI_SUM) {
452+ MPI_Datatype origin_datatype = mpi_type<OriginType>::get ();
453+ MPI_Datatype target_datatype = mpi_type<TargetType>::get ();
454+ int target_count_ = target_count < 0 ? origin_count : target_count;
455+ MPI_Accumulate (origin_addr, origin_count, origin_datatype, target_rank, target_disp, target_count_, target_datatype, op, win);
456+ }
457+
458+ // / Obtains the value of a window attribute.
406459 void * get_attr (int win_keyval) const {
407460 int flag;
408461 void *attribute_val;
409462 MPI_Win_get_attr (win, win_keyval, &attribute_val, &flag);
410463 assert (flag);
411464 return attribute_val;
412465 }
466+
467+ // Expose some commonly used attributes
413468 BaseType* base () const { return static_cast <BaseType*>(get_attr (MPI_WIN_BASE)); }
414469 MPI_Aint size () const { return *static_cast <MPI_Aint*>(get_attr (MPI_WIN_SIZE)); }
415470 int disp_unit () const { return *static_cast <int *>(get_attr (MPI_WIN_DISP_UNIT)); }
@@ -419,11 +474,13 @@ namespace mpi {
419474 template <class BaseType >
420475 class shared_window : public window <BaseType> {
421476 public:
477+ // / Create a window and allocate memory for a shared memory buffer
422478 shared_window (shared_communicator& c, MPI_Aint size) {
423479 void * baseptr = nullptr ;
424480 MPI_Win_allocate_shared (size * sizeof (BaseType), alignof (BaseType), MPI_INFO_NULL, c.get (), &baseptr, &(this ->win ));
425481 }
426482
483+ // / Query a shared memory window
427484 std::tuple<MPI_Aint, int , void *> query (int rank = MPI_PROC_NULL) const {
428485 MPI_Aint size = 0 ;
429486 int disp_unit = 0 ;
@@ -432,9 +489,10 @@ namespace mpi {
432489 return {size, disp_unit, baseptr};
433490 }
434491
492+ // Override the commonly used attributes of the window base class
493+ BaseType* base (int rank = MPI_PROC_NULL) const { return static_cast <BaseType*>(std::get<2 >(query (rank))); }
435494 MPI_Aint size (int rank = MPI_PROC_NULL) const { return std::get<0 >(query (rank)) / sizeof (BaseType); }
436495 int disp_unit (int rank = MPI_PROC_NULL) const { return std::get<1 >(query (rank)); }
437- BaseType* base (int rank = MPI_PROC_NULL) const { return static_cast <BaseType*>(std::get<2 >(query (rank))); }
438496 };
439497
440498 /* -----------------------------------------------------------
0 commit comments