@@ -1366,6 +1366,28 @@ __global__ void __launch_bounds__(MAX_THREADS)
1366
1366
cfg.attrs = attribute_ub; \
1367
1367
cfg.numAttrs = comm->sm_arch >= 9 ? 2 : 1 ;
1368
1368
1369
+ #if (CUDART_VERSION >= 12030)
1370
+ #define ADD_LAUNCH_COMPLETION_EVENT (attribute_ub, comm_launch_event ) \
1371
+ attribute_ub[2 ].id = cudaLaunchAttributeLaunchCompletionEvent; \
1372
+ attribute_ub[2 ].val.launchCompletionEvent.event = comm_launch_event;
1373
+ #define NUM_LAUNCH_ATTRIBUTE_FOR_FDL_LAUNCH 3
1374
+ #else
1375
+ #define ADD_LAUNCH_COMPLETION_EVENT (attribute_ub, comm_launch_event )
1376
+ #define NUM_LAUNCH_ATTRIBUTE_FOR_FDL_LAUNCH 2
1377
+ #endif
1378
+
1379
+ #define SETUP_LAUNCH_CONFIG_WITH_COMPLETION_EVENT (sms, threads, stream, comm_launch_event ) \
1380
+ cudaLaunchConfig_t cfg = {sms, threads, 0 , stream, NULL , 0 }; \
1381
+ cudaLaunchAttribute attribute_ub[NUM_LAUNCH_ATTRIBUTE_FOR_FDL_LAUNCH] = {}; \
1382
+ ADD_LAUNCH_COMPLETION_EVENT (attribute_ub, comm_launch_event) \
1383
+ attribute_ub[1].id = cudaLaunchAttributeClusterDimension; \
1384
+ attribute_ub[1 ].val.clusterDim.x = sms % comm->cga_size == 0 ? comm->cga_size : 1 ; \
1385
+ attribute_ub[1 ].val.clusterDim.y = 1 ; \
1386
+ attribute_ub[1 ].val.clusterDim.z = 1 ; \
1387
+ attribute_ub[0 ].id = cudaLaunchAttributeCooperative; \
1388
+ cfg.attrs = attribute_ub; \
1389
+ cfg.numAttrs = NUM_LAUNCH_ATTRIBUTE_FOR_FDL_LAUNCH;
1390
+
1369
1391
#define callranks_ag (x ) \
1370
1392
if (ar_nvsize == x) { \
1371
1393
int arg1 = op - NVTE_MAX_OPS, \
@@ -1753,7 +1775,8 @@ void reducescatter2_userbuff_strided_multiatomic(void *output, const int handler
1753
1775
}
1754
1776
1755
1777
void allgather2_userbuff_inplace (const int handler, const int offset, const int elements,
1756
- communicator *comm, cudaStream_t stream) {
1778
+ communicator *comm, cudaStream_t stream,
1779
+ cudaEvent_t comm_launch_event) {
1757
1780
const int op = userbuffers_allreduceop_nonsharp2;
1758
1781
const int ar_firstgpu =
1759
1782
op == userbuffers_allreduceop_nonsharp ? comm->ar_firstgpu : comm->ar2_firstgpu ;
@@ -1766,11 +1789,20 @@ void allgather2_userbuff_inplace(const int handler, const int offset, const int
1766
1789
int warps = comm->threads / 32 ;
1767
1790
if (warps < ar_nvsize) warps = ar_nvsize;
1768
1791
1769
- SETUP_LAUNCH_CONFIG (sms, warps * 32 , stream);
1770
- if (comm->use_mc && (comm->memflags [handler] & UB_MEM_MC_CREATED)) {
1771
- callranks_agMC (2 ) callranks_agMC (4 ) callranks_agMC (8 )
1792
+ if (comm_launch_event) {
1793
+ SETUP_LAUNCH_CONFIG_WITH_COMPLETION_EVENT (sms, warps * 32 , stream, comm_launch_event);
1794
+ if (comm->use_mc && (comm->memflags [handler] & UB_MEM_MC_CREATED)) {
1795
+ callranks_agMC (2 ) callranks_agMC (4 ) callranks_agMC (8 )
1796
+ } else {
1797
+ callranks_ag (2 ) callranks_ag (4 ) callranks_ag (8 )
1798
+ }
1772
1799
} else {
1773
- callranks_ag (2 ) callranks_ag (4 ) callranks_ag (8 )
1800
+ SETUP_LAUNCH_CONFIG (sms, warps * 32 , stream);
1801
+ if (comm->use_mc && (comm->memflags [handler] & UB_MEM_MC_CREATED)) {
1802
+ callranks_agMC (2 ) callranks_agMC (4 ) callranks_agMC (8 )
1803
+ } else {
1804
+ callranks_ag (2 ) callranks_ag (4 ) callranks_ag (8 )
1805
+ }
1774
1806
}
1775
1807
}
1776
1808
@@ -1790,7 +1822,8 @@ void allgather2_userbuff_inplace_sliced(const int handler, const int offset, con
1790
1822
}
1791
1823
1792
1824
void reducescatter2_userbuff_inplace (const int handler, const int offset, const int elements,
1793
- communicator *comm, cudaStream_t stream) {
1825
+ communicator *comm, cudaStream_t stream,
1826
+ cudaEvent_t comm_launch_event) {
1794
1827
const int op = userbuffers_allreduceop_nonsharp2;
1795
1828
const int ar_firstgpu =
1796
1829
op == userbuffers_allreduceop_nonsharp ? comm->ar_firstgpu : comm->ar2_firstgpu ;
@@ -1803,17 +1836,26 @@ void reducescatter2_userbuff_inplace(const int handler, const int offset, const
1803
1836
int warps = comm->threads / 32 ;
1804
1837
if (warps < ar_nvsize) warps = ar_nvsize;
1805
1838
1806
- SETUP_LAUNCH_CONFIG (sms, warps * 32 , stream);
1807
- if (comm->use_mc && (comm->memflags [handler] & UB_MEM_MC_CREATED)) {
1808
- callranks_rsMC (2 ) callranks_rsMC (4 ) callranks_rsMC (8 )
1839
+ if (comm_launch_event) {
1840
+ SETUP_LAUNCH_CONFIG_WITH_COMPLETION_EVENT (sms, warps * 32 , stream, comm_launch_event);
1841
+ if (comm->use_mc && (comm->memflags [handler] & UB_MEM_MC_CREATED)) {
1842
+ callranks_rsMC (2 ) callranks_rsMC (4 ) callranks_rsMC (8 )
1843
+ } else {
1844
+ callranks_rs (2 ) callranks_rs (4 ) callranks_rs (8 )
1845
+ }
1809
1846
} else {
1810
- callranks_rs (2 ) callranks_rs (4 ) callranks_rs (8 )
1847
+ SETUP_LAUNCH_CONFIG (sms, warps * 32 , stream);
1848
+ if (comm->use_mc && (comm->memflags [handler] & UB_MEM_MC_CREATED)) {
1849
+ callranks_rsMC (2 ) callranks_rsMC (4 ) callranks_rsMC (8 )
1850
+ } else {
1851
+ callranks_rs (2 ) callranks_rs (4 ) callranks_rs (8 )
1852
+ }
1811
1853
}
1812
1854
}
1813
1855
void reducescatter2_userbuff_stridedoutput (void *output, const int handler, const int offset,
1814
1856
const int rowelements, const int colelements,
1815
1857
const int strideelements, communicator *comm,
1816
- cudaStream_t stream) {
1858
+ cudaStream_t stream, cudaEvent_t comm_launch_event ) {
1817
1859
const int elements = rowelements * colelements;
1818
1860
const int op = userbuffers_allreduceop_nonsharp2;
1819
1861
const int ar_firstgpu =
@@ -1827,23 +1869,35 @@ void reducescatter2_userbuff_stridedoutput(void *output, const int handler, cons
1827
1869
int warps = comm->threads / 32 ;
1828
1870
if (warps < ar_nvsize) warps = ar_nvsize;
1829
1871
1830
- SETUP_LAUNCH_CONFIG (sms, warps * 32 , stream);
1831
- if (comm->use_mc && (comm->memflags [handler] & UB_MEM_MC_CREATED)) {
1832
- callranks_rs_oopMC (2 ) callranks_rs_oopMC (4 ) callranks_rs_oopMC (8 )
1872
+ if (comm_launch_event) {
1873
+ SETUP_LAUNCH_CONFIG_WITH_COMPLETION_EVENT (sms, warps * 32 , stream, comm_launch_event);
1874
+ if (comm->use_mc && (comm->memflags [handler] & UB_MEM_MC_CREATED)) {
1875
+ callranks_rs_oopMC (2 ) callranks_rs_oopMC (4 ) callranks_rs_oopMC (8 )
1876
+ } else {
1877
+ callranks_rs_oop (2 ) callranks_rs_oop (4 ) callranks_rs_oop (8 )
1878
+ }
1833
1879
} else {
1834
- callranks_rs_oop (2 ) callranks_rs_oop (4 ) callranks_rs_oop (8 )
1880
+ SETUP_LAUNCH_CONFIG (sms, warps * 32 , stream);
1881
+ if (comm->use_mc && (comm->memflags [handler] & UB_MEM_MC_CREATED)) {
1882
+ callranks_rs_oopMC (2 ) callranks_rs_oopMC (4 ) callranks_rs_oopMC (8 )
1883
+ } else {
1884
+ callranks_rs_oop (2 ) callranks_rs_oop (4 ) callranks_rs_oop (8 )
1885
+ }
1835
1886
}
1836
1887
}
1837
1888
void reducescatter2_userbuff (void *output, const int handler, const int offset, const int elements,
1838
- communicator *comm, cudaStream_t stream) {
1839
- reducescatter2_userbuff_stridedoutput (output, handler, offset, elements, 1 , 0 , comm, stream);
1889
+ communicator *comm, cudaStream_t stream,
1890
+ cudaEvent_t comm_launch_event) {
1891
+ reducescatter2_userbuff_stridedoutput (output, handler, offset, elements, 1 , 0 , comm, stream,
1892
+ comm_launch_event);
1840
1893
}
1841
1894
1842
1895
template <typename fp8type>
1843
1896
void reducescatter2_userbuff_stridedoutput_fp8 (void *output, float *scale, const int handler,
1844
1897
const int offset, const int rowelements,
1845
1898
const int colelements, const int strideelements,
1846
- communicator *comm, cudaStream_t stream) {
1899
+ communicator *comm, cudaStream_t stream,
1900
+ cudaEvent_t comm_launch_event) {
1847
1901
const int elements = rowelements * colelements;
1848
1902
const int op = userbuffers_allreduceop_nonsharp2;
1849
1903
const int ar_firstgpu =
@@ -1857,33 +1911,43 @@ void reducescatter2_userbuff_stridedoutput_fp8(void *output, float *scale, const
1857
1911
int warps = comm->threads / 32 ;
1858
1912
if (warps < ar_nvsize) warps = ar_nvsize;
1859
1913
1860
- SETUP_LAUNCH_CONFIG (sms, warps * 32 , stream);
1861
- callranks_rs_oop_fp8 (2 ) callranks_rs_oop_fp8 (4 ) callranks_rs_oop_fp8 (8 )
1914
+ if (comm_launch_event) {
1915
+ SETUP_LAUNCH_CONFIG_WITH_COMPLETION_EVENT (sms, warps * 32 , stream, comm_launch_event);
1916
+ callranks_rs_oop_fp8 (2 ) callranks_rs_oop_fp8 (4 ) callranks_rs_oop_fp8 (8 )
1917
+ } else {
1918
+ SETUP_LAUNCH_CONFIG (sms, warps * 32 , stream);
1919
+ callranks_rs_oop_fp8 (2 ) callranks_rs_oop_fp8 (4 ) callranks_rs_oop_fp8 (8 )
1920
+ }
1862
1921
}
1863
1922
1864
1923
template void reducescatter2_userbuff_stridedoutput_fp8<__nv_fp8_e5m2>(
1865
1924
void *output, float *scale, const int handler, const int offset, const int rowelements,
1866
- const int colelements, const int strideelements, communicator *comm, cudaStream_t stream);
1925
+ const int colelements, const int strideelements, communicator *comm, cudaStream_t stream,
1926
+ cudaEvent_t comm_launch_event);
1867
1927
1868
1928
template void reducescatter2_userbuff_stridedoutput_fp8<__nv_fp8_e4m3>(
1869
1929
void *output, float *scale, const int handler, const int offset, const int rowelements,
1870
- const int colelements, const int strideelements, communicator *comm, cudaStream_t stream);
1930
+ const int colelements, const int strideelements, communicator *comm, cudaStream_t stream,
1931
+ cudaEvent_t comm_launch_event);
1871
1932
1872
1933
template <typename fp8type>
1873
1934
void reducescatter2_userbuff_fp8 (void *output, float *scale, const int handler, const int offset,
1874
- const int elements, communicator *comm, cudaStream_t stream) {
1935
+ const int elements, communicator *comm, cudaStream_t stream,
1936
+ cudaEvent_t comm_launch_event) {
1875
1937
reducescatter2_userbuff_stridedoutput_fp8<fp8type>(output, scale, handler, offset, elements, 1 , 0 ,
1876
- comm, stream);
1938
+ comm, stream, comm_launch_event );
1877
1939
}
1878
1940
1879
1941
template void reducescatter2_userbuff_fp8<__nv_fp8_e5m2>(void *output, float *scale,
1880
1942
const int handler, const int offset,
1881
1943
const int elements, communicator *comm,
1882
- cudaStream_t stream);
1944
+ cudaStream_t stream,
1945
+ cudaEvent_t comm_launch_event);
1883
1946
template void reducescatter2_userbuff_fp8<__nv_fp8_e4m3>(void *output, float *scale,
1884
1947
const int handler, const int offset,
1885
1948
const int elements, communicator *comm,
1886
- cudaStream_t stream);
1949
+ cudaStream_t stream,
1950
+ cudaEvent_t comm_launch_event);
1887
1951
1888
1952
template void reducescatter2_userbuff_strided_atomic_fp8<__nv_fp8_e4m3>(
1889
1953
void *output, float *scale, const int handler, const int offset, const int rowelements,
0 commit comments