From 06aa7049c672b77a63be73c13e2f6dfa39f39d74 Mon Sep 17 00:00:00 2001 From: Jeff Daily Date: Fri, 11 Oct 2024 19:03:16 +0000 Subject: [PATCH] lint --- .../contrib_ops/rocm/bert/attention_impl.h | 44 +- .../core/providers/rocm/rocm_stream_handle.cc | 2 +- .../providers/rocm/shared_inc/fpgeneric.h | 551 +++++++++--------- tools/ci_build/amd_hipify.py | 8 +- 4 files changed, 300 insertions(+), 305 deletions(-) diff --git a/onnxruntime/contrib_ops/rocm/bert/attention_impl.h b/onnxruntime/contrib_ops/rocm/bert/attention_impl.h index bfdf34b30a5d5..6c2e36b596d32 100644 --- a/onnxruntime/contrib_ops/rocm/bert/attention_impl.h +++ b/onnxruntime/contrib_ops/rocm/bert/attention_impl.h @@ -71,28 +71,28 @@ Status LaunchConcatTensorToTensor(hipStream_t stream, half* tensor_out); inline hipblasStatus_t _compat_hipblas_gemm_strided_batched_ex(hipblasHandle_t handle, - hipblasOperation_t transa, - hipblasOperation_t transb, - int m, - int n, - int k, - const void* alpha, - const void* A, - hipDataType a_type, - int lda, - hipblasStride stride_A, - const void* b, - hipDataType b_type, - int ldb, - hipblasStride stride_b, - const void* beta, - void* c, - hipDataType c_type, - int ldc, - hipblasStride stride_c, - int batch_count, - hipblasComputeType_t compute_type, - hipblasGemmAlgo_t algo) { + hipblasOperation_t transa, + hipblasOperation_t transb, + int m, + int n, + int k, + const void* alpha, + const void* A, + hipDataType a_type, + int lda, + hipblasStride stride_A, + const void* b, + hipDataType b_type, + int ldb, + hipblasStride stride_b, + const void* beta, + void* c, + hipDataType c_type, + int ldc, + hipblasStride stride_c, + int batch_count, + hipblasComputeType_t compute_type, + hipblasGemmAlgo_t algo) { return hipblasGemmStridedBatchedEx(handle, transa, transb, diff --git a/onnxruntime/core/providers/rocm/rocm_stream_handle.cc b/onnxruntime/core/providers/rocm/rocm_stream_handle.cc index 7931a23f761b0..c175252df3efc 100644 --- a/onnxruntime/core/providers/rocm/rocm_stream_handle.cc +++ b/onnxruntime/core/providers/rocm/rocm_stream_handle.cc @@ -45,7 +45,7 @@ RocmStream::RocmStream(hipStream_t stream, hipblasHandle_t external_hipblas_handle) : Stream(stream, device), own_stream_(own_flag), cpu_allocator_(cpu_allocator), - release_cpu_buffer_on_rocm_stream_(release_cpu_buffer_on_rocm_stream) { + release_cpu_buffer_on_rocm_stream_(release_cpu_buffer_on_rocm_stream) { if (own_flag) { HIPBLAS_CALL_THROW(hipblasCreate(&hipblas_handle_)); HIPBLAS_CALL_THROW(hipblasSetStream(hipblas_handle_, stream)); diff --git a/onnxruntime/core/providers/rocm/shared_inc/fpgeneric.h b/onnxruntime/core/providers/rocm/shared_inc/fpgeneric.h index 764562ac9a720..675b30612065b 100644 --- a/onnxruntime/core/providers/rocm/shared_inc/fpgeneric.h +++ b/onnxruntime/core/providers/rocm/shared_inc/fpgeneric.h @@ -13,43 +13,39 @@ #define FLAG 0 #endif // needed to work around calling rocblas API instead of hipblas API -static rocblas_operation hipOperationToRocOperation(hipblasOperation_t op) -{ - switch(op) - { +static rocblas_operation hipOperationToRocOperation(hipblasOperation_t op) { + switch (op) { case HIPBLAS_OP_N: - return rocblas_operation_none; + return rocblas_operation_none; case HIPBLAS_OP_T: - return rocblas_operation_transpose; + return rocblas_operation_transpose; case HIPBLAS_OP_C: - return rocblas_operation_conjugate_transpose; - } - assert(0 && "HIPBLAS_STATUS_INVALID_ENUM"); -} -static hipblasStatus_t rocBLASStatusToHIPStatus(rocblas_status error) -{ - switch(error) - { + return rocblas_operation_conjugate_transpose; + } + assert(0 && "HIPBLAS_STATUS_INVALID_ENUM"); +} +static hipblasStatus_t rocBLASStatusToHIPStatus(rocblas_status error) { + switch (error) { case rocblas_status_size_unchanged: case rocblas_status_size_increased: case rocblas_status_success: - return HIPBLAS_STATUS_SUCCESS; + return HIPBLAS_STATUS_SUCCESS; case rocblas_status_invalid_handle: - return HIPBLAS_STATUS_NOT_INITIALIZED; + return HIPBLAS_STATUS_NOT_INITIALIZED; case rocblas_status_not_implemented: - return HIPBLAS_STATUS_NOT_SUPPORTED; + return HIPBLAS_STATUS_NOT_SUPPORTED; case rocblas_status_invalid_pointer: case rocblas_status_invalid_size: case rocblas_status_invalid_value: - return HIPBLAS_STATUS_INVALID_VALUE; + return HIPBLAS_STATUS_INVALID_VALUE; case rocblas_status_memory_error: - return HIPBLAS_STATUS_ALLOC_FAILED; + return HIPBLAS_STATUS_ALLOC_FAILED; case rocblas_status_internal_error: - return HIPBLAS_STATUS_INTERNAL_ERROR; + return HIPBLAS_STATUS_INTERNAL_ERROR; default: - assert(0 && "ROCBLAS_STATUS_INVALID_ENUM"); - return HIPBLAS_STATUS_INTERNAL_ERROR; - } + assert(0 && "ROCBLAS_STATUS_INVALID_ENUM"); + return HIPBLAS_STATUS_INTERNAL_ERROR; + } } using namespace onnxruntime; @@ -74,89 +70,89 @@ inline hipblasStatus_t hipblasGemmHelper(hipblasHandle_t handle, const float* beta, float* C, int ldc) { return hipblasGemmEx(handle, - transa, - transb, - m, n, k, - alpha, - A, HIP_R_32F, lda, - B, HIP_R_32F, ldb, - beta, - C, HIP_R_32F, ldc, - HIPBLAS_COMPUTE_32F, - HIPBLAS_GEMM_DEFAULT); + transa, + transb, + m, n, k, + alpha, + A, HIP_R_32F, lda, + B, HIP_R_32F, ldb, + beta, + C, HIP_R_32F, ldc, + HIPBLAS_COMPUTE_32F, + HIPBLAS_GEMM_DEFAULT); } inline hipblasStatus_t hipblasGemmHelper(hipblasHandle_t handle, - hipblasOperation_t transa, - hipblasOperation_t transb, - int m, int n, int k, - const double* alpha, - const double* A, int lda, - const double* B, int ldb, - const double* beta, - double* C, int ldc) { + hipblasOperation_t transa, + hipblasOperation_t transb, + int m, int n, int k, + const double* alpha, + const double* A, int lda, + const double* B, int ldb, + const double* beta, + double* C, int ldc) { return hipblasDgemm(handle, transa, transb, m, n, k, alpha, A, lda, B, ldb, beta, C, ldc); } inline hipblasStatus_t hipblasGemmHelper(hipblasHandle_t handle, - hipblasOperation_t transa, - hipblasOperation_t transb, - int m, int n, int k, - const half* alpha, - const half* A, int lda, - const half* B, int ldb, - const half* beta, - half* C, int ldc) { + hipblasOperation_t transa, + hipblasOperation_t transb, + int m, int n, int k, + const half* alpha, + const half* A, int lda, + const half* B, int ldb, + const half* beta, + half* C, int ldc) { float h_a = onnxruntime::math::halfToFloat(*reinterpret_cast(alpha)); float h_b = onnxruntime::math::halfToFloat(*reinterpret_cast(beta)); return rocBLASStatusToHIPStatus(rocblas_gemm_ex((rocblas_handle)handle, - hipOperationToRocOperation(transa), - hipOperationToRocOperation(transb), - m, n, k, - &h_a, - A, rocblas_datatype_f16_r, lda, - B, rocblas_datatype_f16_r, ldb, - &h_b, - C, rocblas_datatype_f16_r, ldc, - C, rocblas_datatype_f16_r, ldc, - rocblas_datatype_f32_r, - rocblas_gemm_algo_standard, 0, get_flag())); + hipOperationToRocOperation(transa), + hipOperationToRocOperation(transb), + m, n, k, + &h_a, + A, rocblas_datatype_f16_r, lda, + B, rocblas_datatype_f16_r, ldb, + &h_b, + C, rocblas_datatype_f16_r, ldc, + C, rocblas_datatype_f16_r, ldc, + rocblas_datatype_f32_r, + rocblas_gemm_algo_standard, 0, get_flag())); } inline hipblasStatus_t hipblasGemmHelper(hipblasHandle_t handle, - hipblasOperation_t transa, - hipblasOperation_t transb, - int m, int n, int k, - const float* alpha, - const half* A, int lda, - const half* B, int ldb, - const float* beta, - half* C, int ldc) { + hipblasOperation_t transa, + hipblasOperation_t transb, + int m, int n, int k, + const float* alpha, + const half* A, int lda, + const half* B, int ldb, + const float* beta, + half* C, int ldc) { return rocBLASStatusToHIPStatus(rocblas_gemm_ex((rocblas_handle)handle, - hipOperationToRocOperation(transa), - hipOperationToRocOperation(transb), - m, n, k, - alpha, - A, rocblas_datatype_f16_r, lda, - B, rocblas_datatype_f16_r, ldb, - beta, - C, rocblas_datatype_f16_r, ldc, - C, rocblas_datatype_f16_r, ldc, - rocblas_datatype_f32_r, - rocblas_gemm_algo_standard, 0, get_flag())); + hipOperationToRocOperation(transa), + hipOperationToRocOperation(transb), + m, n, k, + alpha, + A, rocblas_datatype_f16_r, lda, + B, rocblas_datatype_f16_r, ldb, + beta, + C, rocblas_datatype_f16_r, ldc, + C, rocblas_datatype_f16_r, ldc, + rocblas_datatype_f32_r, + rocblas_gemm_algo_standard, 0, get_flag())); } inline hipblasStatus_t hipblasGemmHelper(hipblasHandle_t handle, - hipblasOperation_t transa, - hipblasOperation_t transb, - int m, int n, int k, - const float* alpha, - const half* A, int lda, - const half* B, int ldb, - const float* beta, - half* C, int ldc, - const hipDeviceProp_t&, - bool /*use_tf32*/) { + hipblasOperation_t transa, + hipblasOperation_t transb, + int m, int n, int k, + const float* alpha, + const half* A, int lda, + const half* B, int ldb, + const float* beta, + half* C, int ldc, + const hipDeviceProp_t&, + bool /*use_tf32*/) { return hipblasGemmHelper(handle, transa, transb, @@ -169,44 +165,44 @@ inline hipblasStatus_t hipblasGemmHelper(hipblasHandle_t handle, } inline hipblasStatus_t hipblasGemmHelper(hipblasHandle_t handle, - hipblasOperation_t transa, - hipblasOperation_t transb, - int m, int n, int k, - const BFloat16* alpha, - const BFloat16* A, int lda, - const BFloat16* B, int ldb, - const BFloat16* beta, - BFloat16* C, int ldc) { + hipblasOperation_t transa, + hipblasOperation_t transb, + int m, int n, int k, + const BFloat16* alpha, + const BFloat16* A, int lda, + const BFloat16* B, int ldb, + const BFloat16* beta, + BFloat16* C, int ldc) { float h_a = alpha->ToFloat(); float h_b = beta->ToFloat(); // accumulating in FP32 return hipblasGemmEx(handle, - transa, - transb, - m, n, k, - &h_a, - A, HIP_R_16BF, lda, - B, HIP_R_16BF, ldb, - &h_b, - C, HIP_R_16BF, ldc, - HIPBLAS_COMPUTE_32F, - HIPBLAS_GEMM_DEFAULT); + transa, + transb, + m, n, k, + &h_a, + A, HIP_R_16BF, lda, + B, HIP_R_16BF, ldb, + &h_b, + C, HIP_R_16BF, ldc, + HIPBLAS_COMPUTE_32F, + HIPBLAS_GEMM_DEFAULT); } // Compatible for function call with extra arguments (see cublasGemmHelper) template hipblasStatus_t hipblasGemmHelper(hipblasHandle_t handle, - hipblasOperation_t transa, - hipblasOperation_t transb, - int m, int n, int k, - const Scalar* alpha, - const Scalar* A, int lda, - const Scalar* B, int ldb, - const Scalar* beta, - Scalar* C, int ldc, - const hipDeviceProp_t&, - bool /*use_tf32*/) { + hipblasOperation_t transa, + hipblasOperation_t transb, + int m, int n, int k, + const Scalar* alpha, + const Scalar* A, int lda, + const Scalar* B, int ldb, + const Scalar* beta, + Scalar* C, int ldc, + const hipDeviceProp_t&, + bool /*use_tf32*/) { return hipblasGemmHelper(handle, transa, transb, @@ -220,15 +216,15 @@ hipblasStatus_t hipblasGemmHelper(hipblasHandle_t handle, // batched gemm inline hipblasStatus_t hipblasGemmBatchedHelper(hipblasHandle_t handle, - hipblasOperation_t transa, - hipblasOperation_t transb, - int m, int n, int k, - const float* alpha, - const float* Aarray[], int lda, - const float* Barray[], int ldb, - const float* beta, - float* Carray[], int ldc, - int batchCount) { + hipblasOperation_t transa, + hipblasOperation_t transb, + int m, int n, int k, + const float* alpha, + const float* Aarray[], int lda, + const float* Barray[], int ldb, + const float* beta, + float* Carray[], int ldc, + int batchCount) { return hipblasGemmBatchedEx(handle, transa, transb, @@ -243,54 +239,54 @@ inline hipblasStatus_t hipblasGemmBatchedHelper(hipblasHandle_t handle, HIPBLAS_GEMM_DEFAULT); } inline hipblasStatus_t hipblasGemmBatchedHelper(hipblasHandle_t handle, - hipblasOperation_t transa, - hipblasOperation_t transb, - int m, int n, int k, - const double* alpha, - const double* Aarray[], int lda, - const double* Barray[], int ldb, - const double* beta, - double* Carray[], int ldc, - int batchCount) { + hipblasOperation_t transa, + hipblasOperation_t transb, + int m, int n, int k, + const double* alpha, + const double* Aarray[], int lda, + const double* Barray[], int ldb, + const double* beta, + double* Carray[], int ldc, + int batchCount) { return hipblasDgemmBatched(handle, transa, transb, m, n, k, alpha, Aarray, lda, Barray, ldb, beta, Carray, ldc, batchCount); } inline hipblasStatus_t hipblasGemmBatchedHelper(hipblasHandle_t handle, - hipblasOperation_t transa, - hipblasOperation_t transb, - int m, int n, int k, - const half* alpha, - const half* Aarray[], int lda, - const half* Barray[], int ldb, - const half* beta, - half* Carray[], int ldc, - int batchCount) { + hipblasOperation_t transa, + hipblasOperation_t transb, + int m, int n, int k, + const half* alpha, + const half* Aarray[], int lda, + const half* Barray[], int ldb, + const half* beta, + half* Carray[], int ldc, + int batchCount) { float h_a = onnxruntime::math::halfToFloat(*reinterpret_cast(alpha)); float h_b = onnxruntime::math::halfToFloat(*reinterpret_cast(beta)); return rocBLASStatusToHIPStatus(rocblas_gemm_batched_ex((rocblas_handle)handle, - hipOperationToRocOperation(transa), - hipOperationToRocOperation(transb), - m, n, k, - &h_a, - (const void**)Aarray, rocblas_datatype_f16_r, lda, - (const void**)Barray, rocblas_datatype_f16_r, ldb, - &h_b, - (void**)Carray, rocblas_datatype_f16_r, ldc, - (void**)Carray, rocblas_datatype_f16_r, ldc, - batchCount, - rocblas_datatype_f32_r, - rocblas_gemm_algo_standard, 0, get_flag())); + hipOperationToRocOperation(transa), + hipOperationToRocOperation(transb), + m, n, k, + &h_a, + (const void**)Aarray, rocblas_datatype_f16_r, lda, + (const void**)Barray, rocblas_datatype_f16_r, ldb, + &h_b, + (void**)Carray, rocblas_datatype_f16_r, ldc, + (void**)Carray, rocblas_datatype_f16_r, ldc, + batchCount, + rocblas_datatype_f32_r, + rocblas_gemm_algo_standard, 0, get_flag())); } inline hipblasStatus_t hipblasGemmBatchedHelper(hipblasHandle_t handle, - hipblasOperation_t transa, - hipblasOperation_t transb, - int m, int n, int k, - const BFloat16* alpha, - const BFloat16* Aarray[], int lda, - const BFloat16* Barray[], int ldb, - const BFloat16* beta, - BFloat16* Carray[], int ldc, - int batch_count) { + hipblasOperation_t transa, + hipblasOperation_t transb, + int m, int n, int k, + const BFloat16* alpha, + const BFloat16* Aarray[], int lda, + const BFloat16* Barray[], int ldb, + const BFloat16* beta, + BFloat16* Carray[], int ldc, + int batch_count) { float h_a = alpha->ToFloat(); float h_b = beta->ToFloat(); @@ -311,18 +307,18 @@ inline hipblasStatus_t hipblasGemmBatchedHelper(hipblasHandle_t handle, // strided batched gemm inline hipblasStatus_t hipblasGemmStridedBatchedHelper(hipblasHandle_t handle, - hipblasOperation_t transa, - hipblasOperation_t transb, - int m, int n, int k, - const float* alpha, - const float* A, int lda, - long long int strideA, - const float* B, int ldb, - long long int strideB, - const float* beta, - float* C, int ldc, - long long int strideC, - int batchCount) { + hipblasOperation_t transa, + hipblasOperation_t transb, + int m, int n, int k, + const float* alpha, + const float* A, int lda, + long long int strideA, + const float* B, int ldb, + long long int strideB, + const float* beta, + float* C, int ldc, + long long int strideC, + int batchCount) { return hipblasGemmStridedBatchedEx(handle, transa, transb, @@ -338,92 +334,92 @@ inline hipblasStatus_t hipblasGemmStridedBatchedHelper(hipblasHandle_t handle, } inline hipblasStatus_t hipblasGemmStridedBatchedHelper(hipblasHandle_t handle, - hipblasOperation_t transa, - hipblasOperation_t transb, - int m, int n, int k, - const double* alpha, - const double* A, int lda, - long long int strideA, - const double* B, int ldb, - long long int strideB, - const double* beta, - double* C, int ldc, - long long int strideC, - int batchCount) { + hipblasOperation_t transa, + hipblasOperation_t transb, + int m, int n, int k, + const double* alpha, + const double* A, int lda, + long long int strideA, + const double* B, int ldb, + long long int strideB, + const double* beta, + double* C, int ldc, + long long int strideC, + int batchCount) { return hipblasDgemmStridedBatched(handle, transa, transb, m, n, k, alpha, A, lda, strideA, B, ldb, strideB, beta, C, ldc, strideC, batchCount); } inline hipblasStatus_t hipblasGemmStridedBatchedHelper(hipblasHandle_t handle, - hipblasOperation_t transa, - hipblasOperation_t transb, - int m, int n, int k, - const __half* alpha, - const __half* A, int lda, - long long int strideA, - const __half* B, int ldb, - long long int strideB, - const __half* beta, - __half* C, int ldc, - long long int strideC, - int batchCount) { + hipblasOperation_t transa, + hipblasOperation_t transb, + int m, int n, int k, + const __half* alpha, + const __half* A, int lda, + long long int strideA, + const __half* B, int ldb, + long long int strideB, + const __half* beta, + __half* C, int ldc, + long long int strideC, + int batchCount) { float h_a = onnxruntime::math::halfToFloat(*reinterpret_cast(alpha)); float h_b = onnxruntime::math::halfToFloat(*reinterpret_cast(beta)); return rocBLASStatusToHIPStatus(rocblas_gemm_strided_batched_ex((rocblas_handle)handle, - hipOperationToRocOperation(transa), - hipOperationToRocOperation(transb), - m, n, k, - &h_a, - A, rocblas_datatype_f16_r, lda, strideA, - B, rocblas_datatype_f16_r, ldb, strideB, - &h_b, - C, rocblas_datatype_f16_r, ldc, strideC, - C, rocblas_datatype_f16_r, ldc, strideC, - batchCount, - rocblas_datatype_f32_r, - rocblas_gemm_algo_standard, 0, get_flag())); + hipOperationToRocOperation(transa), + hipOperationToRocOperation(transb), + m, n, k, + &h_a, + A, rocblas_datatype_f16_r, lda, strideA, + B, rocblas_datatype_f16_r, ldb, strideB, + &h_b, + C, rocblas_datatype_f16_r, ldc, strideC, + C, rocblas_datatype_f16_r, ldc, strideC, + batchCount, + rocblas_datatype_f32_r, + rocblas_gemm_algo_standard, 0, get_flag())); } inline hipblasStatus_t hipblasGemmStridedBatchedHelper(hipblasHandle_t handle, - hipblasOperation_t transa, - hipblasOperation_t transb, - int m, int n, int k, - const float* alpha, - const __half* A, int lda, - intmax_t strideA, - const __half* B, int ldb, - intmax_t strideB, - const float* beta, - __half* C, int ldc, - intmax_t strideC, - int batchCount) { + hipblasOperation_t transa, + hipblasOperation_t transb, + int m, int n, int k, + const float* alpha, + const __half* A, int lda, + intmax_t strideA, + const __half* B, int ldb, + intmax_t strideB, + const float* beta, + __half* C, int ldc, + intmax_t strideC, + int batchCount) { return rocBLASStatusToHIPStatus(rocblas_gemm_strided_batched_ex((rocblas_handle)handle, - hipOperationToRocOperation(transa), - hipOperationToRocOperation(transb), - m, n, k, - alpha, - A, rocblas_datatype_f16_r, lda, strideA, - B, rocblas_datatype_f16_r, ldb, strideB, - beta, - C, rocblas_datatype_f16_r, ldc, strideC, - C, rocblas_datatype_f16_r, ldc, strideC, - batchCount, - rocblas_datatype_f32_r, - rocblas_gemm_algo_standard, 0, get_flag())); + hipOperationToRocOperation(transa), + hipOperationToRocOperation(transb), + m, n, k, + alpha, + A, rocblas_datatype_f16_r, lda, strideA, + B, rocblas_datatype_f16_r, ldb, strideB, + beta, + C, rocblas_datatype_f16_r, ldc, strideC, + C, rocblas_datatype_f16_r, ldc, strideC, + batchCount, + rocblas_datatype_f32_r, + rocblas_gemm_algo_standard, 0, get_flag())); } inline hipblasStatus_t hipblasGemmStridedBatchedHelper(hipblasHandle_t handle, - hipblasOperation_t transa, - hipblasOperation_t transb, - int m, int n, int k, - const BFloat16* alpha, - const BFloat16* A, int lda, - intmax_t strideA, - const BFloat16* B, int ldb, - intmax_t strideB, - const BFloat16* beta, - BFloat16* C, int ldc, - intmax_t strideC, - int batch_count) { + hipblasOperation_t transa, + hipblasOperation_t transb, + int m, int n, int k, + const BFloat16* alpha, + const BFloat16* A, int lda, + intmax_t strideA, + const BFloat16* B, int ldb, + intmax_t strideB, + const BFloat16* beta, + BFloat16* C, int ldc, + intmax_t strideC, + int batch_count) { float h_a = alpha->ToFloat(); float h_b = beta->ToFloat(); // accumulating in FP32 @@ -444,20 +440,20 @@ inline hipblasStatus_t hipblasGemmStridedBatchedHelper(hipblasHandle_t handle, // Compatible for function call with with extra arguments (see cublasGemmStridedBatchedHelper) template hipblasStatus_t hipblasGemmStridedBatchedHelper(hipblasHandle_t handle, - hipblasOperation_t transa, - hipblasOperation_t transb, - int m, int n, int k, - const Scalar* alpha, - const Scalar* A, int lda, - intmax_t strideA, - const Scalar* B, int ldb, - intmax_t strideB, - const Scalar* beta, - Scalar* C, int ldc, - intmax_t strideC, - int batchCount, - const hipDeviceProp_t&, - bool /*use_tf32*/) { + hipblasOperation_t transa, + hipblasOperation_t transb, + int m, int n, int k, + const Scalar* alpha, + const Scalar* A, int lda, + intmax_t strideA, + const Scalar* B, int ldb, + intmax_t strideB, + const Scalar* beta, + Scalar* C, int ldc, + intmax_t strideC, + int batchCount, + const hipDeviceProp_t&, + bool /*use_tf32*/) { return hipblasGemmStridedBatchedHelper(handle, transa, transb, @@ -471,20 +467,20 @@ hipblasStatus_t hipblasGemmStridedBatchedHelper(hipblasHandle_t handle, } inline hipblasStatus_t hipblasGemmStridedBatchedHelper(hipblasHandle_t handle, - hipblasOperation_t transa, - hipblasOperation_t transb, - int m, int n, int k, - const float* alpha, - const __half* A, int lda, - intmax_t strideA, - const __half* B, int ldb, - intmax_t strideB, - const float* beta, - __half* C, int ldc, - intmax_t strideC, - int batchCount, - const hipDeviceProp_t&, - bool /*use_tf32*/) { + hipblasOperation_t transa, + hipblasOperation_t transb, + int m, int n, int k, + const float* alpha, + const __half* A, int lda, + intmax_t strideA, + const __half* B, int ldb, + intmax_t strideB, + const float* beta, + __half* C, int ldc, + intmax_t strideC, + int batchCount, + const hipDeviceProp_t&, + bool /*use_tf32*/) { return hipblasGemmStridedBatchedHelper(handle, transa, transb, @@ -959,4 +955,3 @@ inline rocblas_status rocblasGemmStridedBatchedHelper(rocblas_handle handle, C, ldc, strideC, batchCount); } - diff --git a/tools/ci_build/amd_hipify.py b/tools/ci_build/amd_hipify.py index 86bcb5540c192..07167b0a61732 100644 --- a/tools/ci_build/amd_hipify.py +++ b/tools/ci_build/amd_hipify.py @@ -84,11 +84,11 @@ def hipify(hipify_perl_path, src_file_path, dst_file_path): s = s.replace("typedef half MappedType", "typedef __half MappedType") # CUBLAS -> HIPBLAS - s = s.replace('CUBLAS', 'HIPBLAS') - s = s.replace('Cublas', 'Hipblas') - s = s.replace('cublas', 'hipblas') + s = s.replace("CUBLAS", "HIPBLAS") + s = s.replace("Cublas", "Hipblas") + s = s.replace("cublas", "hipblas") # deprecated cublas symbol doesn't exist in hipblas, map to new symbol - s = s.replace('HIPBLAS_GEMM_DEFAULT_TENSOR_OP', 'HIPBLAS_GEMM_DEFAULT') + s = s.replace("HIPBLAS_GEMM_DEFAULT_TENSOR_OP", "HIPBLAS_GEMM_DEFAULT") # Undefined ROCMRT constants -> std::numeric_limits s = s.replace("ROCMRT_INF_F", "std::numeric_limits::infinity()")