-
Notifications
You must be signed in to change notification settings - Fork 47
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Ucx process bootstrap #598
base: main
Are you sure you want to change the base?
Changes from all commits
fbadb97
fbff6d5
c09d345
70937a7
a9b71ee
4c580f1
7494fed
6cefdcc
19d17b7
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -14,6 +14,7 @@ | |
|
||
#include "cylon/net/ucc/ucc_operations.hpp" | ||
#include "cylon/util/macros.hpp" | ||
#include "cylon/net/utils.hpp" | ||
|
||
namespace cylon { | ||
namespace ucc { | ||
|
@@ -251,6 +252,8 @@ UccTableGatherImpl::UccTableGatherImpl(ucc_team_h ucc_team, | |
void UccTableGatherImpl::Init(int32_t num_buffers) { | ||
this->requests_.resize(num_buffers); | ||
this->args_.resize(num_buffers); | ||
this->displacements_ = new std::vector<std::vector<int>>(num_buffers); | ||
this->all_recv_counts_ = new std::vector<std::vector<int>>(num_buffers); | ||
} | ||
|
||
Status UccTableGatherImpl::GatherBufferSizes(const int32_t *send_data, int32_t num_buffers, | ||
|
@@ -259,17 +262,25 @@ Status UccTableGatherImpl::GatherBufferSizes(const int32_t *send_data, int32_t n | |
ucc_coll_req_h req; | ||
|
||
args.mask = 0; | ||
args.coll_type = UCC_COLL_TYPE_GATHER; | ||
args.coll_type = UCC_COLL_TYPE_ALLGATHER; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. put a |
||
args.root = gather_root; | ||
|
||
args.src.info.buffer = const_cast<int32_t *>(send_data); | ||
args.src.info.count = static_cast<uint64_t>(num_buffers); | ||
args.src.info.datatype = UCC_DT_INT32; | ||
args.src.info.mem_type = UCC_MEMORY_TYPE_HOST; | ||
|
||
int32_t total_sz = num_buffers * world_size; | ||
std::vector<int32_t> all_buffer_sizes; | ||
if(rank == gather_root) { | ||
args.dst.info.buffer = rcv_data; | ||
args.dst.info.count = num_buffers * world_size; | ||
args.dst.info.count = total_sz; | ||
args.dst.info.datatype = UCC_DT_INT32; | ||
args.dst.info.mem_type = UCC_MEMORY_TYPE_HOST; | ||
} else { | ||
all_buffer_sizes.resize(total_sz); | ||
args.dst.info.buffer = all_buffer_sizes.data(); | ||
args.dst.info.count = total_sz; | ||
args.dst.info.datatype = UCC_DT_INT32; | ||
args.dst.info.mem_type = UCC_MEMORY_TYPE_HOST; | ||
} | ||
Comment on lines
275
to
286
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. can we do this? if(rank == gather_root) {
args.dst.info.buffer = rcv_data;
} else {
all_buffer_sizes.resize(total_sz);
args.dst.info.buffer = all_buffer_sizes.data();
}
args.dst.info.count = total_sz;
args.dst.info.datatype = UCC_DT_INT32;
args.dst.info.mem_type = UCC_MEMORY_TYPE_HOST; |
||
|
@@ -287,6 +298,15 @@ Status UccTableGatherImpl::GatherBufferSizes(const int32_t *send_data, int32_t n | |
} | ||
|
||
RETURN_CYLON_STATUS_IF_UCC_FAILED(ucc_collective_finalize(req)); | ||
|
||
if(rank != gather_root) { | ||
for (int32_t i = 0; i < num_buffers; ++i) { | ||
(*all_recv_counts_)[i] = cylon::net::receiveCounts(all_buffer_sizes, i, | ||
num_buffers, world_size); | ||
(*displacements_)[i] = std::move(cylon::net::displacementsPerBuffer(all_buffer_sizes, i, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. don't think you need the |
||
num_buffers, world_size)); | ||
} | ||
} | ||
return Status::OK(); | ||
} | ||
|
||
|
@@ -297,7 +317,7 @@ Status UccTableGatherImpl::IgatherBufferData( | |
ucc_coll_args_t &args = args_[buf_idx]; | ||
|
||
args.mask = 0; | ||
args.coll_type = UCC_COLL_TYPE_GATHERV; | ||
args.coll_type = UCC_COLL_TYPE_ALLGATHERV; | ||
args.root = gather_root; | ||
|
||
args.src.info.buffer = const_cast<uint8_t *>(send_data); | ||
|
@@ -313,6 +333,19 @@ Status UccTableGatherImpl::IgatherBufferData( | |
(ucc_aint_t *)displacements.data(); | ||
args.dst.info_v.datatype = UCC_DT_UINT8; | ||
args.dst.info_v.mem_type = UCC_MEMORY_TYPE_HOST; | ||
} else { | ||
int sum = 0; | ||
auto& recv_counts_ = (*all_recv_counts_)[buf_idx]; | ||
for(auto count: recv_counts_) { | ||
sum += count; | ||
} | ||
Comment on lines
+337
to
+341
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. use |
||
recv_data_placeholder = new uint8_t[sum]; | ||
args.dst.info_v.buffer = recv_data_placeholder; | ||
|
||
args.dst.info_v.counts = (ucc_count_t *)(*all_recv_counts_)[buf_idx].data(); | ||
args.dst.info_v.displacements = (ucc_aint_t *)(*displacements_)[buf_idx].data(); | ||
args.dst.info_v.datatype = UCC_DT_UINT8; | ||
args.dst.info_v.mem_type = UCC_MEMORY_TYPE_HOST; | ||
Comment on lines
+345
to
+348
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I actually think we dont need these! we just need to pass |
||
} | ||
|
||
RETURN_CYLON_STATUS_IF_UCC_FAILED( | ||
|
@@ -327,6 +360,12 @@ Status UccTableGatherImpl::WaitAll(int32_t num_buffers) { | |
return WaitAllHelper(requests_, ucc_context_); | ||
} | ||
|
||
UccTableGatherImpl::~UccTableGatherImpl() { | ||
delete displacements_; | ||
delete all_recv_counts_; | ||
delete recv_data_placeholder; | ||
} | ||
|
||
UccTableBcastImpl::UccTableBcastImpl(ucc_team_h ucc_team, ucc_context_h ucc_context) | ||
: ucc_team_(ucc_team), ucc_context_(ucc_context) {} | ||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -53,7 +53,7 @@ class UccTableGatherImpl : public net::TableGatherImpl { | |
UccTableGatherImpl(ucc_team_h ucc_team, ucc_context_h ucc_context, int rank, | ||
int world_size); | ||
|
||
~UccTableGatherImpl() override = default; | ||
~UccTableGatherImpl() override; | ||
|
||
void Init(int32_t num_buffers) override; | ||
|
||
|
@@ -70,6 +70,11 @@ class UccTableGatherImpl : public net::TableGatherImpl { | |
Status WaitAll(int32_t num_buffers) override; | ||
|
||
private: | ||
// the following three are to mimic gather using allgather | ||
std::vector<std::vector<int>>* displacements_; | ||
std::vector<std::vector<int>>* all_recv_counts_; | ||
Comment on lines
+74
to
+75
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I have a feeling that you might not even need these. I've explained it in a previous comment |
||
uint8_t* recv_data_placeholder; | ||
|
||
std::vector<ucc_coll_req_h> requests_; | ||
std::vector<ucc_coll_args_t> args_; | ||
ucc_team_h ucc_team_; | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I dont think you need to malloc for a vector object. You can simply do this,
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh I think I understand why you did this. This is because of the
const
inGatherBufferSizes
method, right?