Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Ucx process bootstrap #598

Open
wants to merge 9 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions cpp/src/cylon/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@ if (CYLON_UCX)
net/ucx/ucx_communicator.cpp
net/ucx/ucx_operations.hpp
net/ucx/ucx_operations.cpp
net/ucx/ucx_ucc_oob_contexts.hpp
net/ucx/ucx_ucc_oob_contexts.cpp
)
if (CYLON_UCC)
set(UCC_CYLON_FILES
Expand Down Expand Up @@ -232,6 +234,20 @@ if (CYLON_UCX)
if (CYLON_UCC)
target_link_libraries(cylon ucc)
endif ()
# <------------ add hiredis dependency --------------->
find_path(HIREDIS_HEADER hiredis)
target_include_directories(cylon PUBLIC ${HIREDIS_HEADER})

find_library(HIREDIS_LIB hiredis)
target_link_libraries(cylon ${HIREDIS_LIB})

# <------------ add redis-plus-plus dependency -------------->
# NOTE: this should be *sw* NOT *redis++*
find_path(REDIS_PLUS_PLUS_HEADER sw)
target_include_directories(cylon PUBLIC ${REDIS_PLUS_PLUS_HEADER})

find_library(REDIS_PLUS_PLUS_LIB redis++)
target_link_libraries(cylon ${REDIS_PLUS_PLUS_LIB})
endif ()

if (CYLON_GLOO)
Expand Down
15 changes: 15 additions & 0 deletions cpp/src/cylon/ctx/cylon_context.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ Status CylonContext::InitDistributed(const std::shared_ptr<cylon::net::CommConfi
*ctx = std::make_shared<CylonContext>(true);
auto pool = (*ctx)->GetMemoryPool();
#ifdef BUILD_CYLON_UCC
// use UCC if we can
return net::UCXUCCCommunicator::Make(config, pool, &(*ctx)->communicator);
#else
return net::UCXCommunicator::Make(config, pool, &(*ctx)->communicator);
Expand All @@ -63,6 +64,20 @@ Status CylonContext::InitDistributed(const std::shared_ptr<cylon::net::CommConfi
#endif
}

case net::UCC: {
#ifdef BUILD_CYLON_UCX
#ifdef BUILD_CYLON_UCC
*ctx = std::make_shared<CylonContext>(true);
auto pool = (*ctx)->GetMemoryPool();
return net::UCXUCCCommunicator::Make(config, pool, &(*ctx)->communicator);
#else
return {Code::NotImplemented, "UCX communication not implemented"};
#endif
#else
return {Code::NotImplemented, "UCX communication not implemented"};
#endif
}

case net::TCP:return {Code::NotImplemented, "TCP communication not implemented"};

case net::GLOO: {
Expand Down
3 changes: 2 additions & 1 deletion cpp/src/cylon/net/comm_type.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,8 @@ enum CommType {
MPI = 1,
TCP = 2,
UCX = 3,
GLOO = 4
GLOO = 4,
UCC = 5
};
} // namespace net
} // namespace cylon
Expand Down
45 changes: 42 additions & 3 deletions cpp/src/cylon/net/ucc/ucc_operations.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

#include "cylon/net/ucc/ucc_operations.hpp"
#include "cylon/util/macros.hpp"
#include "cylon/net/utils.hpp"

namespace cylon {
namespace ucc {
Expand Down Expand Up @@ -251,6 +252,8 @@ UccTableGatherImpl::UccTableGatherImpl(ucc_team_h ucc_team,
void UccTableGatherImpl::Init(int32_t num_buffers) {
this->requests_.resize(num_buffers);
this->args_.resize(num_buffers);
this->displacements_ = new std::vector<std::vector<int>>(num_buffers);
this->all_recv_counts_ = new std::vector<std::vector<int>>(num_buffers);
Comment on lines +255 to +256
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I dont think you need to malloc for a vector object. You can simply do this,

this->displacements_ = std::vector<std::vector<int>>(num_buffers);
this->all_recv_counts_ = std::vector<std::vector<int>>(num_buffers);

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh I think I understand why you did this. This is because of the const in GatherBufferSizes method, right?

}

Status UccTableGatherImpl::GatherBufferSizes(const int32_t *send_data, int32_t num_buffers,
Expand All @@ -259,17 +262,25 @@ Status UccTableGatherImpl::GatherBufferSizes(const int32_t *send_data, int32_t n
ucc_coll_req_h req;

args.mask = 0;
args.coll_type = UCC_COLL_TYPE_GATHER;
args.coll_type = UCC_COLL_TYPE_ALLGATHER;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

put a TODO and refer to the github issue.
Add a comment explaining why we are using allgather

args.root = gather_root;

args.src.info.buffer = const_cast<int32_t *>(send_data);
args.src.info.count = static_cast<uint64_t>(num_buffers);
args.src.info.datatype = UCC_DT_INT32;
args.src.info.mem_type = UCC_MEMORY_TYPE_HOST;

int32_t total_sz = num_buffers * world_size;
std::vector<int32_t> all_buffer_sizes;
if(rank == gather_root) {
args.dst.info.buffer = rcv_data;
args.dst.info.count = num_buffers * world_size;
args.dst.info.count = total_sz;
args.dst.info.datatype = UCC_DT_INT32;
args.dst.info.mem_type = UCC_MEMORY_TYPE_HOST;
} else {
all_buffer_sizes.resize(total_sz);
args.dst.info.buffer = all_buffer_sizes.data();
args.dst.info.count = total_sz;
args.dst.info.datatype = UCC_DT_INT32;
args.dst.info.mem_type = UCC_MEMORY_TYPE_HOST;
}
Comment on lines 275 to 286
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we do this?

  if(rank == gather_root) {
    args.dst.info.buffer = rcv_data;
  } else {
    all_buffer_sizes.resize(total_sz);
    args.dst.info.buffer = all_buffer_sizes.data();
  }
args.dst.info.count = total_sz;
args.dst.info.datatype = UCC_DT_INT32;
args.dst.info.mem_type = UCC_MEMORY_TYPE_HOST;

Expand All @@ -287,6 +298,15 @@ Status UccTableGatherImpl::GatherBufferSizes(const int32_t *send_data, int32_t n
}

RETURN_CYLON_STATUS_IF_UCC_FAILED(ucc_collective_finalize(req));

if(rank != gather_root) {
for (int32_t i = 0; i < num_buffers; ++i) {
(*all_recv_counts_)[i] = cylon::net::receiveCounts(all_buffer_sizes, i,
num_buffers, world_size);
(*displacements_)[i] = std::move(cylon::net::displacementsPerBuffer(all_buffer_sizes, i,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

don't think you need the move here.

num_buffers, world_size));
}
}
return Status::OK();
}

Expand All @@ -297,7 +317,7 @@ Status UccTableGatherImpl::IgatherBufferData(
ucc_coll_args_t &args = args_[buf_idx];

args.mask = 0;
args.coll_type = UCC_COLL_TYPE_GATHERV;
args.coll_type = UCC_COLL_TYPE_ALLGATHERV;
args.root = gather_root;

args.src.info.buffer = const_cast<uint8_t *>(send_data);
Expand All @@ -313,6 +333,19 @@ Status UccTableGatherImpl::IgatherBufferData(
(ucc_aint_t *)displacements.data();
args.dst.info_v.datatype = UCC_DT_UINT8;
args.dst.info_v.mem_type = UCC_MEMORY_TYPE_HOST;
} else {
int sum = 0;
auto& recv_counts_ = (*all_recv_counts_)[buf_idx];
for(auto count: recv_counts_) {
sum += count;
}
Comment on lines +337 to +341
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

use std::accumulate

recv_data_placeholder = new uint8_t[sum];
args.dst.info_v.buffer = recv_data_placeholder;

args.dst.info_v.counts = (ucc_count_t *)(*all_recv_counts_)[buf_idx].data();
args.dst.info_v.displacements = (ucc_aint_t *)(*displacements_)[buf_idx].data();
args.dst.info_v.datatype = UCC_DT_UINT8;
args.dst.info_v.mem_type = UCC_MEMORY_TYPE_HOST;
Comment on lines +345 to +348
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I actually think we dont need these! we just need to pass recv_data_placeholder buffer and make sure everyone (other than the root) copies data to this dummy buffer. So you might not even need to track all_recv_counts_ and displacements_.

}

RETURN_CYLON_STATUS_IF_UCC_FAILED(
Expand All @@ -327,6 +360,12 @@ Status UccTableGatherImpl::WaitAll(int32_t num_buffers) {
return WaitAllHelper(requests_, ucc_context_);
}

UccTableGatherImpl::~UccTableGatherImpl() {
delete displacements_;
delete all_recv_counts_;
delete recv_data_placeholder;
}

UccTableBcastImpl::UccTableBcastImpl(ucc_team_h ucc_team, ucc_context_h ucc_context)
: ucc_team_(ucc_team), ucc_context_(ucc_context) {}

Expand Down
7 changes: 6 additions & 1 deletion cpp/src/cylon/net/ucc/ucc_operations.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ class UccTableGatherImpl : public net::TableGatherImpl {
UccTableGatherImpl(ucc_team_h ucc_team, ucc_context_h ucc_context, int rank,
int world_size);

~UccTableGatherImpl() override = default;
~UccTableGatherImpl() override;

void Init(int32_t num_buffers) override;

Expand All @@ -70,6 +70,11 @@ class UccTableGatherImpl : public net::TableGatherImpl {
Status WaitAll(int32_t num_buffers) override;

private:
// the following three are to mimic gather using allgather
std::vector<std::vector<int>>* displacements_;
std::vector<std::vector<int>>* all_recv_counts_;
Comment on lines +74 to +75
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have a feeling that you might not even need these. I've explained it in a previous comment

uint8_t* recv_data_placeholder;

std::vector<ucc_coll_req_h> requests_;
std::vector<ucc_coll_args_t> args_;
ucc_team_h ucc_team_;
Expand Down
Loading