Skip to content

Commit 1797f14

Browse files
committed
fixed GEMM dimensions based on operand transpose flags in TP overlap
Signed-off-by: Alp Dener <[email protected]>
1 parent 5505867 commit 1797f14

File tree

1 file changed

+8
-8
lines changed

1 file changed

+8
-8
lines changed

transformer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -232,9 +232,9 @@ void CommOverlapBase::atomic_gemm_overlap_rs(TensorWrapper &A, bool transa, Tens
232232
_ub_comm->sms = _num_comm_sm;
233233
_ub_comm->cga_size = _cga_size;
234234
// Get GEMM dimensions
235-
size_t m = A.size(0);
236-
size_t k = A.size(1);
237-
size_t n = B.size(0);
235+
size_t m = (transa) ? A.size(0) : A.size(1);
236+
size_t k = (transa) ? A.size(1) : A.size(0);
237+
size_t n = (transb) ? B.size(1) : B.size(0);
238238
size_t m_chunk = m / _num_splits;
239239
size_t workspace_size_chunk = workspace.numel() / _stream_compute.size();
240240

@@ -332,9 +332,9 @@ void CommOverlapBase::split_overlap_rs(TensorWrapper &A, bool transa, TensorWrap
332332
_ub_comm->use_ce = _use_ce;
333333
_ub_comm->sms = _num_comm_sm;
334334
_ub_comm->cga_size = _cga_size;
335-
size_t m = A.size(0);
336-
size_t k = A.size(1);
337-
size_t n = B.size(0);
335+
size_t m = (transa) ? A.size(0) : A.size(1);
336+
size_t k = (transa) ? A.size(1) : A.size(0);
337+
size_t n = (transb) ? B.size(1) : B.size(0);
338338
size_t m_chunk = m / _num_splits;
339339
size_t input_a_chunk_size = m_chunk * k;
340340
size_t output_chunk_size = n * m_chunk;
@@ -930,8 +930,8 @@ void CommOverlapP2PBase::split_overlap_rs(TensorWrapper &A, bool transa, TensorW
930930
_ub_comm->use_ce = _use_ce;
931931
_ub_comm->sms = _num_comm_sm;
932932
_ub_comm->cga_size = _cga_size;
933-
size_t k = A.size(1);
934-
size_t n = B.size(0);
933+
size_t k = (transa) ? A.size(1) : A.size(0);
934+
size_t n = (transb) ? B.size(1) : B.size(0);
935935

936936
// Get communication and GEMM input chunk sizes
937937
size_t n_chunk = n / _tp_size;

0 commit comments

Comments
 (0)