@@ -1769,7 +1769,7 @@ static __global__ void k_compute_batched_ptrs(
1769
1769
ptrs_dst[0 *ne23 + i12 + i13*ne12] = ( char *) dst + i12*nbd2 + i13*nbd3;
1770
1770
}
1771
1771
1772
- // Type traits for CUDA types
1772
+ // Type traits for mapping ggml types to CUDA/cuBLAS types
1773
1773
template <ggml_type T>
1774
1774
struct batched_mul_mat_traits ;
1775
1775
@@ -1823,6 +1823,9 @@ static void ggml_cuda_mul_mat_batched_cublas_impl(ggml_backend_cuda_context & ct
1823
1823
GGML_ASSERT (src0->type == src0_type);
1824
1824
GGML_ASSERT (ggml_is_contiguous (dst));
1825
1825
1826
+ // Byte offsets and tensor dimensions are currently used in an inconsistent way for dst.
1827
+ // As long as dst is contiguous this does not matter though.
1828
+
1826
1829
GGML_TENSOR_BINARY_OP_LOCALS
1827
1830
1828
1831
const int64_t ne_dst = ggml_nelements (dst);
@@ -1874,6 +1877,8 @@ static void ggml_cuda_mul_mat_batched_cublas_impl(ggml_backend_cuda_context & ct
1874
1877
cudaDataType_t cu_data_type_b = traits::data_type;
1875
1878
const void * alpha = traits::get_alpha ();
1876
1879
const void * beta = traits::get_beta ();
1880
+ const float alpha_f32 = 1 .0f ;
1881
+ const float beta_f32 = 0 .0f ;
1877
1882
1878
1883
if (dst->op_params [0 ] == GGML_PREC_DEFAULT) {
1879
1884
if constexpr (src0_type == GGML_TYPE_F32) {
@@ -1887,8 +1892,6 @@ static void ggml_cuda_mul_mat_batched_cublas_impl(ggml_backend_cuda_context & ct
1887
1892
dst_t = (char *) dst_ddf;
1888
1893
cu_compute_type = CUBLAS_COMPUTE_32F;
1889
1894
cu_data_type = CUDA_R_32F;
1890
- const float alpha_f32 = 1 .0f ;
1891
- const float beta_f32 = 0 .0f ;
1892
1895
alpha = &alpha_f32;
1893
1896
beta = &beta_f32;
1894
1897
}
@@ -1897,8 +1900,6 @@ static void ggml_cuda_mul_mat_batched_cublas_impl(ggml_backend_cuda_context & ct
1897
1900
const int cc = ggml_cuda_info ().devices [id].cc ;
1898
1901
if (GGML_CUDA_CC_IS_CDNA (cc) || GGML_CUDA_CC_IS_RDNA4 (cc)) {
1899
1902
cu_compute_type = CUBLAS_COMPUTE_32F;
1900
- const float alpha_f32 = 1 .0f ;
1901
- const float beta_f32 = 0 .0f ;
1902
1903
alpha = &alpha_f32;
1903
1904
beta = &beta_f32;
1904
1905
}
@@ -2029,6 +2030,11 @@ static void ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor
2029
2030
// printf("src0 is contiguous %d, transposed %d, type = %s, name = %s\n", ggml_is_contiguous(src0), ggml_is_transposed(src0), ggml_type_name(src0->type), src0->name);
2030
2031
// printf("src1 is contiguous %d, transposed %d, type = %s, name = %s\n", ggml_is_contiguous(src1), ggml_is_transposed(src1), ggml_type_name(src1->type), src1->name);
2031
2032
2033
+ const int cc = ggml_cuda_info ().devices [ggml_cuda_get_device ()].cc ;
2034
+ bool can_use_batched_cublas_f16 = src0->type == GGML_TYPE_F16 && (src1->type == GGML_TYPE_F16 || !any_gpus_with_slow_fp16);
2035
+ bool can_use_batched_cublas_bf16 = src0->type == GGML_TYPE_BF16 && bf16_mma_hardware_available (cc);
2036
+ bool can_use_batched_cublas_f32 = src0->type == GGML_TYPE_F32;
2037
+
2032
2038
if (!split && use_mul_mat_vec) {
2033
2039
// the custom F16 vector kernel can be used over batched cuBLAS GEMM
2034
2040
// but this is only faster for GPUs without tensor cores or with a thin src0 matrix (particularly KQV in attention)
@@ -2037,8 +2043,7 @@ static void ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor
2037
2043
ggml_cuda_mul_mat_vec_q (ctx, src0, src1, nullptr , dst);
2038
2044
} else if (!split && use_mul_mat_q) {
2039
2045
ggml_cuda_mul_mat_q (ctx, src0, src1, nullptr , dst);
2040
- } else if (!split && (src0->type == GGML_TYPE_F16 || src0->type == GGML_TYPE_BF16 || src0->type == GGML_TYPE_F32)
2041
- && (src1->type == GGML_TYPE_F16 || src1->type == GGML_TYPE_BF16 || src1->type == GGML_TYPE_F32)
2046
+ } else if (!split && (can_use_batched_cublas_f16 || can_use_batched_cublas_bf16 || can_use_batched_cublas_f32)
2042
2047
&& !ggml_is_transposed (src0) && !ggml_is_transposed (src1) && src1->ne [2 ]*src1->ne [3 ] > 1 ) {
2043
2048
// general KQ + KQV multi-batch without FlashAttention
2044
2049
ggml_cuda_mul_mat_batched_cublas (ctx, src0, src1, dst);
0 commit comments