Skip to content

Commit

Permalink
Fix messages larger than INT_MAX for mpi
Browse files Browse the repository at this point in the history
  • Loading branch information
JiakunYan committed Nov 24, 2024
1 parent 8257eff commit 646f848
Show file tree
Hide file tree
Showing 4 changed files with 130 additions and 104 deletions.
5 changes: 5 additions & 0 deletions libs/core/mpi_base/include/hpx/mpi_base/mpi_environment.hpp
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
// Copyright (c) 2013-2015 Thomas Heller
// Copyright (c) 2024 Jiakun Yan
//
// SPDX-License-Identifier: BSL-1.0
// Distributed under the Boost Software License, Version 1.0. (See accompanying
Expand Down Expand Up @@ -42,6 +43,10 @@ namespace hpx::util {

static std::string get_processor_name();

static MPI_Datatype type_contiguous(size_t nbytes);
static MPI_Request isend(void* address, size_t size, int rank, int tag);
static MPI_Request irecv(void* address, size_t size, int rank, int tag);

struct HPX_CORE_EXPORT scoped_lock
{
scoped_lock();
Expand Down
92 changes: 92 additions & 0 deletions libs/core/mpi_base/src/mpi_environment.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
// Copyright (c) 2020 Google
// Copyright (c) 2022 Patrick Diehl
// Copyright (c) 2023 Hartmut Kaiser
// Copyright (c) 2024 Jiakun Yan
//
// SPDX-License-Identifier: BSL-1.0
// Distributed under the Boost Software License, Version 1.0. (See accompanying
Expand Down Expand Up @@ -467,6 +468,97 @@ namespace hpx::util {

report_error(sl, error_code);
}

// Acknowledgement: code adapted from github.com/jeffhammond/BigMPI
MPI_Datatype mpi_environment::type_contiguous(size_t nbytes)
{
size_t int_max = (std::numeric_limits<int>::max)();

size_t c = nbytes / int_max;
size_t r = nbytes % int_max;

HPX_ASSERT(c < int_max);
HPX_ASSERT(r < int_max);

MPI_Datatype chunks;
MPI_Type_vector(c, int_max, int_max, MPI_BYTE, &chunks);

MPI_Datatype remainder;
MPI_Type_contiguous(r, MPI_BYTE, &remainder);

MPI_Aint remdisp = (MPI_Aint) c * int_max;
int blocklengths[2] = {1, 1};
MPI_Aint displacements[2] = {0, remdisp};
MPI_Datatype types[2] = {chunks, remainder};
MPI_Datatype newtype;
MPI_Type_create_struct(2, blocklengths, displacements, types, &newtype);

MPI_Type_free(&chunks);
MPI_Type_free(&remainder);

return newtype;
}

MPI_Request mpi_environment::isend(
void* address, size_t size, int rank, int tag)
{
MPI_Request request;
MPI_Datatype datatype;
int length;
if (size > static_cast<size_t>((std::numeric_limits<int>::max)()))
{
datatype = type_contiguous(size);
MPI_Type_commit(&datatype);
length = 1;
}
else
{
datatype = MPI_BYTE;
length = static_cast<int>(size);
}

{
scoped_lock l;
int const ret = MPI_Isend(
address, length, datatype, rank, tag, communicator(), &request);
check_mpi_error(l, HPX_CURRENT_SOURCE_LOCATION(), ret);
}

if (datatype != MPI_BYTE)
MPI_Type_free(&datatype);
return request;
}

MPI_Request mpi_environment::irecv(
void* address, size_t size, int rank, int tag)
{
MPI_Request request;
MPI_Datatype datatype;
int length;
if (size > static_cast<size_t>((std::numeric_limits<int>::max)()))
{
datatype = type_contiguous(size);
MPI_Type_commit(&datatype);
length = 1;
}
else
{
datatype = MPI_BYTE;
length = static_cast<int>(size);
}

{
scoped_lock l;
int const ret = MPI_Irecv(
address, length, datatype, rank, tag, communicator(), &request);
check_mpi_error(l, HPX_CURRENT_SOURCE_LOCATION(), ret);
}

if (datatype != MPI_BYTE)
MPI_Type_free(&datatype);

return request;
}
} // namespace hpx::util

#endif
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
// Copyright (c) 2014-2015 Thomas Heller
// Copyright (c) 2007-2024 Hartmut Kaiser
// Copyright (c) 2023 Jiakun Yan
// Copyright (c) 2023-2024 Jiakun Yan
//
// SPDX-License-Identifier: BSL-1.0
// Distributed under the Boost Software License, Version 1.0. (See accompanying
Expand Down Expand Up @@ -163,14 +163,11 @@ namespace hpx::parcelset::policies::mpi {
{
util::mpi_environment::scoped_lock l;

int const ret = MPI_Irecv(buffer_.transmission_chunks_.data(),
static_cast<int>(buffer_.transmission_chunks_.size() *
sizeof(buffer_type::transmission_chunk_type)),
MPI_BYTE, src_, tag_, util::mpi_environment::communicator(),
&request_);
util::mpi_environment::check_mpi_error(
l, HPX_CURRENT_SOURCE_LOCATION(), ret);

request_ = util::mpi_environment::irecv(
buffer_.transmission_chunks_.data(),
buffer_.transmission_chunks_.size() *
sizeof(buffer_type::transmission_chunk_type),
src_, tag_);
request_ptr_ = &request_;

state_ = connection_state::rcvd_transmission_chunks;
Expand Down Expand Up @@ -207,12 +204,8 @@ namespace hpx::parcelset::policies::mpi {

ack_ = static_cast<char>(
connection_state::acked_transmission_chunks);
int const ret =
MPI_Isend(&ack_, sizeof(ack_), MPI_BYTE, src_, ack_tag(),
util::mpi_environment::communicator(), &request_);
util::mpi_environment::check_mpi_error(
l, HPX_CURRENT_SOURCE_LOCATION(), ret);

request_ = util::mpi_environment::isend(
&ack_, sizeof(ack_), src_, ack_tag());
request_ptr_ = &request_;
}

Expand Down Expand Up @@ -241,14 +234,8 @@ namespace hpx::parcelset::policies::mpi {

if (need_recv_data)
{
util::mpi_environment::scoped_lock l;

int const ret = MPI_Irecv(buffer_.data_.data(),
static_cast<int>(buffer_.data_.size()), MPI_BYTE, src_,
tag_, util::mpi_environment::communicator(), &request_);
util::mpi_environment::check_mpi_error(
l, HPX_CURRENT_SOURCE_LOCATION(), ret);

request_ = util::mpi_environment::irecv(
buffer_.data_.data(), buffer_.data_.size(), src_, tag_);
request_ptr_ = &request_;

state_ = connection_state::rcvd_data;
Expand Down Expand Up @@ -276,15 +263,8 @@ namespace hpx::parcelset::policies::mpi {
HPX_ASSERT(request_ptr_ == nullptr);

{
util::mpi_environment::scoped_lock l;

ack_ = static_cast<char>(connection_state::acked_data);
int const ret =
MPI_Isend(&ack_, sizeof(ack_), MPI_BYTE, src_, ack_tag(),
util::mpi_environment::communicator(), &request_);
util::mpi_environment::check_mpi_error(
l, HPX_CURRENT_SOURCE_LOCATION(), ret);

request_ = util::mpi_environment::isend(
&ack_, sizeof(ack_), src_, ack_tag());
request_ptr_ = &request_;
}

Expand Down Expand Up @@ -372,17 +352,9 @@ namespace hpx::parcelset::policies::mpi {
"zero-copy chunk buffers should have been initialized "
"during de-serialization");

{
util::mpi_environment::scoped_lock l;

int const ret = MPI_Irecv(c.data(),
static_cast<int>(chunk_size), MPI_BYTE, src_, tag_,
util::mpi_environment::communicator(), &request_);
util::mpi_environment::check_mpi_error(
l, HPX_CURRENT_SOURCE_LOCATION(), ret);

request_ptr_ = &request_;
}
request_ = util::mpi_environment::irecv(
c.data(), chunk_size, src_, tag_);
request_ptr_ = &request_;
}
HPX_ASSERT_MSG(
zero_copy_chunks_idx_ == buffer_.num_chunks_.first,
Expand Down Expand Up @@ -412,14 +384,8 @@ namespace hpx::parcelset::policies::mpi {
c.data(), chunk_size);

{
util::mpi_environment::scoped_lock l;

int const ret = MPI_Irecv(c.data(),
static_cast<int>(c.size()), MPI_BYTE, src_, tag_,
util::mpi_environment::communicator(), &request_);
util::mpi_environment::check_mpi_error(
l, HPX_CURRENT_SOURCE_LOCATION(), ret);

request_ = util::mpi_environment::irecv(
c.data(), c.size(), src_, tag_);
request_ptr_ = &request_;
}
}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
// Copyright (c) 2007-2024 Hartmut Kaiser
// Copyright (c) 2014-2015 Thomas Heller
// Copyright (c) 2023 Jiakun Yan
// Copyright (c) 2023-2024 Jiakun Yan
//
// SPDX-License-Identifier: BSL-1.0
// Distributed under the Boost Software License, Version 1.0. (See accompanying
Expand Down Expand Up @@ -177,17 +177,9 @@ namespace hpx::parcelset::policies::mpi {
HPX_ASSERT(state_ == connection_state::initialized);
HPX_ASSERT(request_ptr_ == nullptr);

{
util::mpi_environment::scoped_lock l;

int const ret = MPI_Isend(header_buffer.data(),
static_cast<int>(header_buffer.size()), MPI_BYTE, dst_, 0,
util::mpi_environment::communicator(), &request_);
util::mpi_environment::check_mpi_error(
l, HPX_CURRENT_SOURCE_LOCATION(), ret);

request_ptr_ = &request_;
}
request_ = util::mpi_environment::isend(
header_buffer.data(), header_buffer.size(), dst_, 0);
request_ptr_ = &request_;

state_ = connection_state::sent_header;
return send_transmission_chunks();
Expand All @@ -206,16 +198,10 @@ namespace hpx::parcelset::policies::mpi {
auto const& chunks = buffer_.transmission_chunks_;
if (!chunks.empty() && !header_.piggy_back_tchunk())
{
util::mpi_environment::scoped_lock l;

int const ret = MPI_Isend(chunks.data(),
static_cast<int>(chunks.size() *
sizeof(parcel_buffer_type::transmission_chunk_type)),
MPI_BYTE, dst_, tag_, util::mpi_environment::communicator(),
&request_);
util::mpi_environment::check_mpi_error(
l, HPX_CURRENT_SOURCE_LOCATION(), ret);

request_ = util::mpi_environment::isend(
const_cast<void*>(
reinterpret_cast<const void*>(chunks.data())),
chunks.size(), dst_, tag_);
request_ptr_ = &request_;

state_ = connection_state::sent_transmission_chunks;
Expand Down Expand Up @@ -250,14 +236,8 @@ namespace hpx::parcelset::policies::mpi {
HPX_ASSERT(request_ptr_ == nullptr);

{
util::mpi_environment::scoped_lock l;

int const ret =
MPI_Irecv(&ack_, sizeof(ack_), MPI_BYTE, dst_, ack_tag(),
util::mpi_environment::communicator(), &request_);
util::mpi_environment::check_mpi_error(
l, HPX_CURRENT_SOURCE_LOCATION(), ret);

request_ = util::mpi_environment::irecv(
&ack_, sizeof(ack_), dst_, ack_tag());
request_ptr_ = &request_;
}

Expand All @@ -283,12 +263,8 @@ namespace hpx::parcelset::policies::mpi {
{
util::mpi_environment::scoped_lock l;

int const ret = MPI_Isend(buffer_.data_.data(),
static_cast<int>(buffer_.data_.size()), MPI_BYTE, dst_,
tag_, util::mpi_environment::communicator(), &request_);
util::mpi_environment::check_mpi_error(
l, HPX_CURRENT_SOURCE_LOCATION(), ret);

request_ = util::mpi_environment::isend(
buffer_.data_.data(), buffer_.data_.size(), dst_, tag_);
request_ptr_ = &request_;

state_ = connection_state::sent_data;
Expand Down Expand Up @@ -321,14 +297,8 @@ namespace hpx::parcelset::policies::mpi {
HPX_ASSERT(request_ptr_ == nullptr);

{
util::mpi_environment::scoped_lock l;

int const ret =
MPI_Irecv(&ack_, sizeof(ack_), MPI_BYTE, dst_, ack_tag(),
util::mpi_environment::communicator(), &request_);
util::mpi_environment::check_mpi_error(
l, HPX_CURRENT_SOURCE_LOCATION(), ret);

request_ = util::mpi_environment::irecv(
&ack_, sizeof(ack_), dst_, ack_tag());
request_ptr_ = &request_;
}

Expand All @@ -352,15 +322,8 @@ namespace hpx::parcelset::policies::mpi {
return false;
}
HPX_ASSERT(request_ptr_ == nullptr);

util::mpi_environment::scoped_lock l;

int const ret = MPI_Isend(c.data_.cpos_,
static_cast<int>(c.size_), MPI_BYTE, dst_, tag_,
util::mpi_environment::communicator(), &request_);
util::mpi_environment::check_mpi_error(
l, HPX_CURRENT_SOURCE_LOCATION(), ret);

request_ = util::mpi_environment::isend(
const_cast<void*>(c.data()), c.size(), dst_, tag_);
request_ptr_ = &request_;
}

Expand Down

0 comments on commit 646f848

Please sign in to comment.