From 0784d32abf9e8679b34d93faa684eece4716aee8 Mon Sep 17 00:00:00 2001 From: Evgeny Mankov Date: Mon, 23 Oct 2023 20:43:26 +0200 Subject: [PATCH] [HIPIFY][6.0.0][rocBLAS] Support for ROCm HIP 6.0.0 - Step 21 - functions `rocblas_(hsh|hss|tst|tss)gemv_strided_batched` + Updated synthetic tests and the regenerated hipify-perl and BLAS docs --- bin/hipify-perl | 8 ++--- .../CUBLAS_API_supported_by_HIP_and_ROC.md | 8 ++--- docs/tables/CUBLAS_API_supported_by_ROC.md | 8 ++--- src/CUDA2HIP_BLAS_API_functions.cpp | 12 ++++--- .../synthetic/libraries/cublas2rocblas.cu | 31 +++++++++++++++++++ 5 files changed, 51 insertions(+), 16 deletions(-) diff --git a/bin/hipify-perl b/bin/hipify-perl index c51d73d5..b79a7863 100755 --- a/bin/hipify-perl +++ b/bin/hipify-perl @@ -1467,7 +1467,9 @@ sub rocSubstitutions { subst("cublasGetVector", "rocblas_get_vector", "library"); subst("cublasGetVectorAsync", "rocblas_get_vector_async", "library"); subst("cublasHSHgemvBatched", "rocblas_hshgemv_batched", "library"); + subst("cublasHSHgemvStridedBatched", "rocblas_hshgemv_strided_batched", "library"); subst("cublasHSSgemvBatched", "rocblas_hssgemv_batched", "library"); + subst("cublasHSSgemvStridedBatched", "rocblas_hssgemv_strided_batched", "library"); subst("cublasHgemm", "rocblas_hgemm", "library"); subst("cublasHgemmBatched", "rocblas_hgemm_batched", "library"); subst("cublasHgemmStridedBatched", "rocblas_hgemm_strided_batched", "library"); @@ -1578,7 +1580,9 @@ sub rocSubstitutions { subst("cublasStrsv", "rocblas_strsv", "library"); subst("cublasStrsv_v2", "rocblas_strsv", "library"); subst("cublasTSSgemvBatched", "rocblas_tssgemv_batched", "library"); + subst("cublasTSSgemvStridedBatched", "rocblas_tssgemv_strided_batched", "library"); subst("cublasTSTgemvBatched", "rocblas_tstgemv_batched", "library"); + subst("cublasTSTgemvStridedBatched", "rocblas_tstgemv_strided_batched", "library"); subst("cublasZaxpy", "rocblas_zaxpy", "library"); subst("cublasZaxpy_v2", "rocblas_zaxpy", "library"); subst("cublasZcopy", "rocblas_zcopy", "library"); @@ -9961,10 +9965,8 @@ sub warnRocOnlyUnsupportedFunctions { "cublasXerbla", "cublasUint8gemmBias", "cublasTSTgemvStridedBatched_64", - "cublasTSTgemvStridedBatched", "cublasTSTgemvBatched_64", "cublasTSSgemvStridedBatched_64", - "cublasTSSgemvStridedBatched", "cublasTSSgemvBatched_64", "cublasSwapEx_64", "cublasSwapEx", @@ -10096,10 +10098,8 @@ sub warnRocOnlyUnsupportedFunctions { "cublasHgemmStridedBatched_64", "cublasHgemmBatched_64", "cublasHSSgemvStridedBatched_64", - "cublasHSSgemvStridedBatched", "cublasHSSgemvBatched_64", "cublasHSHgemvStridedBatched_64", - "cublasHSHgemvStridedBatched", "cublasHSHgemvBatched_64", "cublasGetVersion_v2", "cublasGetVersion", diff --git a/docs/tables/CUBLAS_API_supported_by_HIP_and_ROC.md b/docs/tables/CUBLAS_API_supported_by_HIP_and_ROC.md index c22872bd..9ff2ce05 100644 --- a/docs/tables/CUBLAS_API_supported_by_HIP_and_ROC.md +++ b/docs/tables/CUBLAS_API_supported_by_HIP_and_ROC.md @@ -799,11 +799,11 @@ |`cublasDtrsm_v2_64`|12.0| | | | | | | | | | | | | | | |`cublasHSHgemvBatched`|11.6| | | | | | | | |`rocblas_hshgemv_batched`|6.0.0| | | |6.0.0| |`cublasHSHgemvBatched_64`|12.0| | | | | | | | | | | | | | | -|`cublasHSHgemvStridedBatched`|11.6| | | | | | | | | | | | | | | +|`cublasHSHgemvStridedBatched`|11.6| | | | | | | | |`rocblas_hshgemv_strided_batched`|6.0.0| | | |6.0.0| |`cublasHSHgemvStridedBatched_64`|12.0| | | | | | | | | | | | | | | |`cublasHSSgemvBatched`|11.6| | | | | | | | |`rocblas_hssgemv_batched`|6.0.0| | | |6.0.0| |`cublasHSSgemvBatched_64`|12.0| | | | | | | | | | | | | | | -|`cublasHSSgemvStridedBatched`|11.6| | | | | | | | | | | | | | | +|`cublasHSSgemvStridedBatched`|11.6| | | | | | | | |`rocblas_hssgemv_strided_batched`|6.0.0| | | |6.0.0| |`cublasHSSgemvStridedBatched_64`|12.0| | | | | | | | | | | | | | | |`cublasHgemm`|7.5| | |`hipblasHgemm`|1.8.2| | | | |`rocblas_hgemm`|1.5.0| | | | | |`cublasHgemmBatched`|9.0| | |`hipblasHgemmBatched`|3.0.0| | | | |`rocblas_hgemm_batched`|3.5.0| | | | | @@ -847,11 +847,11 @@ |`cublasStrsm_v2_64`|12.0| | | | | | | | | | | | | | | |`cublasTSSgemvBatched`|11.6| | | | | | | | |`rocblas_tssgemv_batched`|6.0.0| | | |6.0.0| |`cublasTSSgemvBatched_64`|12.0| | | | | | | | | | | | | | | -|`cublasTSSgemvStridedBatched`|11.6| | | | | | | | | | | | | | | +|`cublasTSSgemvStridedBatched`|11.6| | | | | | | | |`rocblas_tssgemv_strided_batched`|6.0.0| | | |6.0.0| |`cublasTSSgemvStridedBatched_64`|12.0| | | | | | | | | | | | | | | |`cublasTSTgemvBatched`|11.6| | | | | | | | |`rocblas_tstgemv_batched`|6.0.0| | | |6.0.0| |`cublasTSTgemvBatched_64`|12.0| | | | | | | | | | | | | | | -|`cublasTSTgemvStridedBatched`|11.6| | | | | | | | | | | | | | | +|`cublasTSTgemvStridedBatched`|11.6| | | | | | | | |`rocblas_tstgemv_strided_batched`|6.0.0| | | |6.0.0| |`cublasTSTgemvStridedBatched_64`|12.0| | | | | | | | | | | | | | | |`cublasZgemm`| | | |`hipblasZgemm_v2`|6.0.0| | | |6.0.0|`rocblas_zgemm`|1.5.0| | | | | |`cublasZgemm3m`|8.0| | | | | | | | | | | | | | | diff --git a/docs/tables/CUBLAS_API_supported_by_ROC.md b/docs/tables/CUBLAS_API_supported_by_ROC.md index b2663305..f0091016 100644 --- a/docs/tables/CUBLAS_API_supported_by_ROC.md +++ b/docs/tables/CUBLAS_API_supported_by_ROC.md @@ -799,11 +799,11 @@ |`cublasDtrsm_v2_64`|12.0| | | | | | | | | |`cublasHSHgemvBatched`|11.6| | |`rocblas_hshgemv_batched`|6.0.0| | | |6.0.0| |`cublasHSHgemvBatched_64`|12.0| | | | | | | | | -|`cublasHSHgemvStridedBatched`|11.6| | | | | | | | | +|`cublasHSHgemvStridedBatched`|11.6| | |`rocblas_hshgemv_strided_batched`|6.0.0| | | |6.0.0| |`cublasHSHgemvStridedBatched_64`|12.0| | | | | | | | | |`cublasHSSgemvBatched`|11.6| | |`rocblas_hssgemv_batched`|6.0.0| | | |6.0.0| |`cublasHSSgemvBatched_64`|12.0| | | | | | | | | -|`cublasHSSgemvStridedBatched`|11.6| | | | | | | | | +|`cublasHSSgemvStridedBatched`|11.6| | |`rocblas_hssgemv_strided_batched`|6.0.0| | | |6.0.0| |`cublasHSSgemvStridedBatched_64`|12.0| | | | | | | | | |`cublasHgemm`|7.5| | |`rocblas_hgemm`|1.5.0| | | | | |`cublasHgemmBatched`|9.0| | |`rocblas_hgemm_batched`|3.5.0| | | | | @@ -847,11 +847,11 @@ |`cublasStrsm_v2_64`|12.0| | | | | | | | | |`cublasTSSgemvBatched`|11.6| | |`rocblas_tssgemv_batched`|6.0.0| | | |6.0.0| |`cublasTSSgemvBatched_64`|12.0| | | | | | | | | -|`cublasTSSgemvStridedBatched`|11.6| | | | | | | | | +|`cublasTSSgemvStridedBatched`|11.6| | |`rocblas_tssgemv_strided_batched`|6.0.0| | | |6.0.0| |`cublasTSSgemvStridedBatched_64`|12.0| | | | | | | | | |`cublasTSTgemvBatched`|11.6| | |`rocblas_tstgemv_batched`|6.0.0| | | |6.0.0| |`cublasTSTgemvBatched_64`|12.0| | | | | | | | | -|`cublasTSTgemvStridedBatched`|11.6| | | | | | | | | +|`cublasTSTgemvStridedBatched`|11.6| | |`rocblas_tstgemv_strided_batched`|6.0.0| | | |6.0.0| |`cublasTSTgemvStridedBatched_64`|12.0| | | | | | | | | |`cublasZgemm`| | | |`rocblas_zgemm`|1.5.0| | | | | |`cublasZgemm3m`|8.0| | | | | | | | | diff --git a/src/CUDA2HIP_BLAS_API_functions.cpp b/src/CUDA2HIP_BLAS_API_functions.cpp index 652b943b..f6444e1b 100644 --- a/src/CUDA2HIP_BLAS_API_functions.cpp +++ b/src/CUDA2HIP_BLAS_API_functions.cpp @@ -458,13 +458,13 @@ const std::map CUDA_BLAS_FUNCTION_MAP { {"cublasCgemvStridedBatched_64", {"hipblasCgemvStridedBatched_64", "", CONV_LIB_FUNC, API_BLAS, 7, UNSUPPORTED}}, {"cublasZgemvStridedBatched", {"hipblasZgemvStridedBatched_v2", "rocblas_zgemv_strided_batched", CONV_LIB_FUNC, API_BLAS, 7}}, {"cublasZgemvStridedBatched_64", {"hipblasZgemvStridedBatched_64", "", CONV_LIB_FUNC, API_BLAS, 7, UNSUPPORTED}}, - {"cublasHSHgemvStridedBatched", {"hipblasHSHgemvStridedBatched", "", CONV_LIB_FUNC, API_BLAS, 7, UNSUPPORTED}}, + {"cublasHSHgemvStridedBatched", {"hipblasHSHgemvStridedBatched", "rocblas_hshgemv_strided_batched", CONV_LIB_FUNC, API_BLAS, 7, HIP_UNSUPPORTED}}, {"cublasHSHgemvStridedBatched_64", {"hipblasHSHgemvStridedBatched_64", "", CONV_LIB_FUNC, API_BLAS, 7, UNSUPPORTED}}, - {"cublasHSSgemvStridedBatched", {"hipblasHSSgemvStridedBatched", "", CONV_LIB_FUNC, API_BLAS, 7, UNSUPPORTED}}, + {"cublasHSSgemvStridedBatched", {"hipblasHSSgemvStridedBatched", "rocblas_hssgemv_strided_batched", CONV_LIB_FUNC, API_BLAS, 7, HIP_UNSUPPORTED}}, {"cublasHSSgemvStridedBatched_64", {"hipblasHSSgemvStridedBatched_64", "", CONV_LIB_FUNC, API_BLAS, 7, UNSUPPORTED}}, - {"cublasTSTgemvStridedBatched", {"hipblasTSTgemvStridedBatched", "", CONV_LIB_FUNC, API_BLAS, 7, UNSUPPORTED}}, + {"cublasTSTgemvStridedBatched", {"hipblasTSTgemvStridedBatched", "rocblas_tstgemv_strided_batched", CONV_LIB_FUNC, API_BLAS, 7, HIP_UNSUPPORTED}}, {"cublasTSTgemvStridedBatched_64", {"hipblasTSTgemvStridedBatched_64", "", CONV_LIB_FUNC, API_BLAS, 7, UNSUPPORTED}}, - {"cublasTSSgemvStridedBatched", {"hipblasTSSgemvStridedBatched", "", CONV_LIB_FUNC, API_BLAS, 7, UNSUPPORTED}}, + {"cublasTSSgemvStridedBatched", {"hipblasTSSgemvStridedBatched", "rocblas_tssgemv_strided_batched", CONV_LIB_FUNC, API_BLAS, 7, HIP_UNSUPPORTED}}, {"cublasTSSgemvStridedBatched_64", {"hipblasTSSgemvStridedBatched_64", "", CONV_LIB_FUNC, API_BLAS, 7, UNSUPPORTED}}, // SYRK @@ -2100,6 +2100,10 @@ const std::map HIP_BLAS_FUNCTION_VER_MAP { {"rocblas_hssgemv_batched", {HIP_6000, HIP_0, HIP_0, HIP_LATEST}}, {"rocblas_tstgemv_batched", {HIP_6000, HIP_0, HIP_0, HIP_LATEST}}, {"rocblas_tssgemv_batched", {HIP_6000, HIP_0, HIP_0, HIP_LATEST}}, + {"rocblas_hshgemv_strided_batched", {HIP_6000, HIP_0, HIP_0, HIP_LATEST}}, + {"rocblas_hssgemv_strided_batched", {HIP_6000, HIP_0, HIP_0, HIP_LATEST}}, + {"rocblas_tstgemv_strided_batched", {HIP_6000, HIP_0, HIP_0, HIP_LATEST}}, + {"rocblas_tssgemv_strided_batched", {HIP_6000, HIP_0, HIP_0, HIP_LATEST}}, }; const std::map HIP_BLAS_FUNCTION_CHANGED_VER_MAP { diff --git a/tests/unit_tests/synthetic/libraries/cublas2rocblas.cu b/tests/unit_tests/synthetic/libraries/cublas2rocblas.cu index a2bfd2d5..ec9ed03a 100644 --- a/tests/unit_tests/synthetic/libraries/cublas2rocblas.cu +++ b/tests/unit_tests/synthetic/libraries/cublas2rocblas.cu @@ -254,6 +254,10 @@ int main() { __half* hc = 0; // CHECK: rocblas_half* hC = 0; __half* hC = 0; + // CHECK: rocblas_half* hx = 0; + __half* hx = 0; + // CHECK: rocblas_half* hy = 0; + __half* hy = 0; // CHECK: rocblas_half** hAarray = 0; __half** hAarray = 0; @@ -274,6 +278,13 @@ int main() { // CHECK: rocblas_half** hyarray = 0; __half** hyarray = 0; + // CHECK: rocblas_bfloat16* bf16A = 0; + __nv_bfloat16* bf16A = 0; + // CHECK: rocblas_bfloat16* bf16x = 0; + __nv_bfloat16* bf16x = 0; + // CHECK: rocblas_bfloat16* bf16y = 0; + __nv_bfloat16* bf16y = 0; + // CHECK: rocblas_bfloat16** bf16Aarray = 0; __nv_bfloat16** bf16Aarray = 0; // CHECK: const rocblas_bfloat16** const bf16Aarray_const = const_cast(bf16Aarray); @@ -1838,6 +1849,26 @@ int main() { // ROC: ROCBLAS_EXPORT rocblas_status rocblas_tssgemv_batched(rocblas_handle handle, rocblas_operation trans, rocblas_int m, rocblas_int n, const float* alpha, const rocblas_bfloat16* const A[], rocblas_int lda, const rocblas_bfloat16* const x[], rocblas_int incx, const float* beta, float* const y[], rocblas_int incy, rocblas_int batch_count); // CHECK: blasStatus = rocblas_tssgemv_batched(blasHandle, blasOperation, m, n, &fa, bf16Aarray_const, lda, bf16xarray_const, incx, &fb, fyarray, incy, batchCount); blasStatus = cublasTSSgemvBatched(blasHandle, blasOperation, m, n, &fa, bf16Aarray_const, lda, bf16xarray_const, incx, &fb, fyarray, incy, batchCount); + + // CUDA: CUBLASAPI cublasStatus_t CUBLASWINAPI cublasHSHgemvStridedBatched(cublasHandle_t handle, cublasOperation_t trans, int m, int n, const float* alpha, const __half* A, int lda, long long int strideA, const __half* x, int incx, long long int stridex, const float* beta, __half* y, int incy, long long int stridey, int batchCount); + // ROC: ROCBLAS_EXPORT rocblas_status rocblas_hshgemv_strided_batched(rocblas_handle handle, rocblas_operation transA, rocblas_int m, rocblas_int n, const float* alpha, const rocblas_half* A, rocblas_int lda, rocblas_stride strideA, const rocblas_half* x, rocblas_int incx, rocblas_stride stridex, const float* beta, rocblas_half* y, rocblas_int incy, rocblas_stride stridey, rocblas_int batch_count); + // CHECK: blasStatus = rocblas_hshgemv_strided_batched(blasHandle, blasOperation, m, n, &fa, hA, lda, strideA, hx, incx, stridex, &fb, hy, incy, stridey, batchCount); + blasStatus = cublasHSHgemvStridedBatched(blasHandle, blasOperation, m, n, &fa, hA, lda, strideA, hx, incx, stridex, &fb, hy, incy, stridey, batchCount); + + // CUDA: CUBLASAPI cublasStatus_t CUBLASWINAPI cublasHSSgemvStridedBatched(cublasHandle_t handle, cublasOperation_t trans, int m, int n, const float* alpha, const __half* A, int lda, long long int strideA, const __half* x, int incx, long long int stridex, const float* beta, float* y, int incy, long long int stridey, int batchCount); + // ROC: ROCBLAS_EXPORT rocblas_status rocblas_hssgemv_strided_batched(rocblas_handle handle, rocblas_operation transA, rocblas_int m, rocblas_int n, const float* alpha, const rocblas_half* A, rocblas_int lda, rocblas_stride strideA, const rocblas_half* x, rocblas_int incx, rocblas_stride stridex, const float* beta, float* y, rocblas_int incy, rocblas_stride stridey, rocblas_int batch_count); + // CHECK: blasStatus = rocblas_hssgemv_strided_batched(blasHandle, blasOperation, m, n, &fa, hA, lda, strideA, hx, incx, stridex, &fb, &fy, incy, stridey, batchCount); + blasStatus = cublasHSSgemvStridedBatched(blasHandle, blasOperation, m, n, &fa, hA, lda, strideA, hx, incx, stridex, &fb, &fy, incy, stridey, batchCount); + + // CUDA: CUBLASAPI cublasStatus_t CUBLASWINAPI cublasTSTgemvStridedBatched(cublasHandle_t handle, cublasOperation_t trans, int m, int n, const float* alpha, const __nv_bfloat16* A, int lda, long long int strideA, const __nv_bfloat16* x, int incx, long long int stridex, const float* beta, __nv_bfloat16* y, int incy, long long int stridey, int batchCount); + // ROC: ROCBLAS_EXPORT rocblas_status rocblas_tstgemv_strided_batched(rocblas_handle handle, rocblas_operation transA, rocblas_int m, rocblas_int n, const float* alpha, const rocblas_bfloat16* A, rocblas_int lda, rocblas_stride strideA, const rocblas_bfloat16* x, rocblas_int incx, rocblas_stride stridex, const float* beta, rocblas_bfloat16* y, rocblas_int incy, rocblas_stride stridey, rocblas_int batch_count); + // CHECK: blasStatus = rocblas_tstgemv_strided_batched(blasHandle, blasOperation, m, n, &fa, bf16A, lda, strideA, bf16x, incx, stridex, &fb, bf16y, incy, stridey, batchCount); + blasStatus = cublasTSTgemvStridedBatched(blasHandle, blasOperation, m, n, &fa, bf16A, lda, strideA, bf16x, incx, stridex, &fb, bf16y, incy, stridey, batchCount); + + // CUDA: CUBLASAPI cublasStatus_t CUBLASWINAPI cublasTSSgemvStridedBatched(cublasHandle_t handle, cublasOperation_t trans, int m, int n, const float* alpha, const __nv_bfloat16* A, int lda, long long int strideA, const __nv_bfloat16* x, int incx, long long int stridex, const float* beta, float* y, int incy, long long int stridey, int batchCount); + // ROC: ROCBLAS_EXPORT rocblas_status rocblas_tssgemv_strided_batched(rocblas_handle handle, rocblas_operation transA, rocblas_int m, rocblas_int n, const float* alpha, const rocblas_bfloat16* A, rocblas_int lda, rocblas_stride strideA, const rocblas_bfloat16* x, rocblas_int incx, rocblas_stride stridex, const float* beta, float* y, rocblas_int incy, rocblas_stride stridey, rocblas_int batch_count); + // CHECK: blasStatus = rocblas_tssgemv_strided_batched(blasHandle, blasOperation, m, n, &fa, bf16A, lda, strideA, bf16x, incx, stridex, &fb, &fy, incy, stridey, batchCount); + blasStatus = cublasTSSgemvStridedBatched(blasHandle, blasOperation, m, n, &fa, bf16A, lda, strideA, bf16x, incx, stridex, &fb, &fy, incy, stridey, batchCount); #endif return 0;