Skip to content

Commit

Permalink
Added maximum gridDim.y overflow heck before calling transposeNoOverl…
Browse files Browse the repository at this point in the history
…ap kernel so that TransposeBigMLFloat16 test passes
  • Loading branch information
sstamenk committed Nov 15, 2024
1 parent d906a82 commit 0e45a53
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 1 deletion.
19 changes: 19 additions & 0 deletions onnxruntime/core/providers/rocm/fpgeneric.cu
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,25 @@ __global__ void CopyVectorBFloat16(const onnxruntime::BFloat16* x, int incx, onn

} // namespace

dim3 rocblasTransposeHelperDimGrid(int m, int n) {
return dim3((n + TRANS_TILE_DIM - 1) / TRANS_TILE_DIM, (m + TRANS_TILE_DIM - 1) / TRANS_TILE_DIM, 1);
}

// rocblasTransposeHelper can only be used if it won't overflow the maxGridSize y dimension size
__host__ bool CanUse_rocblasTransposeHelper_MLFloat16(int m, int n) {
dim3 dimGrid = rocblasTransposeHelperDimGrid(m, n);

int deviceId;
hipError_t hipError = hipGetDevice(&deviceId);
if(hipError != 0) return false;

Check warning on line 66 in onnxruntime/core/providers/rocm/fpgeneric.cu

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/core/providers/rocm/fpgeneric.cu#L66

Missing space before ( in if( [whitespace/parens] [5]
Raw output
onnxruntime/core/providers/rocm/fpgeneric.cu:66:  Missing space before ( in if(  [whitespace/parens] [5]

hipDeviceProp_t deviceProp;
hipError = hipGetDeviceProperties(&deviceProp, deviceId);
if(hipError != 0) return false;

Check warning on line 70 in onnxruntime/core/providers/rocm/fpgeneric.cu

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/core/providers/rocm/fpgeneric.cu#L70

Missing space before ( in if( [whitespace/parens] [5]
Raw output
onnxruntime/core/providers/rocm/fpgeneric.cu:70:  Missing space before ( in if(  [whitespace/parens] [5]

return dimGrid.y < deviceProp.maxGridSize[1];
}

rocblas_status rocblasTransposeHelper(hipStream_t stream, rocblas_handle, rocblas_operation , rocblas_operation , int m, int n, const half*, const half* A, int, const half*, const half*, int, half* C, int) {
if (C != A) {
dim3 dimGrid((n + TRANS_TILE_DIM - 1) / TRANS_TILE_DIM, (m + TRANS_TILE_DIM - 1) / TRANS_TILE_DIM, 1);
Expand Down
2 changes: 1 addition & 1 deletion onnxruntime/core/providers/rocm/shared_inc/fpgeneric.h
Original file line number Diff line number Diff line change
Expand Up @@ -470,7 +470,7 @@ inline rocblas_status rocblasTransposeHelper(hipStream_t /*stream*/, rocblas_han
return rocblas_dgeam(handle, transa, transb, m, n, alpha, A, lda, beta, B, ldb, C, ldc);
}

inline bool CanUse_rocblasTransposeHelper_MLFloat16(int /*m*/, int /*n*/) { return true; } // CUDA has a limited grid size of 65536, ROCm has higher limits.
bool CanUse_rocblasTransposeHelper_MLFloat16(int m, int n);
rocblas_status rocblasTransposeHelper(hipStream_t stream, rocblas_handle, rocblas_operation, rocblas_operation, int m, int n, const half*, const half* A, int, const half*, const half*, int, half* C, int);

// copy
Expand Down

0 comments on commit 0e45a53

Please sign in to comment.