Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add safety check so that TransposeBigMLFloat16 test passes #77

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 21 additions & 2 deletions onnxruntime/core/providers/rocm/fpgeneric.cu
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,26 @@

} // namespace

rocblas_status rocblasTransposeHelper(hipStream_t stream, rocblas_handle, rocblas_operation , rocblas_operation , int m, int n, const half*, const half* A, int, const half*, const half*, int, half* C, int) {
dim3 rocblasTransposeHelperDimGrid(int m, int n) {
return dim3((n + TRANS_TILE_DIM - 1) / TRANS_TILE_DIM, (m + TRANS_TILE_DIM - 1) / TRANS_TILE_DIM, 1);
}

// rocblasTransposeHelper can only be used if it won't overflow the maxGridSize y dimension size
__host__ bool CanUse_rocblasTransposeHelper_MLFloat16(int m, int n) {
dim3 dimGrid = rocblasTransposeHelperDimGrid(m, n);

int deviceId;
hipError_t hipError = hipGetDevice(&deviceId);
if (hipError != 0) return false;

hipDeviceProp_t deviceProp;
hipError = hipGetDeviceProperties(&deviceProp, deviceId);
if (hipError != 0) return false;

return dimGrid.y < deviceProp.maxGridSize[1];
}

rocblas_status rocblasTransposeHelper(hipStream_t stream, rocblas_handle, rocblas_operation, rocblas_operation, int m, int n, const half*, const half* A, int, const half*, const half*, int, half* C, int) {

Check warning on line 75 in onnxruntime/core/providers/rocm/fpgeneric.cu

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/core/providers/rocm/fpgeneric.cu#L75

Lines should be <= 120 characters long [whitespace/line_length] [2]
Raw output
onnxruntime/core/providers/rocm/fpgeneric.cu:75:  Lines should be <= 120 characters long  [whitespace/line_length] [2]
if (C != A) {
dim3 dimGrid((n + TRANS_TILE_DIM - 1) / TRANS_TILE_DIM, (m + TRANS_TILE_DIM - 1) / TRANS_TILE_DIM, 1);
dim3 dimBlock(TRANS_TILE_DIM, BLOCK_ROWS, 1);
Expand All @@ -73,7 +92,7 @@
}

rocblas_status rocblasCopyHelper(hipStream_t stream, rocblas_handle, int n, const onnxruntime::BFloat16* x, int incx,
onnxruntime::BFloat16* y, int incy) {
onnxruntime::BFloat16* y, int incy) {
dim3 dimGrid((unsigned int)(n + COPY_BLOCK_DIM - 1) / COPY_BLOCK_DIM, 1, 1);
dim3 dimBlock(COPY_BLOCK_DIM, 1, 1);
CopyVectorBFloat16<<<dimGrid, dimBlock, 0, stream>>>(x, incx, y, incy, n);
Expand Down
2 changes: 1 addition & 1 deletion onnxruntime/core/providers/rocm/shared_inc/fpgeneric.h
Original file line number Diff line number Diff line change
Expand Up @@ -470,7 +470,7 @@ inline rocblas_status rocblasTransposeHelper(hipStream_t /*stream*/, rocblas_han
return rocblas_dgeam(handle, transa, transb, m, n, alpha, A, lda, beta, B, ldb, C, ldc);
}

inline bool CanUse_rocblasTransposeHelper_MLFloat16(int /*m*/, int /*n*/) { return true; } // CUDA has a limited grid size of 65536, ROCm has higher limits.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why did you remove the inline here?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I moved the implementation of the function inside fpgeneric.cu to mirror the way it was done inside CUDA EP.

bool CanUse_rocblasTransposeHelper_MLFloat16(int m, int n);
rocblas_status rocblasTransposeHelper(hipStream_t stream, rocblas_handle, rocblas_operation, rocblas_operation, int m, int n, const half*, const half* A, int, const half*, const half*, int, half* C, int);

// copy
Expand Down
Loading