Skip to content

Commit 8f599ce

Browse files
youngeunkwon0405pre-commit-ci[bot]
authored andcommitted
Improving communication overlap for the case of multi kernel queue usage (#1308)
* draft implementation Signed-off-by: Youngeun Kwon <[email protected]> * compile error fix Signed-off-by: Youngeun Kwon <[email protected]> * fix compile error Signed-off-by: Youngeun Kwon <[email protected]> * remove print Signed-off-by: Youngeun Kwon <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Edit comments Signed-off-by: Youngeun Kwon <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * edit the bulk-overlap test case Signed-off-by: Youngeun Kwon <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * add version guard Signed-off-by: Youngeun Kwon <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * add runtime version guard Signed-off-by: Youngeun Kwon <[email protected]> * fix the version guard Signed-off-by: Youngeun Kwon <[email protected]> --------- Signed-off-by: Youngeun Kwon <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 1d1c3a6 commit 8f599ce

File tree

5 files changed

+157
-43
lines changed

5 files changed

+157
-43
lines changed

tests/pytorch/distributed/test_comm_gemm_overlap.py

Lines changed: 27 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -209,19 +209,39 @@ def test_atomic_gemm_overlaps(ag_type, rs_type, p2p, fp8_out):
209209

210210

211211
@pytest.mark.parametrize(
212-
"comm_type,fp8",
212+
"comm_type, fp8, connections",
213213
[
214-
("AG", False),
215-
("RS", False),
216-
("RS", True),
214+
("AG", False, 1),
215+
("RS", False, 1),
216+
("RS", True, 1),
217+
("AG", False, 8),
218+
("RS", False, 8),
219+
("RS", True, 8),
220+
],
221+
ids=[
222+
"ALL-GATHER - BF16 - 1 connections",
223+
"REDUCE-SCATTER - BF16 - 1 connections",
224+
"REDUCE-SCATTER - FP8 - 1 connections",
225+
"ALL-GATHER - BF16 - 8 connections",
226+
"REDUCE-SCATTER - BF16 - 8 connections",
227+
"REDUCE-SCATTER - FP8 - 8 connections",
217228
],
218-
ids=[" ALL-GATHER - BF16 ", " REDUCE-SCATTER - BF16 ", " REDUCE-SCATTER - FP8 "],
219229
)
220-
def test_bulk_overlaps(comm_type, fp8):
230+
def test_bulk_overlaps(comm_type, fp8, connections):
221231
"""
222232
Test bulk overlaps with direct calls to te.cpp_extensions.gemm or te.cpp_extensions.fp8_gemm.
223233
"""
224-
_run_gemm_with_overlap(comm_type, True, False, False, fp8, False, False)
234+
if connections == 8:
235+
if torch.cuda.get_device_properties(0).major != 9:
236+
pytest.skip(
237+
"CUDA_DEVICE_MAX_CONNECTIONS=8 test only applies to devices with compute capability"
238+
" 9.0 (HOPPER ARCH)."
239+
)
240+
os.environ["CUDA_DEVICE_MAX_CONNECTIONS"] = "8"
241+
_run_gemm_with_overlap(comm_type, True, False, False, fp8, False, False)
242+
os.environ["CUDA_DEVICE_MAX_CONNECTIONS"] = "1"
243+
else:
244+
_run_gemm_with_overlap(comm_type, True, False, False, fp8, False, False)
225245

226246

227247
@pytest.mark.parametrize(

transformer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp

Lines changed: 27 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -90,13 +90,31 @@ CommOverlapCore::CommOverlapCore(int myrank, int numranks, int mylocal, int numl
9090
cudaEventCreateWithFlags(&_stop_compute, 0);
9191
cudaEventCreateWithFlags(&_start_comm, 0);
9292
cudaEventCreateWithFlags(&_stop_comm, 0);
93+
94+
/*
95+
Defining the launcher order between the communication and GEMM kernels
96+
using Fast Dependent Launch when CUDA_DEVICE_MAX_CONNECTIONS>1.
97+
The event is used to schedule the communication kernel before the GEMM.
98+
This is needed only for Hopper, which uses persistent CTA execution.
99+
*/
100+
int max_connection = transformer_engine::getenv<int>("CUDA_DEVICE_MAX_CONNECTIONS", 8);
101+
int runtime_version = 0;
102+
cudaRuntimeGetVersion(&runtime_version);
103+
cudaDeviceProp deviceProp;
104+
cudaGetDeviceProperties(&deviceProp, 0);
105+
if (runtime_version >= 12030 && deviceProp.major == 9 && max_connection > 1) {
106+
cudaEventCreateWithFlags(&_comm_launch_event, cudaEventDisableTiming);
107+
} else {
108+
_comm_launch_event = 0;
109+
}
93110
}
94111

95112
CommOverlapCore::~CommOverlapCore() {
96113
cudaEventDestroy(_stop_comm);
97114
cudaEventDestroy(_start_comm);
98115
cudaEventDestroy(_stop_compute);
99116
cudaEventDestroy(_start_compute);
117+
if (_comm_launch_event) cudaEventDestroy(_comm_launch_event);
100118

101119
if (_atomic_gemm) cudaFree(_counter.dptr());
102120

@@ -168,7 +186,8 @@ void CommOverlapBase::bulk_overlap(TensorWrapper &A, bool transa, TensorWrapper
168186
// Communication: AG and RS
169187
int comm_elements = (_ubuf.numel() / 2) * _ubuf.element_size(); // UBUF uses 2Byte element size
170188
if (comm_type == CommOverlapType::AG) {
171-
allgather2_userbuff_inplace(_ub_reg, 0, comm_elements, _ub_comm, _stream_comm);
189+
allgather2_userbuff_inplace(_ub_reg, 0, comm_elements, _ub_comm, _stream_comm,
190+
(cudaEvent_t)_comm_launch_event);
172191
} else {
173192
if (_ubuf.element_size() == 1) {
174193
assert(_ubuf_scale_inv_initialized);
@@ -178,13 +197,18 @@ void CommOverlapBase::bulk_overlap(TensorWrapper &A, bool transa, TensorWrapper
178197
assert(rs_output.element_size() == 2);
179198
char *rs_output_ptr = reinterpret_cast<char *>(rs_output.dptr());
180199
reducescatter2_userbuff_fp8<__nv_fp8_e5m2>(rs_output_ptr, _ubuf_scale_inv, _ub_reg, 0,
181-
comm_elements, _ub_comm, _stream_comm);
200+
comm_elements, _ub_comm, _stream_comm,
201+
(cudaEvent_t)_comm_launch_event);
182202
} else {
183-
reducescatter2_userbuff_inplace(_ub_reg, 0, comm_elements, _ub_comm, _stream_comm);
203+
reducescatter2_userbuff_inplace(_ub_reg, 0, comm_elements, _ub_comm, _stream_comm,
204+
(cudaEvent_t)_comm_launch_event);
184205
}
185206
}
186207

187208
assert(pre_gelu_out.numel() == 0);
209+
// When the kernel launch order is defined, enforce the GEMM kernel launch to wait for the communication kernel launch
210+
if (_comm_launch_event)
211+
NVTE_CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)stream_main, _comm_launch_event, 0));
188212
nvte_cublas_gemm(A.data(), B.data(), D.data(), bias.data(), pre_gelu_out.data(), transa, transb,
189213
grad, workspace.data(), accumulate, use_split_accumulator, _math_sms,
190214
stream_main);

transformer_engine/common/comm_gemm_overlap/userbuffers/userbuffers.cu

Lines changed: 90 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1366,6 +1366,28 @@ __global__ void __launch_bounds__(MAX_THREADS)
13661366
cfg.attrs = attribute_ub; \
13671367
cfg.numAttrs = comm->sm_arch >= 9 ? 2 : 1;
13681368

1369+
#if (CUDART_VERSION >= 12030)
1370+
#define ADD_LAUNCH_COMPLETION_EVENT(attribute_ub, comm_launch_event) \
1371+
attribute_ub[2].id = cudaLaunchAttributeLaunchCompletionEvent; \
1372+
attribute_ub[2].val.launchCompletionEvent.event = comm_launch_event;
1373+
#define NUM_LAUNCH_ATTRIBUTE_FOR_FDL_LAUNCH 3
1374+
#else
1375+
#define ADD_LAUNCH_COMPLETION_EVENT(attribute_ub, comm_launch_event)
1376+
#define NUM_LAUNCH_ATTRIBUTE_FOR_FDL_LAUNCH 2
1377+
#endif
1378+
1379+
#define SETUP_LAUNCH_CONFIG_WITH_COMPLETION_EVENT(sms, threads, stream, comm_launch_event) \
1380+
cudaLaunchConfig_t cfg = {sms, threads, 0, stream, NULL, 0}; \
1381+
cudaLaunchAttribute attribute_ub[NUM_LAUNCH_ATTRIBUTE_FOR_FDL_LAUNCH] = {}; \
1382+
ADD_LAUNCH_COMPLETION_EVENT(attribute_ub, comm_launch_event) \
1383+
attribute_ub[1].id = cudaLaunchAttributeClusterDimension; \
1384+
attribute_ub[1].val.clusterDim.x = sms % comm->cga_size == 0 ? comm->cga_size : 1; \
1385+
attribute_ub[1].val.clusterDim.y = 1; \
1386+
attribute_ub[1].val.clusterDim.z = 1; \
1387+
attribute_ub[0].id = cudaLaunchAttributeCooperative; \
1388+
cfg.attrs = attribute_ub; \
1389+
cfg.numAttrs = NUM_LAUNCH_ATTRIBUTE_FOR_FDL_LAUNCH;
1390+
13691391
#define callranks_ag(x) \
13701392
if (ar_nvsize == x) { \
13711393
int arg1 = op - NVTE_MAX_OPS, \
@@ -1753,7 +1775,8 @@ void reducescatter2_userbuff_strided_multiatomic(void *output, const int handler
17531775
}
17541776

17551777
void allgather2_userbuff_inplace(const int handler, const int offset, const int elements,
1756-
communicator *comm, cudaStream_t stream) {
1778+
communicator *comm, cudaStream_t stream,
1779+
cudaEvent_t comm_launch_event) {
17571780
const int op = userbuffers_allreduceop_nonsharp2;
17581781
const int ar_firstgpu =
17591782
op == userbuffers_allreduceop_nonsharp ? comm->ar_firstgpu : comm->ar2_firstgpu;
@@ -1766,11 +1789,20 @@ void allgather2_userbuff_inplace(const int handler, const int offset, const int
17661789
int warps = comm->threads / 32;
17671790
if (warps < ar_nvsize) warps = ar_nvsize;
17681791

1769-
SETUP_LAUNCH_CONFIG(sms, warps * 32, stream);
1770-
if (comm->use_mc && (comm->memflags[handler] & UB_MEM_MC_CREATED)) {
1771-
callranks_agMC(2) callranks_agMC(4) callranks_agMC(8)
1792+
if (comm_launch_event) {
1793+
SETUP_LAUNCH_CONFIG_WITH_COMPLETION_EVENT(sms, warps * 32, stream, comm_launch_event);
1794+
if (comm->use_mc && (comm->memflags[handler] & UB_MEM_MC_CREATED)) {
1795+
callranks_agMC(2) callranks_agMC(4) callranks_agMC(8)
1796+
} else {
1797+
callranks_ag(2) callranks_ag(4) callranks_ag(8)
1798+
}
17721799
} else {
1773-
callranks_ag(2) callranks_ag(4) callranks_ag(8)
1800+
SETUP_LAUNCH_CONFIG(sms, warps * 32, stream);
1801+
if (comm->use_mc && (comm->memflags[handler] & UB_MEM_MC_CREATED)) {
1802+
callranks_agMC(2) callranks_agMC(4) callranks_agMC(8)
1803+
} else {
1804+
callranks_ag(2) callranks_ag(4) callranks_ag(8)
1805+
}
17741806
}
17751807
}
17761808

@@ -1790,7 +1822,8 @@ void allgather2_userbuff_inplace_sliced(const int handler, const int offset, con
17901822
}
17911823

17921824
void reducescatter2_userbuff_inplace(const int handler, const int offset, const int elements,
1793-
communicator *comm, cudaStream_t stream) {
1825+
communicator *comm, cudaStream_t stream,
1826+
cudaEvent_t comm_launch_event) {
17941827
const int op = userbuffers_allreduceop_nonsharp2;
17951828
const int ar_firstgpu =
17961829
op == userbuffers_allreduceop_nonsharp ? comm->ar_firstgpu : comm->ar2_firstgpu;
@@ -1803,17 +1836,26 @@ void reducescatter2_userbuff_inplace(const int handler, const int offset, const
18031836
int warps = comm->threads / 32;
18041837
if (warps < ar_nvsize) warps = ar_nvsize;
18051838

1806-
SETUP_LAUNCH_CONFIG(sms, warps * 32, stream);
1807-
if (comm->use_mc && (comm->memflags[handler] & UB_MEM_MC_CREATED)) {
1808-
callranks_rsMC(2) callranks_rsMC(4) callranks_rsMC(8)
1839+
if (comm_launch_event) {
1840+
SETUP_LAUNCH_CONFIG_WITH_COMPLETION_EVENT(sms, warps * 32, stream, comm_launch_event);
1841+
if (comm->use_mc && (comm->memflags[handler] & UB_MEM_MC_CREATED)) {
1842+
callranks_rsMC(2) callranks_rsMC(4) callranks_rsMC(8)
1843+
} else {
1844+
callranks_rs(2) callranks_rs(4) callranks_rs(8)
1845+
}
18091846
} else {
1810-
callranks_rs(2) callranks_rs(4) callranks_rs(8)
1847+
SETUP_LAUNCH_CONFIG(sms, warps * 32, stream);
1848+
if (comm->use_mc && (comm->memflags[handler] & UB_MEM_MC_CREATED)) {
1849+
callranks_rsMC(2) callranks_rsMC(4) callranks_rsMC(8)
1850+
} else {
1851+
callranks_rs(2) callranks_rs(4) callranks_rs(8)
1852+
}
18111853
}
18121854
}
18131855
void reducescatter2_userbuff_stridedoutput(void *output, const int handler, const int offset,
18141856
const int rowelements, const int colelements,
18151857
const int strideelements, communicator *comm,
1816-
cudaStream_t stream) {
1858+
cudaStream_t stream, cudaEvent_t comm_launch_event) {
18171859
const int elements = rowelements * colelements;
18181860
const int op = userbuffers_allreduceop_nonsharp2;
18191861
const int ar_firstgpu =
@@ -1827,23 +1869,35 @@ void reducescatter2_userbuff_stridedoutput(void *output, const int handler, cons
18271869
int warps = comm->threads / 32;
18281870
if (warps < ar_nvsize) warps = ar_nvsize;
18291871

1830-
SETUP_LAUNCH_CONFIG(sms, warps * 32, stream);
1831-
if (comm->use_mc && (comm->memflags[handler] & UB_MEM_MC_CREATED)) {
1832-
callranks_rs_oopMC(2) callranks_rs_oopMC(4) callranks_rs_oopMC(8)
1872+
if (comm_launch_event) {
1873+
SETUP_LAUNCH_CONFIG_WITH_COMPLETION_EVENT(sms, warps * 32, stream, comm_launch_event);
1874+
if (comm->use_mc && (comm->memflags[handler] & UB_MEM_MC_CREATED)) {
1875+
callranks_rs_oopMC(2) callranks_rs_oopMC(4) callranks_rs_oopMC(8)
1876+
} else {
1877+
callranks_rs_oop(2) callranks_rs_oop(4) callranks_rs_oop(8)
1878+
}
18331879
} else {
1834-
callranks_rs_oop(2) callranks_rs_oop(4) callranks_rs_oop(8)
1880+
SETUP_LAUNCH_CONFIG(sms, warps * 32, stream);
1881+
if (comm->use_mc && (comm->memflags[handler] & UB_MEM_MC_CREATED)) {
1882+
callranks_rs_oopMC(2) callranks_rs_oopMC(4) callranks_rs_oopMC(8)
1883+
} else {
1884+
callranks_rs_oop(2) callranks_rs_oop(4) callranks_rs_oop(8)
1885+
}
18351886
}
18361887
}
18371888
void reducescatter2_userbuff(void *output, const int handler, const int offset, const int elements,
1838-
communicator *comm, cudaStream_t stream) {
1839-
reducescatter2_userbuff_stridedoutput(output, handler, offset, elements, 1, 0, comm, stream);
1889+
communicator *comm, cudaStream_t stream,
1890+
cudaEvent_t comm_launch_event) {
1891+
reducescatter2_userbuff_stridedoutput(output, handler, offset, elements, 1, 0, comm, stream,
1892+
comm_launch_event);
18401893
}
18411894

18421895
template <typename fp8type>
18431896
void reducescatter2_userbuff_stridedoutput_fp8(void *output, float *scale, const int handler,
18441897
const int offset, const int rowelements,
18451898
const int colelements, const int strideelements,
1846-
communicator *comm, cudaStream_t stream) {
1899+
communicator *comm, cudaStream_t stream,
1900+
cudaEvent_t comm_launch_event) {
18471901
const int elements = rowelements * colelements;
18481902
const int op = userbuffers_allreduceop_nonsharp2;
18491903
const int ar_firstgpu =
@@ -1857,33 +1911,43 @@ void reducescatter2_userbuff_stridedoutput_fp8(void *output, float *scale, const
18571911
int warps = comm->threads / 32;
18581912
if (warps < ar_nvsize) warps = ar_nvsize;
18591913

1860-
SETUP_LAUNCH_CONFIG(sms, warps * 32, stream);
1861-
callranks_rs_oop_fp8(2) callranks_rs_oop_fp8(4) callranks_rs_oop_fp8(8)
1914+
if (comm_launch_event) {
1915+
SETUP_LAUNCH_CONFIG_WITH_COMPLETION_EVENT(sms, warps * 32, stream, comm_launch_event);
1916+
callranks_rs_oop_fp8(2) callranks_rs_oop_fp8(4) callranks_rs_oop_fp8(8)
1917+
} else {
1918+
SETUP_LAUNCH_CONFIG(sms, warps * 32, stream);
1919+
callranks_rs_oop_fp8(2) callranks_rs_oop_fp8(4) callranks_rs_oop_fp8(8)
1920+
}
18621921
}
18631922

18641923
template void reducescatter2_userbuff_stridedoutput_fp8<__nv_fp8_e5m2>(
18651924
void *output, float *scale, const int handler, const int offset, const int rowelements,
1866-
const int colelements, const int strideelements, communicator *comm, cudaStream_t stream);
1925+
const int colelements, const int strideelements, communicator *comm, cudaStream_t stream,
1926+
cudaEvent_t comm_launch_event);
18671927

18681928
template void reducescatter2_userbuff_stridedoutput_fp8<__nv_fp8_e4m3>(
18691929
void *output, float *scale, const int handler, const int offset, const int rowelements,
1870-
const int colelements, const int strideelements, communicator *comm, cudaStream_t stream);
1930+
const int colelements, const int strideelements, communicator *comm, cudaStream_t stream,
1931+
cudaEvent_t comm_launch_event);
18711932

18721933
template <typename fp8type>
18731934
void reducescatter2_userbuff_fp8(void *output, float *scale, const int handler, const int offset,
1874-
const int elements, communicator *comm, cudaStream_t stream) {
1935+
const int elements, communicator *comm, cudaStream_t stream,
1936+
cudaEvent_t comm_launch_event) {
18751937
reducescatter2_userbuff_stridedoutput_fp8<fp8type>(output, scale, handler, offset, elements, 1, 0,
1876-
comm, stream);
1938+
comm, stream, comm_launch_event);
18771939
}
18781940

18791941
template void reducescatter2_userbuff_fp8<__nv_fp8_e5m2>(void *output, float *scale,
18801942
const int handler, const int offset,
18811943
const int elements, communicator *comm,
1882-
cudaStream_t stream);
1944+
cudaStream_t stream,
1945+
cudaEvent_t comm_launch_event);
18831946
template void reducescatter2_userbuff_fp8<__nv_fp8_e4m3>(void *output, float *scale,
18841947
const int handler, const int offset,
18851948
const int elements, communicator *comm,
1886-
cudaStream_t stream);
1949+
cudaStream_t stream,
1950+
cudaEvent_t comm_launch_event);
18871951

18881952
template void reducescatter2_userbuff_strided_atomic_fp8<__nv_fp8_e4m3>(
18891953
void *output, float *scale, const int handler, const int offset, const int rowelements,

transformer_engine/common/comm_gemm_overlap/userbuffers/userbuffers.h

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -213,7 +213,8 @@ int register_user_buffer_collective(void **gpubuff, size_t bytes, communicator *
213213

214214
// for TP-parallelism, only single node is implemented
215215
void allgather2_userbuff_inplace(const int handler, const int offset, const int elements,
216-
communicator *comm, cudaStream_t stream = 0);
216+
communicator *comm, cudaStream_t stream = 0,
217+
cudaEvent_t comm_launch_event = 0);
217218
/*
218219
each Rank input is
219220
allgather2_userbuff_inplace: offset+myrank*elements
@@ -228,21 +229,26 @@ for(int slice=0;slice<ncslices;slice++)
228229
allgather2_userbuff_inplace(hndl,offset, elements*nslices,comm,stream);
229230
*/
230231
void reducescatter2_userbuff_inplace(const int handler, const int offset, const int elements,
231-
communicator *comm, cudaStream_t stream = 0);
232+
communicator *comm, cudaStream_t stream = 0,
233+
cudaEvent_t comm_launch_event = 0);
232234
void reducescatter2_userbuff(void *output, const int handler, const int offset, const int elements,
233-
communicator *comm, cudaStream_t stream = 0);
235+
communicator *comm, cudaStream_t stream = 0,
236+
cudaEvent_t comm_launch_event = 0);
234237
void reducescatter2_userbuff_stridedoutput(void *output, const int handler, const int offset,
235238
const int rowelements, const int colelements,
236239
const int strideelements, communicator *comm,
237-
cudaStream_t stream = 0);
240+
cudaStream_t stream = 0,
241+
cudaEvent_t comm_launch_event = 0);
238242
template <typename fp8type>
239243
void reducescatter2_userbuff_stridedoutput_fp8(void *output, float *scale, const int handler,
240244
const int offset, const int rowelements,
241245
const int colelements, const int strideelements,
242-
communicator *comm, cudaStream_t stream = 0);
246+
communicator *comm, cudaStream_t stream = 0,
247+
cudaEvent_t comm_launch_event = 0);
243248
template <typename fp8type>
244249
void reducescatter2_userbuff_fp8(void *output, float *scale, const int handler, const int offset,
245-
const int elements, communicator *comm, cudaStream_t stream = 0);
250+
const int elements, communicator *comm, cudaStream_t stream = 0,
251+
cudaEvent_t comm_launch_event = 0);
246252
template <typename fp8type>
247253
void reducescatter2_userbuff_strided_atomic_fp8(void *output, float *scale, const int handler,
248254
const int offset, const int rowelements,

transformer_engine/common/include/transformer_engine/comm_gemm_overlap.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ class CommOverlapCore {
6262
bool _ubuf_scale_inv_initialized{false};
6363

6464
std::vector<cudaStream_t> _stream_compute;
65-
cudaEvent_t _start_compute, _stop_compute, _start_comm, _stop_comm;
65+
cudaEvent_t _start_compute, _stop_compute, _start_comm, _stop_comm, _comm_launch_event;
6666

6767
public:
6868
CommOverlapCore(int myrank, int numranks, int mylocal, int numlocal, int mynode, int numnodes,

0 commit comments

Comments
 (0)