From 4887fa579315d6b6bca202932f7c1edc2aa8ff40 Mon Sep 17 00:00:00 2001
From: Aman Gupta <amangupta052@gmail.com>
Date: Tue, 24 Jun 2025 16:49:50 +0800
Subject: [PATCH 1/4] CUDA: add bf16 and f32 support to cublas_mul_mat_batched

---
 ggml/src/ggml-cuda/convert.cu   |  11 ++
 ggml/src/ggml-cuda/convert.cuh  |   2 +
 ggml/src/ggml-cuda/ggml-cuda.cu | 234 +++++++++++++++++++++++++-------
 tests/test-backend-ops.cpp      |   6 +-
 4 files changed, 202 insertions(+), 51 deletions(-)

diff --git a/ggml/src/ggml-cuda/convert.cu b/ggml/src/ggml-cuda/convert.cu
index c6dec4276b36d..a5a53f6239a4a 100644
--- a/ggml/src/ggml-cuda/convert.cu
+++ b/ggml/src/ggml-cuda/convert.cu
@@ -728,3 +728,14 @@ to_fp16_nc_cuda_t ggml_get_to_fp16_nc_cuda(ggml_type type) {
             return nullptr;
     }
 }
+
+to_bf16_nc_cuda_t ggml_get_to_bf16_nc_cuda(ggml_type type) {
+    switch (type) {
+        case GGML_TYPE_F32:
+            return convert_unary_cuda<float, nv_bfloat16>;
+        case GGML_TYPE_F16:
+            return convert_unary_cuda<half, nv_bfloat16>;
+        default:
+            return nullptr;
+    }
+}
diff --git a/ggml/src/ggml-cuda/convert.cuh b/ggml/src/ggml-cuda/convert.cuh
index b65b98e08e7e2..debda3b03dc26 100644
--- a/ggml/src/ggml-cuda/convert.cuh
+++ b/ggml/src/ggml-cuda/convert.cuh
@@ -23,4 +23,6 @@ using to_t_nc_cuda_t = void (*)(const void * x, T * y,
     int64_t s01, int64_t s02, int64_t s03, cudaStream_t stream);
 
 typedef to_t_nc_cuda_t<half> to_fp16_nc_cuda_t;
+typedef to_t_nc_cuda_t<nv_bfloat16> to_bf16_nc_cuda_t;
 to_fp16_nc_cuda_t ggml_get_to_fp16_nc_cuda(ggml_type type);
+to_bf16_nc_cuda_t ggml_get_to_bf16_nc_cuda(ggml_type type);
diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu
index b30c13c62f25c..bfac2363c900a 100644
--- a/ggml/src/ggml-cuda/ggml-cuda.cu
+++ b/ggml/src/ggml-cuda/ggml-cuda.cu
@@ -1748,8 +1748,9 @@ static void ggml_cuda_op_mul_mat(
     }
 }
 
+template<typename T>
 static __global__ void k_compute_batched_ptrs(
-        const half * src0_as_f16, const half * src1_as_f16, char * dst,
+        const T * src0_as_f16, const T * src1_as_f16, char * dst,
         const void ** ptrs_src, void ** ptrs_dst,
         int64_t ne12, int64_t ne13,
         int64_t ne23,
@@ -1777,7 +1778,7 @@ static void ggml_cuda_mul_mat_batched_cublas(ggml_backend_cuda_context & ctx, co
     GGML_ASSERT(!ggml_is_transposed(src1));
 
     GGML_ASSERT(!ggml_backend_buft_is_cuda_split(src0->buffer->buft));
-    GGML_ASSERT(src0->type == GGML_TYPE_F16);
+    GGML_ASSERT(src0->type == GGML_TYPE_F16 || src0->type == GGML_TYPE_BF16 || src0->type == GGML_TYPE_F32);
 
     // Byte offsets and tensor dimensions are currently used in an inconsistent way for dst.
     // As long as dst is contiguous this does not matter though.
@@ -1791,64 +1792,153 @@ static void ggml_cuda_mul_mat_batched_cublas(ggml_backend_cuda_context & ctx, co
 
     CUBLAS_CHECK(cublasSetStream(ctx.cublas_handle(), main_stream));
 
-    const half * src0_f16 = (const half *) src0->data;
-    float * dst_ddf = (float *) dst->data;
+    const ggml_type src0_type = src0->type;
+    const bool use_f32_path = src0_type == GGML_TYPE_F32;
+    const bool use_bf16_path = src0_type == GGML_TYPE_BF16;
 
-    const half * src1_f16 = (const half *) src1->data;
+    float * dst_ddf = (float *) dst->data;
     const size_t ts_src1 = ggml_type_size(src1->type);
     GGML_ASSERT(nb10 == ts_src1);
     int64_t s11 = nb11 / ts_src1;
     int64_t s12 = nb12 / ts_src1;
     int64_t s13 = nb13 / ts_src1;
+
+    const half * src0_f16 = nullptr;
+    const half * src1_f16 = nullptr;
+    const nv_bfloat16 * src0_bf16 = nullptr;
+    const nv_bfloat16 * src1_bf16 = nullptr;
+    const float * src0_f32 = nullptr;
+    const float * src1_f32 = nullptr;
+
+    ggml_cuda_pool_alloc<half> src0_f16_alloc(ctx.pool());
     ggml_cuda_pool_alloc<half> src1_f16_alloc(ctx.pool());
+    ggml_cuda_pool_alloc<nv_bfloat16> src0_bf16_alloc(ctx.pool());
+    ggml_cuda_pool_alloc<nv_bfloat16> src1_bf16_alloc(ctx.pool());
+    ggml_cuda_pool_alloc<float> src0_f32_alloc(ctx.pool());
+    ggml_cuda_pool_alloc<float> src1_f32_alloc(ctx.pool());
+
+    if (use_f32_path) {
+        // F32 path
+        src0_f32 = (const float *) src0->data;
+        if (src1->type == GGML_TYPE_F32) {
+            src1_f32 = (const float *) src1->data;
+        } else {
+            // Convert src1 to F32
+            const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(src1->type);
+            const int64_t ne_src1 = ggml_nelements(src1);
+            src1_f32_alloc.alloc(ne_src1);
+            GGML_ASSERT(to_fp32_cuda != nullptr);
 
-    // convert src1 to fp16
-    if (src1->type != GGML_TYPE_F16) {
-        const to_fp16_nc_cuda_t to_fp16_cuda = ggml_get_to_fp16_nc_cuda(src1->type);
-        const int64_t ne_src1 = ggml_nelements(src1);
-        src1_f16_alloc.alloc(ne_src1);
-        GGML_ASSERT(to_fp16_cuda != nullptr);
+            to_fp32_cuda((const void*)((const char*)src1->data), src1_f32_alloc.get(), ne_src1, main_stream);
+            src1_f32 = src1_f32_alloc.get();
+            s11 = ne10;
+            s12 = ne11*s11;
+            s13 = ne12*s12;
+        }
+    } else if (use_bf16_path) {
+        // BF16 path
+        src0_bf16 = (const nv_bfloat16 *) src0->data;
+        if (src1->type == GGML_TYPE_BF16) {
+            src1_bf16 = (const nv_bfloat16 *) src1->data;
+        } else {
+            // Convert src1 to BF16
+            const to_bf16_nc_cuda_t to_bf16_cuda = ggml_get_to_bf16_nc_cuda(src1->type);
+            const int64_t ne_src1 = ggml_nelements(src1);
+            src1_bf16_alloc.alloc(ne_src1);
+            GGML_ASSERT(to_bf16_cuda != nullptr);
 
-        to_fp16_cuda(src1_f16, src1_f16_alloc.get(), ne10, ne11, ne12, ne13, s11, s12, s13, main_stream);
+            to_bf16_cuda((const void*)((const char*)src1->data), src1_bf16_alloc.get(), ne10, ne11, ne12, ne13, s11, s12, s13, main_stream);
+            src1_bf16 = src1_bf16_alloc.get();
+            s11 = ne10;
+            s12 = ne11*s11;
+            s13 = ne12*s12;
+        }
+    } else {
+        // F16 path (default)
+        src0_f16 = (const half *) src0->data;
+        if (src1->type == GGML_TYPE_F16) {
+            src1_f16 = (const half *) src1->data;
+        } else {
+            // Convert src1 to F16
+            const to_fp16_nc_cuda_t to_fp16_cuda = ggml_get_to_fp16_nc_cuda(src1->type);
+            const int64_t ne_src1 = ggml_nelements(src1);
+            src1_f16_alloc.alloc(ne_src1);
+            GGML_ASSERT(to_fp16_cuda != nullptr);
 
-        src1_f16 = src1_f16_alloc.get();
-        s11 = ne10;
-        s12 = ne11*s11;
-        s13 = ne12*s12;
+            to_fp16_cuda((const void*)((const char*)src1->data), src1_f16_alloc.get(), ne10, ne11, ne12, ne13, s11, s12, s13, main_stream);
+            src1_f16 = src1_f16_alloc.get();
+            s11 = ne10;
+            s12 = ne11*s11;
+            s13 = ne12*s12;
+        }
     }
 
     ggml_cuda_pool_alloc<half> dst_f16(ctx.pool());
+    ggml_cuda_pool_alloc<nv_bfloat16> dst_bf16(ctx.pool());
     char * dst_t;
 
-    cublasComputeType_t cu_compute_type = CUBLAS_COMPUTE_16F;
-    cudaDataType_t      cu_data_type    = CUDA_R_16F;
+    cublasComputeType_t cu_compute_type;
+    cudaDataType_t cu_data_type;
+    cudaDataType_t cu_data_type_a;
+    cudaDataType_t cu_data_type_b;
+
+    if (use_f32_path) {
+        cu_compute_type = CUBLAS_COMPUTE_32F;
+        cu_data_type = CUDA_R_32F;
+        cu_data_type_a = CUDA_R_32F;
+        cu_data_type_b = CUDA_R_32F;
+    } else if (use_bf16_path) {
+        cu_compute_type = CUBLAS_COMPUTE_32F;
+        cu_data_type = CUDA_R_16BF;
+        cu_data_type_a = CUDA_R_16BF;
+        cu_data_type_b = CUDA_R_16BF;
+    } else {
+        cu_compute_type = CUBLAS_COMPUTE_16F;
+        cu_data_type = CUDA_R_16F;
+        cu_data_type_a = CUDA_R_16F;
+        cu_data_type_b = CUDA_R_16F;
+    }
 
-    // dst strides
     size_t nbd2 = dst->nb[2];
     size_t nbd3 = dst->nb[3];
 
     const half  alpha_f16 = 1.0f;
     const half  beta_f16  = 0.0f;
-
     const float alpha_f32 = 1.0f;
     const float beta_f32  = 0.0f;
 
-    const void * alpha = &alpha_f16;
-    const void * beta  = &beta_f16;
+    const void * alpha;
+    const void * beta;
 
-    if (dst->op_params[0] == GGML_PREC_DEFAULT) {
-        dst_t = (char *) dst_f16.alloc(ne_dst);
+    if (use_f32_path || cu_compute_type == CUBLAS_COMPUTE_32F) {
+        alpha = &alpha_f32;
+        beta = &beta_f32;
+    } else if (use_bf16_path) {
+        alpha = &alpha_f32;
+        beta = &beta_f32;
+    } else {
+        alpha = &alpha_f16;
+        beta = &beta_f16;
+    }
 
-        nbd2 /= sizeof(float) / sizeof(half);
-        nbd3 /= sizeof(float) / sizeof(half);
+    if (dst->op_params[0] == GGML_PREC_DEFAULT) {
+        if (use_f32_path) {
+            dst_t = (char *) dst_ddf;  // Direct F32 output
+        } else if (use_bf16_path) {
+            dst_t = (char *) dst_bf16.alloc(ne_dst);
+            nbd2 /= sizeof(float) / sizeof(nv_bfloat16);
+            nbd3 /= sizeof(float) / sizeof(nv_bfloat16);
+        } else {
+            dst_t = (char *) dst_f16.alloc(ne_dst);
+            nbd2 /= sizeof(float) / sizeof(half);
+            nbd3 /= sizeof(float) / sizeof(half);
+        }
     } else {
         dst_t = (char *) dst_ddf;
-
         cu_compute_type = CUBLAS_COMPUTE_32F;
-        cu_data_type    = CUDA_R_32F;
-
+        cu_data_type = CUDA_R_32F;
         alpha = &alpha_f32;
-        beta  = &beta_f32;
+        beta = &beta_f32;
     }
 
     int id = ggml_cuda_get_device();
@@ -1889,11 +1979,16 @@ static void ggml_cuda_mul_mat_batched_cublas(ggml_backend_cuda_context & ctx, co
     if (r2 == 1 && r3 == 1 && ggml_is_contiguous_2(src0) && ggml_is_contiguous_2(src1)) {
         // there is no broadcast and src0, src1 are contiguous across dims 2, 3
         // use cublasGemmStridedBatchedEx
+        const void * src0_ptr = use_f32_path ? (const void*)src0_f32 :
+                               use_bf16_path ? (const void*)src0_bf16 : (const void*)src0_f16;
+        const void * src1_ptr = use_f32_path ? (const void*)src1_f32 :
+                               use_bf16_path ? (const void*)src1_bf16 : (const void*)src1_f16;
+
         CUBLAS_CHECK(
         cublasGemmStridedBatchedEx(ctx.cublas_handle(), CUBLAS_OP_T, CUBLAS_OP_N,
                 ne01, ne11, ne10,
-                alpha, src0_f16, CUDA_R_16F,   nb01/nb00, nb02/nb00, // strideA
-                       src1_f16, CUDA_R_16F,   s11,       s12,       // strideB
+                alpha, src0_ptr, cu_data_type_a, nb01/nb00, nb02/nb00, // strideA
+                       src1_ptr, cu_data_type_b, s11,       s12,       // strideB
                 beta,     dst_t, cu_data_type, ne0,       ne1*ne0,   // strideC
                 ne12*ne13,
                 cu_compute_type,
@@ -1905,24 +2000,57 @@ static void ggml_cuda_mul_mat_batched_cublas(ggml_backend_cuda_context & ctx, co
         ggml_cuda_pool_alloc<const void *> ptrs_src(ctx.pool(), 2*ne23);
         ggml_cuda_pool_alloc<      void *> ptrs_dst(ctx.pool(), 1*ne23);
 
+        const void * src0_ptr = use_f32_path ? (const void*)src0_f32 :
+                               use_bf16_path ? (const void*)src0_bf16 : (const void*)src0_f16;
+        const void * src1_ptr = use_f32_path ? (const void*)src1_f32 :
+                               use_bf16_path ? (const void*)src1_bf16 : (const void*)src1_f16;
+
+        size_t src1_stride_size = use_f32_path ? sizeof(float) :
+                                 use_bf16_path ? sizeof(nv_bfloat16) : sizeof(half);
+
         dim3 block_dims(ne13, ne12);
-        k_compute_batched_ptrs<<<1, block_dims, 0, main_stream>>>(
-                src0_f16, src1_f16, dst_t,
-                ptrs_src.get(), ptrs_dst.get(),
-                ne12, ne13,
-                ne23,
-                nb02, nb03,
-                src1->type == GGML_TYPE_F16 ? nb12 : s12*sizeof(half),
-                src1->type == GGML_TYPE_F16 ? nb13 : s13*sizeof(half),
-                nbd2, nbd3,
-                r2, r3);
+        if( use_f32_path ) {
+            k_compute_batched_ptrs<<<1, block_dims, 0, main_stream>>>(
+                    (const float*)src0_ptr, (const float*)src1_ptr, dst_t,
+                    ptrs_src.get(), ptrs_dst.get(),
+                    ne12, ne13,
+                    ne23,
+                    nb02, nb03,
+                    (src1->type == src0_type) ? nb12 : s12*src1_stride_size,
+                    (src1->type == src0_type) ? nb13 : s13*src1_stride_size,
+                    nbd2, nbd3,
+                    r2, r3);
+        } else if (use_bf16_path) {
+            k_compute_batched_ptrs<<<1, block_dims, 0, main_stream>>>(
+                    (const nv_bfloat16*)src0_ptr, (const nv_bfloat16*)src1_ptr, dst_t,
+                    ptrs_src.get(), ptrs_dst.get(),
+                    ne12, ne13,
+                    ne23,
+                    nb02, nb03,
+                    (src1->type == src0_type) ? nb12 : s12*src1_stride_size,
+                    (src1->type == src0_type) ? nb13 : s13*src1_stride_size,
+                    nbd2, nbd3,
+                    r2, r3);
+        } else {
+            k_compute_batched_ptrs<<<1, block_dims, 0, main_stream>>>(
+                    (const half*)src0_ptr, (const half*)src1_ptr, dst_t,
+                    ptrs_src.get(), ptrs_dst.get(),
+                    ne12, ne13,
+                    ne23,
+                    nb02, nb03,
+                    (src1->type == src0_type) ? nb12 : s12*src1_stride_size,
+                    (src1->type == src0_type) ? nb13 : s13*src1_stride_size,
+                    nbd2, nbd3,
+                    r2, r3);
+        }
+
         CUDA_CHECK(cudaGetLastError());
 
         CUBLAS_CHECK(
         cublasGemmBatchedEx(ctx.cublas_handle(), CUBLAS_OP_T, CUBLAS_OP_N,
                 ne01, ne11, ne10,
-                alpha, (const void **) (ptrs_src.get() + 0*ne23), CUDA_R_16F,   nb01/nb00,
-                       (const void **) (ptrs_src.get() + 1*ne23), CUDA_R_16F,   s11,
+                alpha, (const void **) (ptrs_src.get() + 0*ne23), cu_data_type_a, nb01/nb00,
+                       (const void **) (ptrs_src.get() + 1*ne23), cu_data_type_b, s11,
                 beta,  (      void **) (ptrs_dst.get() + 0*ne23), cu_data_type, ne0,
                 ne23,
                 cu_compute_type,
@@ -1930,9 +2058,16 @@ static void ggml_cuda_mul_mat_batched_cublas(ggml_backend_cuda_context & ctx, co
     }
 #endif
 
-    if (dst->op_params[0] == GGML_PREC_DEFAULT && cu_data_type == CUDA_R_16F) {
-        const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(GGML_TYPE_F16);
-        to_fp32_cuda(dst_f16.get(), dst_ddf, ne_dst, main_stream);
+    if (dst->op_params[0] == GGML_PREC_DEFAULT) {
+        if (use_f32_path) {
+            //already in f32
+        } else if (use_bf16_path && cu_data_type == CUDA_R_16BF) {
+            const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(GGML_TYPE_BF16);
+            to_fp32_cuda(dst_bf16.get(), dst_ddf, ne_dst, main_stream);
+        } else if (cu_data_type == CUDA_R_16F) {
+            const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(GGML_TYPE_F16);
+            to_fp32_cuda(dst_f16.get(), dst_ddf, ne_dst, main_stream);
+        }
     }
 }
 
@@ -1992,8 +2127,9 @@ static void ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor
         ggml_cuda_mul_mat_vec_q(ctx, src0, src1, nullptr, dst);
     } else if (!split && use_mul_mat_q) {
         ggml_cuda_mul_mat_q(ctx, src0, src1, nullptr, dst);
-    } else if (!split && src0->type == GGML_TYPE_F16 && (src1->type == GGML_TYPE_F16 || !any_gpus_with_slow_fp16) &&
-            !ggml_is_transposed(src0) && !ggml_is_transposed(src1) && src1->ne[2]*src1->ne[3] > 1) {
+    } else if (!split && (src0->type == GGML_TYPE_F16 || src0->type == GGML_TYPE_BF16 || src0->type == GGML_TYPE_F32)
+        && (src1->type == GGML_TYPE_F16 || src1->type == GGML_TYPE_BF16 || src1->type == GGML_TYPE_F32)
+        && !ggml_is_transposed(src0) && !ggml_is_transposed(src1) && src1->ne[2]*src1->ne[3] > 1) {
         // general KQ + KQV multi-batch without FlashAttention
         ggml_cuda_mul_mat_batched_cublas(ctx, src0, src1, dst);
     } else if (use_mul_mat_vec) {
diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp
index a233f1f2fd97a..a596948b44bb6 100644
--- a/tests/test-backend-ops.cpp
+++ b/tests/test-backend-ops.cpp
@@ -4425,8 +4425,10 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
         for (auto nr : {1,4}) {
             for (uint32_t m = 0; m < 2; ++m) {
                 for (uint32_t k = 0; k < 2; ++k) {
-                    test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F16, GGML_TYPE_F32, 1056 + m, 1, 128 + k,  {bs,  1}, {nr, 1}, {0, 2, 1, 3}));
-                    test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F16, GGML_TYPE_F32, 128 + m,  1, 1056 + k, {bs,  1}, {nr, 1}, {0, 1, 2, 3}, true));
+                    for(ggml_type type: {GGML_TYPE_F16, GGML_TYPE_BF16, GGML_TYPE_F32}){
+                        test_cases.emplace_back(new test_mul_mat(type, GGML_TYPE_F32, 1056 + m, 1, 128 + k,  {bs,  1}, {nr, 1}, {0, 2, 1, 3}));
+                        test_cases.emplace_back(new test_mul_mat(type, GGML_TYPE_F32, 128 + m,  1, 1056 + k, {bs,  1}, {nr, 1}, {0, 1, 2, 3}, true));
+                    }
                 }
             }
         }

From 526fc4e098653f2954a2fad60a352758349a75b9 Mon Sep 17 00:00:00 2001
From: Aman Gupta <amangupta052@gmail.com>
Date: Wed, 25 Jun 2025 15:47:59 +0800
Subject: [PATCH 2/4] Review: add type traits and make function more generic

---
 ggml/src/ggml-cuda/convert.cu   |  11 ++
 ggml/src/ggml-cuda/convert.cuh  |   3 +
 ggml/src/ggml-cuda/ggml-cuda.cu | 327 ++++++++++++--------------------
 3 files changed, 134 insertions(+), 207 deletions(-)

diff --git a/ggml/src/ggml-cuda/convert.cu b/ggml/src/ggml-cuda/convert.cu
index a5a53f6239a4a..eeaa14bf57950 100644
--- a/ggml/src/ggml-cuda/convert.cu
+++ b/ggml/src/ggml-cuda/convert.cu
@@ -739,3 +739,14 @@ to_bf16_nc_cuda_t ggml_get_to_bf16_nc_cuda(ggml_type type) {
             return nullptr;
     }
 }
+
+to_fp32_nc_cuda_t ggml_get_to_fp32_nc_cuda(ggml_type type) {
+    switch (type) {
+        case GGML_TYPE_F16:
+            return convert_unary_cuda<half, float>;
+        case GGML_TYPE_BF16:
+            return convert_unary_cuda<nv_bfloat16, float>;
+        default:
+            return nullptr;
+    }
+}
diff --git a/ggml/src/ggml-cuda/convert.cuh b/ggml/src/ggml-cuda/convert.cuh
index debda3b03dc26..f04214be175ba 100644
--- a/ggml/src/ggml-cuda/convert.cuh
+++ b/ggml/src/ggml-cuda/convert.cuh
@@ -22,7 +22,10 @@ using to_t_nc_cuda_t = void (*)(const void * x, T * y,
     int64_t ne00, int64_t ne01, int64_t ne02, int64_t ne03,
     int64_t s01, int64_t s02, int64_t s03, cudaStream_t stream);
 
+typedef to_t_nc_cuda_t<float> to_fp32_nc_cuda_t;
 typedef to_t_nc_cuda_t<half> to_fp16_nc_cuda_t;
 typedef to_t_nc_cuda_t<nv_bfloat16> to_bf16_nc_cuda_t;
+
+to_fp32_nc_cuda_t ggml_get_to_fp32_nc_cuda(ggml_type type);
 to_fp16_nc_cuda_t ggml_get_to_fp16_nc_cuda(ggml_type type);
 to_bf16_nc_cuda_t ggml_get_to_bf16_nc_cuda(ggml_type type);
diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu
index bfac2363c900a..cc8b5bf291b15 100644
--- a/ggml/src/ggml-cuda/ggml-cuda.cu
+++ b/ggml/src/ggml-cuda/ggml-cuda.cu
@@ -1748,9 +1748,8 @@ static void ggml_cuda_op_mul_mat(
     }
 }
 
-template<typename T>
 static __global__ void k_compute_batched_ptrs(
-        const T * src0_as_f16, const T * src1_as_f16, char * dst,
+        const void * src0_as_f16, const void * src1_as_f16, char * dst,
         const void ** ptrs_src, void ** ptrs_dst,
         int64_t ne12, int64_t ne13,
         int64_t ne23,
@@ -1773,29 +1772,66 @@ static __global__ void k_compute_batched_ptrs(
     ptrs_dst[0*ne23 + i12 + i13*ne12] = (      char *)         dst + i12*nbd2 + i13*nbd3;
 }
 
-static void ggml_cuda_mul_mat_batched_cublas(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
+// Type traits for CUDA types
+template<ggml_type T>
+struct batched_mul_mat_traits;
+
+template<>
+struct batched_mul_mat_traits<GGML_TYPE_F32> {
+    using cuda_type = float;
+    static inline const cublasComputeType_t compute_type = CUBLAS_COMPUTE_32F;
+    static inline const cudaDataType_t data_type = CUDA_R_32F;
+    static inline const ggml_type ggml_type_val = GGML_TYPE_F32;
+    static inline const float alpha = 1.0f;
+    static inline const float beta = 0.0f;
+    static inline const void* get_alpha() { static const float val = alpha; return &val; }
+    static inline const void* get_beta() { static const float val = beta; return &val; }
+    static inline auto get_nc_converter(ggml_type src_type) { return ggml_get_to_fp32_nc_cuda(src_type); }
+};
+
+template<>
+struct batched_mul_mat_traits<GGML_TYPE_BF16> {
+    using cuda_type = nv_bfloat16;
+    static inline const cublasComputeType_t compute_type = CUBLAS_COMPUTE_32F;
+    static inline const cudaDataType_t data_type = CUDA_R_16BF;
+    static inline const ggml_type ggml_type_val = GGML_TYPE_BF16;
+    static inline const float alpha = 1.0f;
+    static inline const float beta = 0.0f;
+    static inline const void* get_alpha() { static const float val = alpha; return &val; }
+    static inline const void* get_beta() { static const float val = beta; return &val; }
+    static inline auto get_nc_converter(ggml_type src_type) { return ggml_get_to_bf16_nc_cuda(src_type); }
+};
+
+template<>
+struct batched_mul_mat_traits<GGML_TYPE_F16> {
+    using cuda_type = half;
+    static inline const cublasComputeType_t compute_type = CUBLAS_COMPUTE_16F;
+    static inline const cudaDataType_t data_type = CUDA_R_16F;
+    static inline const ggml_type ggml_type_val = GGML_TYPE_F16;
+    static inline const half alpha = 1.0;
+    static inline const half beta = 0.0;
+    static inline const void* get_alpha() { static const half val = alpha; return &val; }
+    static inline const void* get_beta() { static const half val = beta; return &val; }
+    static inline auto get_nc_converter(ggml_type src_type) { return ggml_get_to_fp16_nc_cuda(src_type); }
+};
+
+template<ggml_type src0_type>
+static void ggml_cuda_mul_mat_batched_cublas_impl(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
+    using traits = batched_mul_mat_traits<src0_type>;
+    using cuda_t = typename traits::cuda_type;
+
     GGML_ASSERT(!ggml_is_transposed(src0));
     GGML_ASSERT(!ggml_is_transposed(src1));
-
     GGML_ASSERT(!ggml_backend_buft_is_cuda_split(src0->buffer->buft));
-    GGML_ASSERT(src0->type == GGML_TYPE_F16 || src0->type == GGML_TYPE_BF16 || src0->type == GGML_TYPE_F32);
-
-    // Byte offsets and tensor dimensions are currently used in an inconsistent way for dst.
-    // As long as dst is contiguous this does not matter though.
+    GGML_ASSERT(src0->type == src0_type);
     GGML_ASSERT(ggml_is_contiguous(dst));
 
     GGML_TENSOR_BINARY_OP_LOCALS
 
     const int64_t ne_dst = ggml_nelements(dst);
-
     cudaStream_t main_stream = ctx.stream();
-
     CUBLAS_CHECK(cublasSetStream(ctx.cublas_handle(), main_stream));
 
-    const ggml_type src0_type = src0->type;
-    const bool use_f32_path = src0_type == GGML_TYPE_F32;
-    const bool use_bf16_path = src0_type == GGML_TYPE_BF16;
-
     float * dst_ddf = (float *) dst->data;
     const size_t ts_src1 = ggml_type_size(src1->type);
     GGML_ASSERT(nb10 == ts_src1);
@@ -1803,140 +1839,59 @@ static void ggml_cuda_mul_mat_batched_cublas(ggml_backend_cuda_context & ctx, co
     int64_t s12 = nb12 / ts_src1;
     int64_t s13 = nb13 / ts_src1;
 
-    const half * src0_f16 = nullptr;
-    const half * src1_f16 = nullptr;
-    const nv_bfloat16 * src0_bf16 = nullptr;
-    const nv_bfloat16 * src1_bf16 = nullptr;
-    const float * src0_f32 = nullptr;
-    const float * src1_f32 = nullptr;
-
-    ggml_cuda_pool_alloc<half> src0_f16_alloc(ctx.pool());
-    ggml_cuda_pool_alloc<half> src1_f16_alloc(ctx.pool());
-    ggml_cuda_pool_alloc<nv_bfloat16> src0_bf16_alloc(ctx.pool());
-    ggml_cuda_pool_alloc<nv_bfloat16> src1_bf16_alloc(ctx.pool());
-    ggml_cuda_pool_alloc<float> src0_f32_alloc(ctx.pool());
-    ggml_cuda_pool_alloc<float> src1_f32_alloc(ctx.pool());
-
-    if (use_f32_path) {
-        // F32 path
-        src0_f32 = (const float *) src0->data;
-        if (src1->type == GGML_TYPE_F32) {
-            src1_f32 = (const float *) src1->data;
-        } else {
-            // Convert src1 to F32
-            const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(src1->type);
-            const int64_t ne_src1 = ggml_nelements(src1);
-            src1_f32_alloc.alloc(ne_src1);
-            GGML_ASSERT(to_fp32_cuda != nullptr);
+    const cuda_t * src0_ptr = nullptr;
+    const cuda_t * src1_ptr = nullptr;
 
-            to_fp32_cuda((const void*)((const char*)src1->data), src1_f32_alloc.get(), ne_src1, main_stream);
-            src1_f32 = src1_f32_alloc.get();
-            s11 = ne10;
-            s12 = ne11*s11;
-            s13 = ne12*s12;
-        }
-    } else if (use_bf16_path) {
-        // BF16 path
-        src0_bf16 = (const nv_bfloat16 *) src0->data;
-        if (src1->type == GGML_TYPE_BF16) {
-            src1_bf16 = (const nv_bfloat16 *) src1->data;
-        } else {
-            // Convert src1 to BF16
-            const to_bf16_nc_cuda_t to_bf16_cuda = ggml_get_to_bf16_nc_cuda(src1->type);
-            const int64_t ne_src1 = ggml_nelements(src1);
-            src1_bf16_alloc.alloc(ne_src1);
-            GGML_ASSERT(to_bf16_cuda != nullptr);
+    ggml_cuda_pool_alloc<cuda_t> src0_alloc(ctx.pool());
+    ggml_cuda_pool_alloc<cuda_t> src1_alloc(ctx.pool());
 
-            to_bf16_cuda((const void*)((const char*)src1->data), src1_bf16_alloc.get(), ne10, ne11, ne12, ne13, s11, s12, s13, main_stream);
-            src1_bf16 = src1_bf16_alloc.get();
-            s11 = ne10;
-            s12 = ne11*s11;
-            s13 = ne12*s12;
-        }
+    // Handle src0
+    src0_ptr = (const cuda_t *) src0->data;
+
+    // Handle src1 - convert if necessary
+    if (src1->type == src0_type) {
+        src1_ptr = (const cuda_t *) src1->data;
     } else {
-        // F16 path (default)
-        src0_f16 = (const half *) src0->data;
-        if (src1->type == GGML_TYPE_F16) {
-            src1_f16 = (const half *) src1->data;
-        } else {
-            // Convert src1 to F16
-            const to_fp16_nc_cuda_t to_fp16_cuda = ggml_get_to_fp16_nc_cuda(src1->type);
-            const int64_t ne_src1 = ggml_nelements(src1);
-            src1_f16_alloc.alloc(ne_src1);
-            GGML_ASSERT(to_fp16_cuda != nullptr);
+        // Convert src1 to target type using traits conversion functions
+        const int64_t ne_src1 = ggml_nelements(src1);
+        src1_alloc.alloc(ne_src1);
 
-            to_fp16_cuda((const void*)((const char*)src1->data), src1_f16_alloc.get(), ne10, ne11, ne12, ne13, s11, s12, s13, main_stream);
-            src1_f16 = src1_f16_alloc.get();
-            s11 = ne10;
-            s12 = ne11*s11;
-            s13 = ne12*s12;
-        }
+        const auto convert_func = traits::get_nc_converter(src1->type);
+        GGML_ASSERT(convert_func != nullptr);
+        convert_func((const void*)((const char*)src1->data), src1_alloc.get(), ne10, ne11, ne12, ne13, s11, s12, s13, main_stream);
+        src1_ptr = src1_alloc.get();
+        s11 = ne10;
+        s12 = ne11*s11;
+        s13 = ne12*s12;
     }
 
-    ggml_cuda_pool_alloc<half> dst_f16(ctx.pool());
-    ggml_cuda_pool_alloc<nv_bfloat16> dst_bf16(ctx.pool());
+    // Setup destination buffer
+    ggml_cuda_pool_alloc<cuda_t> dst_temp(ctx.pool());
     char * dst_t;
-
-    cublasComputeType_t cu_compute_type;
-    cudaDataType_t cu_data_type;
-    cudaDataType_t cu_data_type_a;
-    cudaDataType_t cu_data_type_b;
-
-    if (use_f32_path) {
-        cu_compute_type = CUBLAS_COMPUTE_32F;
-        cu_data_type = CUDA_R_32F;
-        cu_data_type_a = CUDA_R_32F;
-        cu_data_type_b = CUDA_R_32F;
-    } else if (use_bf16_path) {
-        cu_compute_type = CUBLAS_COMPUTE_32F;
-        cu_data_type = CUDA_R_16BF;
-        cu_data_type_a = CUDA_R_16BF;
-        cu_data_type_b = CUDA_R_16BF;
-    } else {
-        cu_compute_type = CUBLAS_COMPUTE_16F;
-        cu_data_type = CUDA_R_16F;
-        cu_data_type_a = CUDA_R_16F;
-        cu_data_type_b = CUDA_R_16F;
-    }
-
     size_t nbd2 = dst->nb[2];
     size_t nbd3 = dst->nb[3];
 
-    const half  alpha_f16 = 1.0f;
-    const half  beta_f16  = 0.0f;
-    const float alpha_f32 = 1.0f;
-    const float beta_f32  = 0.0f;
-
-    const void * alpha;
-    const void * beta;
-
-    if (use_f32_path || cu_compute_type == CUBLAS_COMPUTE_32F) {
-        alpha = &alpha_f32;
-        beta = &beta_f32;
-    } else if (use_bf16_path) {
-        alpha = &alpha_f32;
-        beta = &beta_f32;
-    } else {
-        alpha = &alpha_f16;
-        beta = &beta_f16;
-    }
+    cublasComputeType_t cu_compute_type = traits::compute_type;
+    cudaDataType_t cu_data_type = traits::data_type;
+    cudaDataType_t cu_data_type_a = traits::data_type;
+    cudaDataType_t cu_data_type_b = traits::data_type;
+    const void * alpha = traits::get_alpha();
+    const void * beta = traits::get_beta();
 
     if (dst->op_params[0] == GGML_PREC_DEFAULT) {
-        if (use_f32_path) {
+        if constexpr (src0_type == GGML_TYPE_F32) {
             dst_t = (char *) dst_ddf;  // Direct F32 output
-        } else if (use_bf16_path) {
-            dst_t = (char *) dst_bf16.alloc(ne_dst);
-            nbd2 /= sizeof(float) / sizeof(nv_bfloat16);
-            nbd3 /= sizeof(float) / sizeof(nv_bfloat16);
         } else {
-            dst_t = (char *) dst_f16.alloc(ne_dst);
-            nbd2 /= sizeof(float) / sizeof(half);
-            nbd3 /= sizeof(float) / sizeof(half);
+            dst_t = (char *) dst_temp.alloc(ne_dst);
+            nbd2 /= sizeof(float) / sizeof(cuda_t);
+            nbd3 /= sizeof(float) / sizeof(cuda_t);
         }
     } else {
         dst_t = (char *) dst_ddf;
         cu_compute_type = CUBLAS_COMPUTE_32F;
         cu_data_type = CUDA_R_32F;
+        const float alpha_f32 = 1.0f;
+        const float beta_f32 = 0.0f;
         alpha = &alpha_f32;
         beta = &beta_f32;
     }
@@ -1945,8 +1900,10 @@ static void ggml_cuda_mul_mat_batched_cublas(ggml_backend_cuda_context & ctx, co
     const int cc = ggml_cuda_info().devices[id].cc;
     if (GGML_CUDA_CC_IS_CDNA(cc) || GGML_CUDA_CC_IS_RDNA4(cc)) {
         cu_compute_type = CUBLAS_COMPUTE_32F;
+        const float alpha_f32 = 1.0f;
+        const float beta_f32 = 0.0f;
         alpha = &alpha_f32;
-        beta  = &beta_f32;
+        beta = &beta_f32;
     }
 
     GGML_ASSERT(ne12 % ne02 == 0);
@@ -1956,34 +1913,9 @@ static void ggml_cuda_mul_mat_batched_cublas(ggml_backend_cuda_context & ctx, co
     const int64_t r2 = ne12/ne02;
     const int64_t r3 = ne13/ne03;
 
-#if 0
-    // use cublasGemmEx
-    {
-        for (int i13 = 0; i13 < ne13; ++i13) {
-            for (int i12 = 0; i12 < ne12; ++i12) {
-                int i03 = i13 / r3;
-                int i02 = i12 / r2;
-
-                CUBLAS_CHECK(
-                cublasGemmEx(ctx.cublas_handle(), CUBLAS_OP_T, CUBLAS_OP_N,
-                    ne01, ne11, ne10,
-                    alpha, (const char *) src0_f16 + i03*nb03 + i02*nb02, CUDA_R_16F,   nb01/sizeof(half),
-                                          src1_f16 + i13*s13  + i12*s12,  CUDA_R_16F,   s11,
-                    beta,  (      char *)    dst_t + i13*nbd3 + i12*nbd2, cu_data_type, ne0,
-                    cu_compute_type,
-                    CUBLAS_GEMM_DEFAULT_TENSOR_OP));
-            }
-        }
-    }
-#else
     if (r2 == 1 && r3 == 1 && ggml_is_contiguous_2(src0) && ggml_is_contiguous_2(src1)) {
         // there is no broadcast and src0, src1 are contiguous across dims 2, 3
         // use cublasGemmStridedBatchedEx
-        const void * src0_ptr = use_f32_path ? (const void*)src0_f32 :
-                               use_bf16_path ? (const void*)src0_bf16 : (const void*)src0_f16;
-        const void * src1_ptr = use_f32_path ? (const void*)src1_f32 :
-                               use_bf16_path ? (const void*)src1_bf16 : (const void*)src1_f16;
-
         CUBLAS_CHECK(
         cublasGemmStridedBatchedEx(ctx.cublas_handle(), CUBLAS_OP_T, CUBLAS_OP_N,
                 ne01, ne11, ne10,
@@ -2000,49 +1932,19 @@ static void ggml_cuda_mul_mat_batched_cublas(ggml_backend_cuda_context & ctx, co
         ggml_cuda_pool_alloc<const void *> ptrs_src(ctx.pool(), 2*ne23);
         ggml_cuda_pool_alloc<      void *> ptrs_dst(ctx.pool(), 1*ne23);
 
-        const void * src0_ptr = use_f32_path ? (const void*)src0_f32 :
-                               use_bf16_path ? (const void*)src0_bf16 : (const void*)src0_f16;
-        const void * src1_ptr = use_f32_path ? (const void*)src1_f32 :
-                               use_bf16_path ? (const void*)src1_bf16 : (const void*)src1_f16;
-
-        size_t src1_stride_size = use_f32_path ? sizeof(float) :
-                                 use_bf16_path ? sizeof(nv_bfloat16) : sizeof(half);
+        size_t src1_stride_size = sizeof(cuda_t);
 
         dim3 block_dims(ne13, ne12);
-        if( use_f32_path ) {
-            k_compute_batched_ptrs<<<1, block_dims, 0, main_stream>>>(
-                    (const float*)src0_ptr, (const float*)src1_ptr, dst_t,
-                    ptrs_src.get(), ptrs_dst.get(),
-                    ne12, ne13,
-                    ne23,
-                    nb02, nb03,
-                    (src1->type == src0_type) ? nb12 : s12*src1_stride_size,
-                    (src1->type == src0_type) ? nb13 : s13*src1_stride_size,
-                    nbd2, nbd3,
-                    r2, r3);
-        } else if (use_bf16_path) {
-            k_compute_batched_ptrs<<<1, block_dims, 0, main_stream>>>(
-                    (const nv_bfloat16*)src0_ptr, (const nv_bfloat16*)src1_ptr, dst_t,
-                    ptrs_src.get(), ptrs_dst.get(),
-                    ne12, ne13,
-                    ne23,
-                    nb02, nb03,
-                    (src1->type == src0_type) ? nb12 : s12*src1_stride_size,
-                    (src1->type == src0_type) ? nb13 : s13*src1_stride_size,
-                    nbd2, nbd3,
-                    r2, r3);
-        } else {
-            k_compute_batched_ptrs<<<1, block_dims, 0, main_stream>>>(
-                    (const half*)src0_ptr, (const half*)src1_ptr, dst_t,
-                    ptrs_src.get(), ptrs_dst.get(),
-                    ne12, ne13,
-                    ne23,
-                    nb02, nb03,
-                    (src1->type == src0_type) ? nb12 : s12*src1_stride_size,
-                    (src1->type == src0_type) ? nb13 : s13*src1_stride_size,
-                    nbd2, nbd3,
-                    r2, r3);
-        }
+        k_compute_batched_ptrs<<<1, block_dims, 0, main_stream>>>(
+                src0_ptr, src1_ptr, dst_t,
+                ptrs_src.get(), ptrs_dst.get(),
+                ne12, ne13,
+                ne23,
+                nb02, nb03,
+                (src1->type == src0_type) ? nb12 : s12*src1_stride_size,
+                (src1->type == src0_type) ? nb13 : s13*src1_stride_size,
+                nbd2, nbd3,
+                r2, r3);
 
         CUDA_CHECK(cudaGetLastError());
 
@@ -2056,18 +1958,29 @@ static void ggml_cuda_mul_mat_batched_cublas(ggml_backend_cuda_context & ctx, co
                 cu_compute_type,
                 CUBLAS_GEMM_DEFAULT_TENSOR_OP));
     }
-#endif
 
-    if (dst->op_params[0] == GGML_PREC_DEFAULT) {
-        if (use_f32_path) {
-            //already in f32
-        } else if (use_bf16_path && cu_data_type == CUDA_R_16BF) {
-            const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(GGML_TYPE_BF16);
-            to_fp32_cuda(dst_bf16.get(), dst_ddf, ne_dst, main_stream);
-        } else if (cu_data_type == CUDA_R_16F) {
-            const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(GGML_TYPE_F16);
-            to_fp32_cuda(dst_f16.get(), dst_ddf, ne_dst, main_stream);
-        }
+    // Convert output back to F32 if needed
+    if (dst->op_params[0] == GGML_PREC_DEFAULT && cu_data_type != CUDA_R_32F) {
+        const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(traits::ggml_type_val);
+        to_fp32_cuda(dst_temp.get(), dst_ddf, ne_dst, main_stream);
+    }
+}
+
+static void ggml_cuda_mul_mat_batched_cublas(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
+    GGML_ASSERT(src0->type == GGML_TYPE_F16 || src0->type == GGML_TYPE_BF16 || src0->type == GGML_TYPE_F32);
+
+    switch (src0->type) {
+        case GGML_TYPE_F32:
+            ggml_cuda_mul_mat_batched_cublas_impl<GGML_TYPE_F32>(ctx, src0, src1, dst);
+            break;
+        case GGML_TYPE_BF16:
+            ggml_cuda_mul_mat_batched_cublas_impl<GGML_TYPE_BF16>(ctx, src0, src1, dst);
+            break;
+        case GGML_TYPE_F16:
+            ggml_cuda_mul_mat_batched_cublas_impl<GGML_TYPE_F16>(ctx, src0, src1, dst);
+            break;
+        default:
+            GGML_ABORT("Unsupported type");
     }
 }
 

From c02cd2fccd73847532c2c2f20eb2f100bf5729dc Mon Sep 17 00:00:00 2001
From: Aman Gupta <amangupta052@gmail.com>
Date: Wed, 25 Jun 2025 18:13:07 +0800
Subject: [PATCH 3/4] Review: make check more explicit, add back comments, and
 fix formatting

---
 ggml/src/ggml-cuda/ggml-cuda.cu | 19 ++++++++++++-------
 tests/test-backend-ops.cpp      |  2 +-
 2 files changed, 13 insertions(+), 8 deletions(-)

diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu
index cc8b5bf291b15..ce0551434f35a 100644
--- a/ggml/src/ggml-cuda/ggml-cuda.cu
+++ b/ggml/src/ggml-cuda/ggml-cuda.cu
@@ -1772,7 +1772,7 @@ static __global__ void k_compute_batched_ptrs(
     ptrs_dst[0*ne23 + i12 + i13*ne12] = (      char *)         dst + i12*nbd2 + i13*nbd3;
 }
 
-// Type traits for CUDA types
+// Type traits for mapping ggml types to CUDA/cuBLAS types
 template<ggml_type T>
 struct batched_mul_mat_traits;
 
@@ -1826,6 +1826,9 @@ static void ggml_cuda_mul_mat_batched_cublas_impl(ggml_backend_cuda_context & ct
     GGML_ASSERT(src0->type == src0_type);
     GGML_ASSERT(ggml_is_contiguous(dst));
 
+    // Byte offsets and tensor dimensions are currently used in an inconsistent way for dst.
+    // As long as dst is contiguous this does not matter though.
+
     GGML_TENSOR_BINARY_OP_LOCALS
 
     const int64_t ne_dst = ggml_nelements(dst);
@@ -1877,6 +1880,8 @@ static void ggml_cuda_mul_mat_batched_cublas_impl(ggml_backend_cuda_context & ct
     cudaDataType_t cu_data_type_b = traits::data_type;
     const void * alpha = traits::get_alpha();
     const void * beta = traits::get_beta();
+    const float alpha_f32 = 1.0f;
+    const float beta_f32 = 0.0f;
 
     if (dst->op_params[0] == GGML_PREC_DEFAULT) {
         if constexpr (src0_type == GGML_TYPE_F32) {
@@ -1890,8 +1895,6 @@ static void ggml_cuda_mul_mat_batched_cublas_impl(ggml_backend_cuda_context & ct
         dst_t = (char *) dst_ddf;
         cu_compute_type = CUBLAS_COMPUTE_32F;
         cu_data_type = CUDA_R_32F;
-        const float alpha_f32 = 1.0f;
-        const float beta_f32 = 0.0f;
         alpha = &alpha_f32;
         beta = &beta_f32;
     }
@@ -1900,8 +1903,6 @@ static void ggml_cuda_mul_mat_batched_cublas_impl(ggml_backend_cuda_context & ct
     const int cc = ggml_cuda_info().devices[id].cc;
     if (GGML_CUDA_CC_IS_CDNA(cc) || GGML_CUDA_CC_IS_RDNA4(cc)) {
         cu_compute_type = CUBLAS_COMPUTE_32F;
-        const float alpha_f32 = 1.0f;
-        const float beta_f32 = 0.0f;
         alpha = &alpha_f32;
         beta = &beta_f32;
     }
@@ -2032,6 +2033,11 @@ static void ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor
     //printf("src0 is contiguous %d, transposed %d, type = %s, name = %s\n", ggml_is_contiguous(src0), ggml_is_transposed(src0), ggml_type_name(src0->type), src0->name);
     //printf("src1 is contiguous %d, transposed %d, type = %s, name = %s\n", ggml_is_contiguous(src1), ggml_is_transposed(src1), ggml_type_name(src1->type), src1->name);
 
+    const int cc                     = ggml_cuda_info().devices[ggml_cuda_get_device()].cc;
+    bool can_use_batched_cublas_f16  = src0->type == GGML_TYPE_F16 && (src1->type == GGML_TYPE_F16 || !any_gpus_with_slow_fp16);
+    bool can_use_batched_cublas_bf16 = src0->type == GGML_TYPE_BF16 && bf16_mma_hardware_available(cc);
+    bool can_use_batched_cublas_f32  = src0->type == GGML_TYPE_F32;
+
     if (!split && use_mul_mat_vec) {
         // the custom F16 vector kernel can be used over batched cuBLAS GEMM
         // but this is only faster for GPUs without tensor cores or with a thin src0 matrix (particularly KQV in attention)
@@ -2040,8 +2046,7 @@ static void ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor
         ggml_cuda_mul_mat_vec_q(ctx, src0, src1, nullptr, dst);
     } else if (!split && use_mul_mat_q) {
         ggml_cuda_mul_mat_q(ctx, src0, src1, nullptr, dst);
-    } else if (!split && (src0->type == GGML_TYPE_F16 || src0->type == GGML_TYPE_BF16 || src0->type == GGML_TYPE_F32)
-        && (src1->type == GGML_TYPE_F16 || src1->type == GGML_TYPE_BF16 || src1->type == GGML_TYPE_F32)
+    } else if (!split && (can_use_batched_cublas_f16 || can_use_batched_cublas_bf16 || can_use_batched_cublas_f32)
         && !ggml_is_transposed(src0) && !ggml_is_transposed(src1) && src1->ne[2]*src1->ne[3] > 1) {
         // general KQ + KQV multi-batch without FlashAttention
         ggml_cuda_mul_mat_batched_cublas(ctx, src0, src1, dst);
diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp
index a596948b44bb6..128d63988f4e6 100644
--- a/tests/test-backend-ops.cpp
+++ b/tests/test-backend-ops.cpp
@@ -4425,7 +4425,7 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
         for (auto nr : {1,4}) {
             for (uint32_t m = 0; m < 2; ++m) {
                 for (uint32_t k = 0; k < 2; ++k) {
-                    for(ggml_type type: {GGML_TYPE_F16, GGML_TYPE_BF16, GGML_TYPE_F32}){
+                    for (ggml_type type: {GGML_TYPE_F16, GGML_TYPE_BF16, GGML_TYPE_F32}) {
                         test_cases.emplace_back(new test_mul_mat(type, GGML_TYPE_F32, 1056 + m, 1, 128 + k,  {bs,  1}, {nr, 1}, {0, 2, 1, 3}));
                         test_cases.emplace_back(new test_mul_mat(type, GGML_TYPE_F32, 128 + m,  1, 1056 + k, {bs,  1}, {nr, 1}, {0, 1, 2, 3}, true));
                     }

From 2c4e42edff637bcd2771c9dea4badf923f6b4d44 Mon Sep 17 00:00:00 2001
From: Aman Gupta <amangupta052@gmail.com>
Date: Wed, 25 Jun 2025 20:00:29 +0800
Subject: [PATCH 4/4] Review: fix formatting, remove useless type conversion,
 fix naming for bools

---
 ggml/src/ggml-cuda/ggml-cuda.cu | 15 ++++++++-------
 1 file changed, 8 insertions(+), 7 deletions(-)

diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu
index ce0551434f35a..811422f385073 100644
--- a/ggml/src/ggml-cuda/ggml-cuda.cu
+++ b/ggml/src/ggml-cuda/ggml-cuda.cu
@@ -1861,7 +1861,7 @@ static void ggml_cuda_mul_mat_batched_cublas_impl(ggml_backend_cuda_context & ct
 
         const auto convert_func = traits::get_nc_converter(src1->type);
         GGML_ASSERT(convert_func != nullptr);
-        convert_func((const void*)((const char*)src1->data), src1_alloc.get(), ne10, ne11, ne12, ne13, s11, s12, s13, main_stream);
+        convert_func(src1->data, src1_alloc.get(), ne10, ne11, ne12, ne13, s11, s12, s13, main_stream);
         src1_ptr = src1_alloc.get();
         s11 = ne10;
         s12 = ne11*s11;
@@ -1922,7 +1922,7 @@ static void ggml_cuda_mul_mat_batched_cublas_impl(ggml_backend_cuda_context & ct
                 ne01, ne11, ne10,
                 alpha, src0_ptr, cu_data_type_a, nb01/nb00, nb02/nb00, // strideA
                        src1_ptr, cu_data_type_b, s11,       s12,       // strideB
-                beta,     dst_t, cu_data_type, ne0,       ne1*ne0,   // strideC
+                beta,     dst_t, cu_data_type,   ne0,       ne1*ne0,   // strideC
                 ne12*ne13,
                 cu_compute_type,
                 CUBLAS_GEMM_DEFAULT_TENSOR_OP));
@@ -1954,7 +1954,7 @@ static void ggml_cuda_mul_mat_batched_cublas_impl(ggml_backend_cuda_context & ct
                 ne01, ne11, ne10,
                 alpha, (const void **) (ptrs_src.get() + 0*ne23), cu_data_type_a, nb01/nb00,
                        (const void **) (ptrs_src.get() + 1*ne23), cu_data_type_b, s11,
-                beta,  (      void **) (ptrs_dst.get() + 0*ne23), cu_data_type, ne0,
+                beta,  (      void **) (ptrs_dst.get() + 0*ne23), cu_data_type,   ne0,
                 ne23,
                 cu_compute_type,
                 CUBLAS_GEMM_DEFAULT_TENSOR_OP));
@@ -2033,10 +2033,11 @@ static void ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor
     //printf("src0 is contiguous %d, transposed %d, type = %s, name = %s\n", ggml_is_contiguous(src0), ggml_is_transposed(src0), ggml_type_name(src0->type), src0->name);
     //printf("src1 is contiguous %d, transposed %d, type = %s, name = %s\n", ggml_is_contiguous(src1), ggml_is_transposed(src1), ggml_type_name(src1->type), src1->name);
 
+    //TODO update for generic tensor parallelism
     const int cc                     = ggml_cuda_info().devices[ggml_cuda_get_device()].cc;
-    bool can_use_batched_cublas_f16  = src0->type == GGML_TYPE_F16 && (src1->type == GGML_TYPE_F16 || !any_gpus_with_slow_fp16);
-    bool can_use_batched_cublas_bf16 = src0->type == GGML_TYPE_BF16 && bf16_mma_hardware_available(cc);
-    bool can_use_batched_cublas_f32  = src0->type == GGML_TYPE_F32;
+    bool use_batched_cublas_f16  = src0->type == GGML_TYPE_F16 && (src1->type == GGML_TYPE_F16 || !any_gpus_with_slow_fp16);
+    bool use_batched_cublas_bf16 = src0->type == GGML_TYPE_BF16 && bf16_mma_hardware_available(cc);
+    bool use_batched_cublas_f32  = src0->type == GGML_TYPE_F32;
 
     if (!split && use_mul_mat_vec) {
         // the custom F16 vector kernel can be used over batched cuBLAS GEMM
@@ -2046,7 +2047,7 @@ static void ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor
         ggml_cuda_mul_mat_vec_q(ctx, src0, src1, nullptr, dst);
     } else if (!split && use_mul_mat_q) {
         ggml_cuda_mul_mat_q(ctx, src0, src1, nullptr, dst);
-    } else if (!split && (can_use_batched_cublas_f16 || can_use_batched_cublas_bf16 || can_use_batched_cublas_f32)
+    } else if (!split && (use_batched_cublas_f16 || use_batched_cublas_bf16 || use_batched_cublas_f32)
         && !ggml_is_transposed(src0) && !ggml_is_transposed(src1) && src1->ne[2]*src1->ne[3] > 1) {
         // general KQ + KQV multi-batch without FlashAttention
         ggml_cuda_mul_mat_batched_cublas(ctx, src0, src1, dst);