@@ -237,6 +237,7 @@ public:
237237 DataType type,
238238 int group0,
239239 int group1,
240+ int cp_size,
240241 const int * local_token_nums,
241242 cudaStream_t stream) override
242243 {
@@ -252,9 +253,8 @@ public:
252253 NCCLCHECK (ncclCommCount (comm0, &tp0));
253254 NCCLCHECK (ncclCommCount (comm1, &tp1));
254255
255- const int inner_tp = std::min (tp0, tp1);
256-
257- FT_CHECK (tp0 % inner_tp == 0 && tp1 % inner_tp == 0 );
256+ FT_CHECK (std::max (tp0, tp1) % std::min (tp0, tp1) == 0 );
257+ const int inner_tp = std::min (tp0, tp1) * cp_size;
258258
259259 std::vector<std::tuple<int , int , int >> tasks;
260260 tasks.reserve (global_n_ranks_);
@@ -289,7 +289,18 @@ public:
289289 sync_check_cuda_error ();
290290 }
291291
292- if (tp1 > 1 ) {
292+ if (cp_size > 1 && tp0 > tp1) {
293+ NCCLCHECK (ncclGroupStart ());
294+ for (int i = 0 ; i < global_n_ranks_; ++i) {
295+ if (auto & [offset, first, num] = tasks[i]; num > 0 ) {
296+ char * buff = (char *)hidden + elem_size * (offset + first) * dim;
297+ NCCLCHECK (ncclBroadcast (buff, buff, (size_t )num * dim, nccl_type, i % tp0, comm0, stream));
298+ }
299+ }
300+ NCCLCHECK (ncclGroupEnd ());
301+ sync_check_cuda_error ();
302+ }
303+ else if (tp1 > 1 ) {
293304 NCCLCHECK (ncclGroupStart ());
294305 for (int i = 0 ; i < global_n_ranks_; ++i) {
295306 if (auto & [offset, first, num] = tasks[i]; num > 0 ) {
0 commit comments