@@ -232,9 +232,9 @@ void CommOverlapBase::atomic_gemm_overlap_rs(TensorWrapper &A, bool transa, Tens
232
232
_ub_comm->sms = _num_comm_sm;
233
233
_ub_comm->cga_size = _cga_size;
234
234
// Get GEMM dimensions
235
- size_t m = A.size (0 );
236
- size_t k = A.size (1 );
237
- size_t n = B.size (0 );
235
+ size_t m = (transa) ? A.size (0 ) : A. size ( 1 );
236
+ size_t k = (transa) ? A.size (1 ) : A. size ( 0 );
237
+ size_t n = (transb) ? B. size ( 1 ) : B.size (0 );
238
238
size_t m_chunk = m / _num_splits;
239
239
size_t workspace_size_chunk = workspace.numel () / _stream_compute.size ();
240
240
@@ -332,9 +332,9 @@ void CommOverlapBase::split_overlap_rs(TensorWrapper &A, bool transa, TensorWrap
332
332
_ub_comm->use_ce = _use_ce;
333
333
_ub_comm->sms = _num_comm_sm;
334
334
_ub_comm->cga_size = _cga_size;
335
- size_t m = A.size (0 );
336
- size_t k = A.size (1 );
337
- size_t n = B.size (0 );
335
+ size_t m = (transa) ? A.size (0 ) : A. size ( 1 );
336
+ size_t k = (transa) ? A.size (1 ) : A. size ( 0 );
337
+ size_t n = (transb) ? B. size ( 1 ) : B.size (0 );
338
338
size_t m_chunk = m / _num_splits;
339
339
size_t input_a_chunk_size = m_chunk * k;
340
340
size_t output_chunk_size = n * m_chunk;
@@ -930,8 +930,8 @@ void CommOverlapP2PBase::split_overlap_rs(TensorWrapper &A, bool transa, TensorW
930
930
_ub_comm->use_ce = _use_ce;
931
931
_ub_comm->sms = _num_comm_sm;
932
932
_ub_comm->cga_size = _cga_size;
933
- size_t k = A.size (1 );
934
- size_t n = B.size (0 );
933
+ size_t k = (transa) ? A.size (1 ) : A. size ( 0 );
934
+ size_t n = (transb) ? B. size ( 1 ) : B.size (0 );
935
935
936
936
// Get communication and GEMM input chunk sizes
937
937
size_t n_chunk = n / _tp_size;
0 commit comments