Skip to content

Commit 5783ae4

Browse files
authored
metal : batch rows copy in a single threadgroup (#14384)
* metal : batch rows copy in a single threadgroup ggml-ci * metal : handle some edge cases when threadgroup size is not a power of 2 ggml-ci
1 parent bf5bcd0 commit 5783ae4

File tree

2 files changed

+45
-9
lines changed

2 files changed

+45
-9
lines changed

ggml/src/ggml-metal/ggml-metal.m

Lines changed: 37 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2450,6 +2450,7 @@ static bool ggml_metal_encode_node(
24502450
nth *= 2;
24512451
}
24522452

2453+
nth = MIN(nth, (int) pipeline.maxTotalThreadsPerThreadgroup);
24532454
nth = MIN(nth, ne00);
24542455

24552456
ggml_metal_kargs_sum_rows args = {
@@ -3780,6 +3781,7 @@ static bool ggml_metal_encode_node(
37803781
nth *= 2;
37813782
}
37823783

3784+
nth = MIN(nth, (int) pipeline.maxTotalThreadsPerThreadgroup);
37833785
nth = MIN(nth, ne00/4);
37843786

37853787
ggml_metal_kargs_rms_norm args = {
@@ -3816,6 +3818,7 @@ static bool ggml_metal_encode_node(
38163818
nth *= 2;
38173819
}
38183820

3821+
nth = MIN(nth, (int) pipeline.maxTotalThreadsPerThreadgroup);
38193822
nth = MIN(nth, ne00/4);
38203823

38213824
ggml_metal_kargs_l2_norm args = {
@@ -3888,6 +3891,7 @@ static bool ggml_metal_encode_node(
38883891
nth *= 2;
38893892
}
38903893

3894+
nth = MIN(nth, (int) pipeline.maxTotalThreadsPerThreadgroup);
38913895
nth = MIN(nth, ne00/4);
38923896

38933897
ggml_metal_kargs_norm args = {
@@ -4974,8 +4978,39 @@ static bool ggml_metal_encode_node(
49744978
default: GGML_ABORT("not implemented");
49754979
}
49764980

4981+
GGML_ASSERT(ne00 % ggml_blck_size(src0->type) == 0);
4982+
4983+
// TODO: support
4984+
//const int32_t nk00 = ne00/ggml_blck_size(dst->type);
4985+
const int32_t nk00 = ne00;
4986+
4987+
int nth = 32; // SIMD width
4988+
4989+
while (nth < nk00 && nth < (int) pipeline.maxTotalThreadsPerThreadgroup) {
4990+
nth *= 2;
4991+
}
4992+
4993+
nth = MIN(nth, (int) pipeline.maxTotalThreadsPerThreadgroup);
4994+
4995+
// when rows are small, we can batch them together in a single threadgroup
4996+
int nrptg = 1;
4997+
4998+
// TODO: relax this constraint in the future
4999+
if (ggml_blck_size(src0->type) == 1 && ggml_blck_size(dst->type) == 1) {
5000+
if (nth > nk00) {
5001+
nrptg = (nth + nk00 - 1)/nk00;
5002+
nth = nk00;
5003+
5004+
if (nrptg*nth > (int) pipeline.maxTotalThreadsPerThreadgroup) {
5005+
nrptg--;
5006+
}
5007+
}
5008+
}
5009+
5010+
nth = MIN(nth, nk00);
5011+
49775012
ggml_metal_kargs_cpy args = {
4978-
/*.ne00 =*/ ne00,
5013+
/*.ne00 =*/ nk00,
49795014
/*.ne01 =*/ ne01,
49805015
/*.ne02 =*/ ne02,
49815016
/*.ne03 =*/ ne03,
@@ -4998,11 +5033,7 @@ static bool ggml_metal_encode_node(
49985033
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
49995034
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
50005035

5001-
GGML_ASSERT(ne00 % ggml_blck_size(src0->type) == 0);
5002-
int nth = MIN(1024, ne00/ggml_blck_size(src0->type));
5003-
5004-
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
5005-
5036+
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + nrptg - 1)/nrptg, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, nrptg, 1)];
50065037
} break;
50075038
case GGML_OP_SET:
50085039
{

ggml/src/ggml-metal/ggml-metal.metal

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4306,11 +4306,16 @@ kernel void kernel_cpy(
43064306
device const char * src0,
43074307
device char * dst,
43084308
uint3 tgpig[[threadgroup_position_in_grid]],
4309+
uint tiitg[[thread_index_in_threadgroup]],
43094310
ushort3 tpitg[[thread_position_in_threadgroup]],
4310-
ushort3 ntg[[threads_per_threadgroup]]) {
4311+
ushort3 tptg[[threads_per_threadgroup]]) {
43114312
const int i03 = tgpig[2];
43124313
const int i02 = tgpig[1];
4313-
const int i01 = tgpig[0];
4314+
const int i01 = tgpig[0]*tptg.y + tiitg/tptg.x;
4315+
4316+
if (i01 >= args.ne01) {
4317+
return;
4318+
}
43144319

43154320
const int64_t n = i03*args.ne02*args.ne01*args.ne00 + i02*args.ne01*args.ne00 + i01*args.ne00;
43164321

@@ -4321,7 +4326,7 @@ kernel void kernel_cpy(
43214326

43224327
device T1 * dst_data = (device T1 *) (dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0);
43234328

4324-
for (int64_t i00 = tpitg.x; i00 < args.ne00; i00 += ntg.x) {
4329+
for (int64_t i00 = tiitg%tptg.x; i00 < args.ne00; i00 += tptg.x) {
43254330
device const T0 * src = (device T0 *)(src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + i00*args.nb00);
43264331
dst_data[i00] = (T1) src[0];
43274332
}

0 commit comments

Comments
 (0)