diff --git a/onnxruntime/core/providers/rocm/fpgeneric.cu b/onnxruntime/core/providers/rocm/fpgeneric.cu index d130758bec084..2e0fd57144380 100644 --- a/onnxruntime/core/providers/rocm/fpgeneric.cu +++ b/onnxruntime/core/providers/rocm/fpgeneric.cu @@ -53,7 +53,26 @@ __global__ void CopyVectorBFloat16(const onnxruntime::BFloat16* x, int incx, onn } // namespace -rocblas_status rocblasTransposeHelper(hipStream_t stream, rocblas_handle, rocblas_operation , rocblas_operation , int m, int n, const half*, const half* A, int, const half*, const half*, int, half* C, int) { +dim3 rocblasTransposeHelperDimGrid(int m, int n) { + return dim3((n + TRANS_TILE_DIM - 1) / TRANS_TILE_DIM, (m + TRANS_TILE_DIM - 1) / TRANS_TILE_DIM, 1); +} + +// rocblasTransposeHelper can only be used if it won't overflow the maxGridSize y dimension size +__host__ bool CanUse_rocblasTransposeHelper_MLFloat16(int m, int n) { + dim3 dimGrid = rocblasTransposeHelperDimGrid(m, n); + + int deviceId; + hipError_t hipError = hipGetDevice(&deviceId); + if (hipError != 0) return false; + + hipDeviceProp_t deviceProp; + hipError = hipGetDeviceProperties(&deviceProp, deviceId); + if (hipError != 0) return false; + + return dimGrid.y < deviceProp.maxGridSize[1]; +} + +rocblas_status rocblasTransposeHelper(hipStream_t stream, rocblas_handle, rocblas_operation, rocblas_operation, int m, int n, const half*, const half* A, int, const half*, const half*, int, half* C, int) { if (C != A) { dim3 dimGrid((n + TRANS_TILE_DIM - 1) / TRANS_TILE_DIM, (m + TRANS_TILE_DIM - 1) / TRANS_TILE_DIM, 1); dim3 dimBlock(TRANS_TILE_DIM, BLOCK_ROWS, 1); @@ -73,7 +92,7 @@ rocblas_status rocblasCopyHelper(hipStream_t stream, rocblas_handle, int n, cons } rocblas_status rocblasCopyHelper(hipStream_t stream, rocblas_handle, int n, const onnxruntime::BFloat16* x, int incx, - onnxruntime::BFloat16* y, int incy) { + onnxruntime::BFloat16* y, int incy) { dim3 dimGrid((unsigned int)(n + COPY_BLOCK_DIM - 1) / COPY_BLOCK_DIM, 1, 1); dim3 dimBlock(COPY_BLOCK_DIM, 1, 1); CopyVectorBFloat16<<>>(x, incx, y, incy, n); diff --git a/onnxruntime/core/providers/rocm/shared_inc/fpgeneric.h b/onnxruntime/core/providers/rocm/shared_inc/fpgeneric.h index d93f70785c093..c165158f7e461 100644 --- a/onnxruntime/core/providers/rocm/shared_inc/fpgeneric.h +++ b/onnxruntime/core/providers/rocm/shared_inc/fpgeneric.h @@ -470,7 +470,7 @@ inline rocblas_status rocblasTransposeHelper(hipStream_t /*stream*/, rocblas_han return rocblas_dgeam(handle, transa, transb, m, n, alpha, A, lda, beta, B, ldb, C, ldc); } -inline bool CanUse_rocblasTransposeHelper_MLFloat16(int /*m*/, int /*n*/) { return true; } // CUDA has a limited grid size of 65536, ROCm has higher limits. +bool CanUse_rocblasTransposeHelper_MLFloat16(int m, int n); rocblas_status rocblasTransposeHelper(hipStream_t stream, rocblas_handle, rocblas_operation, rocblas_operation, int m, int n, const half*, const half* A, int, const half*, const half*, int, half* C, int); // copy