Skip to content

metal : add special-case mat-vec mul for ne00 == 4 #14385

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jun 26, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 22 additions & 3 deletions ggml/src/ggml-metal/ggml-metal.m
Original file line number Diff line number Diff line change
Expand Up @@ -211,11 +211,14 @@ static void ggml_backend_metal_device_rel(struct ggml_backend_metal_device_conte
GGML_METAL_KERNEL_TYPE_RWKV_WKV6_F32,
GGML_METAL_KERNEL_TYPE_RWKV_WKV7_F32,
GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32,
GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32_C4,
GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32,
GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_C4,
GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_1ROW,
GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_L4,
GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F16,
GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32,
GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_C4,
GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_1ROW,
GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_L4,
GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_BF16,
Expand Down Expand Up @@ -1175,11 +1178,14 @@ @implementation GGMLMetalClass
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RWKV_WKV6_F32, rwkv_wkv6_f32, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RWKV_WKV7_F32, rwkv_wkv7_f32, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32, mul_mv_f32_f32, has_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32_C4, mul_mv_f32_f32_c4, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32, mul_mv_bf16_f32, has_simdgroup_reduction && use_bfloat);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_C4, mul_mv_bf16_f32_c4, use_bfloat);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_1ROW, mul_mv_bf16_f32_1row, has_simdgroup_reduction && use_bfloat);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_L4, mul_mv_bf16_f32_l4, has_simdgroup_reduction && use_bfloat);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_BF16, mul_mv_bf16_bf16, has_simdgroup_reduction && use_bfloat);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32, mul_mv_f16_f32, has_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_C4, mul_mv_f16_f32_c4, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_1ROW, mul_mv_f16_f32_1row, has_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_L4, mul_mv_f16_f32_l4, has_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F16, mul_mv_f16_f16, has_simdgroup_reduction);
Expand Down Expand Up @@ -3110,14 +3116,23 @@ static bool ggml_metal_encode_node(
nsg = 1;
nr0 = 1;
nr1 = 4;
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32].pipeline;
if (ne00 == 4) {
nr0 = 32;
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32_C4].pipeline;
} else {
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32].pipeline;
}
} break;
case GGML_TYPE_F16:
{
nsg = 1;
nr0 = 1;
if (src1t == GGML_TYPE_F32) {
if (ne11 * ne12 < 4) {
if (ne00 == 4) {
nr0 = 32;
nr1 = 4;
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_C4].pipeline;
} else if (ne11 * ne12 < 4) {
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_1ROW].pipeline;
} else if (ne00 >= 128 && ne01 >= 8 && ne00%4 == 0) {
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_L4].pipeline;
Expand All @@ -3136,7 +3151,11 @@ static bool ggml_metal_encode_node(
nsg = 1;
nr0 = 1;
if (src1t == GGML_TYPE_F32) {
if (ne11 * ne12 < 4) {
if (ne00 == 4) {
nr0 = 32;
nr1 = 4;
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_C4].pipeline;
} else if (ne11 * ne12 < 4) {
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_1ROW].pipeline;
} else if (ne00 >= 128 && ne01 >= 8 && ne00%4 == 0) {
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_L4].pipeline;
Expand Down
64 changes: 64 additions & 0 deletions ggml/src/ggml-metal/ggml-metal.metal
Original file line number Diff line number Diff line change
Expand Up @@ -2532,6 +2532,70 @@ template [[host_name("kernel_mul_mv_bf16_f32")]] kernel mul_mv_t kernel_mul_mv<
template [[host_name("kernel_mul_mv_bf16_bf16")]] kernel mul_mv_t kernel_mul_mv<bfloat, bfloat4, bfloat, bfloat4>;
#endif

template<typename T04, typename T14, typename args_t>
void kernel_mul_mv_c4_impl(
args_t args,
device const char * src0,
device const char * src1,
device char * dst,
uint3 tgpig,
ushort tiisg) {
const int r0 = tgpig.x*32 + tiisg;
const int rb = tgpig.y*N_MV_T_T;
const int im = tgpig.z;

if (r0 >= args.ne01) {
return;
}

const uint i12 = im%args.ne12;
const uint i13 = im/args.ne12;

const uint64_t offset0 = r0*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;

device const T04 * x = (device const T04 *) (src0 + offset0);

device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1;

for (int row = 0; row < N_MV_T_T; ++row) {
int r1 = rb + row;
if (r1 >= args.ne11) {
break;
}

const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13;

device const T14 * y = (device const T14 *) (src1 + offset1);

dst_f32[(uint64_t)r1*args.ne0 + r0] = dot((float4) x[0], (float4) y[0]);
}
}

template<typename T04, typename T14>
kernel void kernel_mul_mv_c4(
constant ggml_metal_kargs_mul_mv & args,
device const char * src0,
device const char * src1,
device char * dst,
uint3 tgpig[[threadgroup_position_in_grid]],
ushort tiisg[[thread_index_in_simdgroup]]) {
kernel_mul_mv_c4_impl<T04, T14, constant ggml_metal_kargs_mul_mv &>(
args,
src0,
src1,
dst,
tgpig,
tiisg);
}

typedef decltype(kernel_mul_mv_c4<half4, half4>) mul_mv_c4_t;

template [[host_name("kernel_mul_mv_f32_f32_c4")]] kernel mul_mv_c4_t kernel_mul_mv_c4<float4, float4>;
template [[host_name("kernel_mul_mv_f16_f32_c4")]] kernel mul_mv_c4_t kernel_mul_mv_c4<half4, float4>;
#if defined(GGML_METAL_USE_BF16)
template [[host_name("kernel_mul_mv_bf16_f32_c4")]] kernel mul_mv_c4_t kernel_mul_mv_c4<bfloat4, float4>;
#endif

template<typename T, typename T4>
kernel void kernel_mul_mv_1row(
constant ggml_metal_kargs_mul_mv & args,
Expand Down
72 changes: 39 additions & 33 deletions tests/test-backend-ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4252,39 +4252,45 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
#if 1
for (ggml_type type_a : base_types) {
for (ggml_type type_b : {GGML_TYPE_F32, GGML_TYPE_F16}) {
// test cases without permutation
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, 256, {1, 1}, {1, 1}));
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, 256, {1, 1}, {2, 1}));
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, 256, {1, 1}, {1, 2}));
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, 256, {3, 1}, {1, 1}));
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, 256, {3, 1}, {2, 1}));
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, 256, {3, 2}, {1, 1}));
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, 256, {3, 2}, {2, 1}));
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, 256, {3, 2}, {1, 2}));
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, 256, {3, 2}, {2, 2}));

test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, 256, {1, 1}, {1, 1}));
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, 256, {1, 1}, {2, 1}));
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, 256, {1, 1}, {1, 2}));
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, 256, {3, 1}, {1, 1}));
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, 256, {3, 1}, {2, 1}));
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, 256, {3, 2}, {1, 1}));
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, 256, {3, 2}, {2, 1}));
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, 256, {3, 2}, {1, 2}));
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, 256, {3, 2}, {2, 2}));

// test cases with permutation
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, 256, {2, 3}, {1, 1}, {0, 2, 1, 3}));
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, 256, {2, 3}, {1, 1}, {0, 1, 3, 2}));
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, 256, {2, 3}, {1, 1}, {0, 3, 2, 1}));

test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 8, 256, {2, 3}, {1, 1}, {0, 2, 1, 3}));
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 8, 256, {2, 3}, {1, 1}, {0, 1, 3, 2}));
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 8, 256, {2, 3}, {1, 1}, {0, 3, 2, 1}));

test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, 256, {2, 3}, {1, 1}, {0, 2, 1, 3}));
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, 256, {2, 3}, {1, 1}, {0, 1, 3, 2}));
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, 256, {2, 3}, {1, 1}, {0, 3, 2, 1}));
std::vector<int> ks = { 256 };
if (ggml_blck_size(type_a) == 1) {
ks.push_back(4);
}
for (auto k : ks) {
// test cases without permutation
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, k, {1, 1}, {1, 1}));
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, k, {1, 1}, {2, 1}));
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, k, {1, 1}, {1, 2}));
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, k, {3, 1}, {1, 1}));
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, k, {3, 1}, {2, 1}));
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, k, {3, 2}, {1, 1}));
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, k, {3, 2}, {2, 1}));
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, k, {3, 2}, {1, 2}));
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, k, {3, 2}, {2, 2}));

test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, k, {1, 1}, {1, 1}));
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, k, {1, 1}, {2, 1}));
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, k, {1, 1}, {1, 2}));
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, k, {3, 1}, {1, 1}));
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, k, {3, 1}, {2, 1}));
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, k, {3, 2}, {1, 1}));
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, k, {3, 2}, {2, 1}));
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, k, {3, 2}, {1, 2}));
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, k, {3, 2}, {2, 2}));

// test cases with permutation
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, k, {2, 3}, {1, 1}, {0, 2, 1, 3}));
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, k, {2, 3}, {1, 1}, {0, 1, 3, 2}));
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, k, {2, 3}, {1, 1}, {0, 3, 2, 1}));

test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 8, k, {2, 3}, {1, 1}, {0, 2, 1, 3}));
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 8, k, {2, 3}, {1, 1}, {0, 1, 3, 2}));
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 8, k, {2, 3}, {1, 1}, {0, 3, 2, 1}));

test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, k, {2, 3}, {1, 1}, {0, 2, 1, 3}));
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, k, {2, 3}, {1, 1}, {0, 1, 3, 2}));
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, k, {2, 3}, {1, 1}, {0, 3, 2, 1}));
}

// test cases with large ne00/ne10 to cover stream-k fixup
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, 1024, {3, 2}, {1, 1}));
Expand Down
Loading