Skip to content

Commit e8215db

Browse files
authored
metal : add special-case mat-vec mul for ne00 == 4 (#14385)
ggml-ci
1 parent 5783ae4 commit e8215db

File tree

3 files changed

+125
-36
lines changed

3 files changed

+125
-36
lines changed

ggml/src/ggml-metal/ggml-metal.m

Lines changed: 22 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -211,11 +211,14 @@ static void ggml_backend_metal_device_rel(struct ggml_backend_metal_device_conte
211211
GGML_METAL_KERNEL_TYPE_RWKV_WKV6_F32,
212212
GGML_METAL_KERNEL_TYPE_RWKV_WKV7_F32,
213213
GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32,
214+
GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32_C4,
214215
GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32,
216+
GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_C4,
215217
GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_1ROW,
216218
GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_L4,
217219
GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F16,
218220
GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32,
221+
GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_C4,
219222
GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_1ROW,
220223
GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_L4,
221224
GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_BF16,
@@ -1175,11 +1178,14 @@ @implementation GGMLMetalClass
11751178
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RWKV_WKV6_F32, rwkv_wkv6_f32, true);
11761179
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RWKV_WKV7_F32, rwkv_wkv7_f32, true);
11771180
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32, mul_mv_f32_f32, has_simdgroup_reduction);
1181+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32_C4, mul_mv_f32_f32_c4, true);
11781182
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32, mul_mv_bf16_f32, has_simdgroup_reduction && use_bfloat);
1183+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_C4, mul_mv_bf16_f32_c4, use_bfloat);
11791184
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_1ROW, mul_mv_bf16_f32_1row, has_simdgroup_reduction && use_bfloat);
11801185
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_L4, mul_mv_bf16_f32_l4, has_simdgroup_reduction && use_bfloat);
11811186
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_BF16, mul_mv_bf16_bf16, has_simdgroup_reduction && use_bfloat);
11821187
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32, mul_mv_f16_f32, has_simdgroup_reduction);
1188+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_C4, mul_mv_f16_f32_c4, true);
11831189
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_1ROW, mul_mv_f16_f32_1row, has_simdgroup_reduction);
11841190
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_L4, mul_mv_f16_f32_l4, has_simdgroup_reduction);
11851191
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F16, mul_mv_f16_f16, has_simdgroup_reduction);
@@ -3111,14 +3117,23 @@ static bool ggml_metal_encode_node(
31113117
nsg = 1;
31123118
nr0 = 1;
31133119
nr1 = 4;
3114-
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32].pipeline;
3120+
if (ne00 == 4) {
3121+
nr0 = 32;
3122+
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32_C4].pipeline;
3123+
} else {
3124+
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32].pipeline;
3125+
}
31153126
} break;
31163127
case GGML_TYPE_F16:
31173128
{
31183129
nsg = 1;
31193130
nr0 = 1;
31203131
if (src1t == GGML_TYPE_F32) {
3121-
if (ne11 * ne12 < 4) {
3132+
if (ne00 == 4) {
3133+
nr0 = 32;
3134+
nr1 = 4;
3135+
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_C4].pipeline;
3136+
} else if (ne11 * ne12 < 4) {
31223137
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_1ROW].pipeline;
31233138
} else if (ne00 >= 128 && ne01 >= 8 && ne00%4 == 0) {
31243139
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_L4].pipeline;
@@ -3137,7 +3152,11 @@ static bool ggml_metal_encode_node(
31373152
nsg = 1;
31383153
nr0 = 1;
31393154
if (src1t == GGML_TYPE_F32) {
3140-
if (ne11 * ne12 < 4) {
3155+
if (ne00 == 4) {
3156+
nr0 = 32;
3157+
nr1 = 4;
3158+
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_C4].pipeline;
3159+
} else if (ne11 * ne12 < 4) {
31413160
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_1ROW].pipeline;
31423161
} else if (ne00 >= 128 && ne01 >= 8 && ne00%4 == 0) {
31433162
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_L4].pipeline;

ggml/src/ggml-metal/ggml-metal.metal

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2532,6 +2532,70 @@ template [[host_name("kernel_mul_mv_bf16_f32")]] kernel mul_mv_t kernel_mul_mv<
25322532
template [[host_name("kernel_mul_mv_bf16_bf16")]] kernel mul_mv_t kernel_mul_mv<bfloat, bfloat4, bfloat, bfloat4>;
25332533
#endif
25342534

2535+
template<typename T04, typename T14, typename args_t>
2536+
void kernel_mul_mv_c4_impl(
2537+
args_t args,
2538+
device const char * src0,
2539+
device const char * src1,
2540+
device char * dst,
2541+
uint3 tgpig,
2542+
ushort tiisg) {
2543+
const int r0 = tgpig.x*32 + tiisg;
2544+
const int rb = tgpig.y*N_MV_T_T;
2545+
const int im = tgpig.z;
2546+
2547+
if (r0 >= args.ne01) {
2548+
return;
2549+
}
2550+
2551+
const uint i12 = im%args.ne12;
2552+
const uint i13 = im/args.ne12;
2553+
2554+
const uint64_t offset0 = r0*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
2555+
2556+
device const T04 * x = (device const T04 *) (src0 + offset0);
2557+
2558+
device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1;
2559+
2560+
for (int row = 0; row < N_MV_T_T; ++row) {
2561+
int r1 = rb + row;
2562+
if (r1 >= args.ne11) {
2563+
break;
2564+
}
2565+
2566+
const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13;
2567+
2568+
device const T14 * y = (device const T14 *) (src1 + offset1);
2569+
2570+
dst_f32[(uint64_t)r1*args.ne0 + r0] = dot((float4) x[0], (float4) y[0]);
2571+
}
2572+
}
2573+
2574+
template<typename T04, typename T14>
2575+
kernel void kernel_mul_mv_c4(
2576+
constant ggml_metal_kargs_mul_mv & args,
2577+
device const char * src0,
2578+
device const char * src1,
2579+
device char * dst,
2580+
uint3 tgpig[[threadgroup_position_in_grid]],
2581+
ushort tiisg[[thread_index_in_simdgroup]]) {
2582+
kernel_mul_mv_c4_impl<T04, T14, constant ggml_metal_kargs_mul_mv &>(
2583+
args,
2584+
src0,
2585+
src1,
2586+
dst,
2587+
tgpig,
2588+
tiisg);
2589+
}
2590+
2591+
typedef decltype(kernel_mul_mv_c4<half4, half4>) mul_mv_c4_t;
2592+
2593+
template [[host_name("kernel_mul_mv_f32_f32_c4")]] kernel mul_mv_c4_t kernel_mul_mv_c4<float4, float4>;
2594+
template [[host_name("kernel_mul_mv_f16_f32_c4")]] kernel mul_mv_c4_t kernel_mul_mv_c4<half4, float4>;
2595+
#if defined(GGML_METAL_USE_BF16)
2596+
template [[host_name("kernel_mul_mv_bf16_f32_c4")]] kernel mul_mv_c4_t kernel_mul_mv_c4<bfloat4, float4>;
2597+
#endif
2598+
25352599
template<typename T, typename T4>
25362600
kernel void kernel_mul_mv_1row(
25372601
constant ggml_metal_kargs_mul_mv & args,

tests/test-backend-ops.cpp

Lines changed: 39 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -4252,39 +4252,45 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
42524252
#if 1
42534253
for (ggml_type type_a : base_types) {
42544254
for (ggml_type type_b : {GGML_TYPE_F32, GGML_TYPE_F16}) {
4255-
// test cases without permutation
4256-
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, 256, {1, 1}, {1, 1}));
4257-
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, 256, {1, 1}, {2, 1}));
4258-
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, 256, {1, 1}, {1, 2}));
4259-
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, 256, {3, 1}, {1, 1}));
4260-
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, 256, {3, 1}, {2, 1}));
4261-
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, 256, {3, 2}, {1, 1}));
4262-
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, 256, {3, 2}, {2, 1}));
4263-
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, 256, {3, 2}, {1, 2}));
4264-
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, 256, {3, 2}, {2, 2}));
4265-
4266-
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, 256, {1, 1}, {1, 1}));
4267-
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, 256, {1, 1}, {2, 1}));
4268-
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, 256, {1, 1}, {1, 2}));
4269-
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, 256, {3, 1}, {1, 1}));
4270-
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, 256, {3, 1}, {2, 1}));
4271-
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, 256, {3, 2}, {1, 1}));
4272-
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, 256, {3, 2}, {2, 1}));
4273-
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, 256, {3, 2}, {1, 2}));
4274-
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, 256, {3, 2}, {2, 2}));
4275-
4276-
// test cases with permutation
4277-
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, 256, {2, 3}, {1, 1}, {0, 2, 1, 3}));
4278-
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, 256, {2, 3}, {1, 1}, {0, 1, 3, 2}));
4279-
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, 256, {2, 3}, {1, 1}, {0, 3, 2, 1}));
4280-
4281-
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 8, 256, {2, 3}, {1, 1}, {0, 2, 1, 3}));
4282-
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 8, 256, {2, 3}, {1, 1}, {0, 1, 3, 2}));
4283-
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 8, 256, {2, 3}, {1, 1}, {0, 3, 2, 1}));
4284-
4285-
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, 256, {2, 3}, {1, 1}, {0, 2, 1, 3}));
4286-
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, 256, {2, 3}, {1, 1}, {0, 1, 3, 2}));
4287-
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, 256, {2, 3}, {1, 1}, {0, 3, 2, 1}));
4255+
std::vector<int> ks = { 256 };
4256+
if (ggml_blck_size(type_a) == 1) {
4257+
ks.push_back(4);
4258+
}
4259+
for (auto k : ks) {
4260+
// test cases without permutation
4261+
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, k, {1, 1}, {1, 1}));
4262+
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, k, {1, 1}, {2, 1}));
4263+
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, k, {1, 1}, {1, 2}));
4264+
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, k, {3, 1}, {1, 1}));
4265+
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, k, {3, 1}, {2, 1}));
4266+
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, k, {3, 2}, {1, 1}));
4267+
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, k, {3, 2}, {2, 1}));
4268+
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, k, {3, 2}, {1, 2}));
4269+
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, k, {3, 2}, {2, 2}));
4270+
4271+
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, k, {1, 1}, {1, 1}));
4272+
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, k, {1, 1}, {2, 1}));
4273+
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, k, {1, 1}, {1, 2}));
4274+
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, k, {3, 1}, {1, 1}));
4275+
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, k, {3, 1}, {2, 1}));
4276+
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, k, {3, 2}, {1, 1}));
4277+
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, k, {3, 2}, {2, 1}));
4278+
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, k, {3, 2}, {1, 2}));
4279+
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, k, {3, 2}, {2, 2}));
4280+
4281+
// test cases with permutation
4282+
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, k, {2, 3}, {1, 1}, {0, 2, 1, 3}));
4283+
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, k, {2, 3}, {1, 1}, {0, 1, 3, 2}));
4284+
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, k, {2, 3}, {1, 1}, {0, 3, 2, 1}));
4285+
4286+
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 8, k, {2, 3}, {1, 1}, {0, 2, 1, 3}));
4287+
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 8, k, {2, 3}, {1, 1}, {0, 1, 3, 2}));
4288+
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 8, k, {2, 3}, {1, 1}, {0, 3, 2, 1}));
4289+
4290+
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, k, {2, 3}, {1, 1}, {0, 2, 1, 3}));
4291+
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, k, {2, 3}, {1, 1}, {0, 1, 3, 2}));
4292+
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, k, {2, 3}, {1, 1}, {0, 3, 2, 1}));
4293+
}
42884294

42894295
// test cases with large ne00/ne10 to cover stream-k fixup
42904296
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, 1024, {3, 2}, {1, 1}));

0 commit comments

Comments
 (0)