@@ -2450,6 +2450,7 @@ static bool ggml_metal_encode_node(
2450
2450
nth *= 2 ;
2451
2451
}
2452
2452
2453
+ nth = MIN (nth, (int ) pipeline.maxTotalThreadsPerThreadgroup );
2453
2454
nth = MIN (nth, ne00);
2454
2455
2455
2456
ggml_metal_kargs_sum_rows args = {
@@ -3780,6 +3781,7 @@ static bool ggml_metal_encode_node(
3780
3781
nth *= 2 ;
3781
3782
}
3782
3783
3784
+ nth = MIN (nth, (int ) pipeline.maxTotalThreadsPerThreadgroup );
3783
3785
nth = MIN (nth, ne00/4 );
3784
3786
3785
3787
ggml_metal_kargs_rms_norm args = {
@@ -3816,6 +3818,7 @@ static bool ggml_metal_encode_node(
3816
3818
nth *= 2 ;
3817
3819
}
3818
3820
3821
+ nth = MIN (nth, (int ) pipeline.maxTotalThreadsPerThreadgroup );
3819
3822
nth = MIN (nth, ne00/4 );
3820
3823
3821
3824
ggml_metal_kargs_l2_norm args = {
@@ -3888,6 +3891,7 @@ static bool ggml_metal_encode_node(
3888
3891
nth *= 2 ;
3889
3892
}
3890
3893
3894
+ nth = MIN (nth, (int ) pipeline.maxTotalThreadsPerThreadgroup );
3891
3895
nth = MIN (nth, ne00/4 );
3892
3896
3893
3897
ggml_metal_kargs_norm args = {
@@ -4974,8 +4978,39 @@ static bool ggml_metal_encode_node(
4974
4978
default : GGML_ABORT (" not implemented" );
4975
4979
}
4976
4980
4981
+ GGML_ASSERT (ne00 % ggml_blck_size (src0->type ) == 0 );
4982
+
4983
+ // TODO: support
4984
+ // const int32_t nk00 = ne00/ggml_blck_size(dst->type);
4985
+ const int32_t nk00 = ne00;
4986
+
4987
+ int nth = 32 ; // SIMD width
4988
+
4989
+ while (nth < nk00 && nth < (int ) pipeline.maxTotalThreadsPerThreadgroup ) {
4990
+ nth *= 2 ;
4991
+ }
4992
+
4993
+ nth = MIN (nth, (int ) pipeline.maxTotalThreadsPerThreadgroup );
4994
+
4995
+ // when rows are small, we can batch them together in a single threadgroup
4996
+ int nrptg = 1 ;
4997
+
4998
+ // TODO: relax this constraint in the future
4999
+ if (ggml_blck_size (src0->type ) == 1 && ggml_blck_size (dst->type ) == 1 ) {
5000
+ if (nth > nk00) {
5001
+ nrptg = (nth + nk00 - 1 )/nk00;
5002
+ nth = nk00;
5003
+
5004
+ if (nrptg*nth > (int ) pipeline.maxTotalThreadsPerThreadgroup ) {
5005
+ nrptg--;
5006
+ }
5007
+ }
5008
+ }
5009
+
5010
+ nth = MIN (nth, nk00);
5011
+
4977
5012
ggml_metal_kargs_cpy args = {
4978
- /* .ne00 =*/ ne00 ,
5013
+ /* .ne00 =*/ nk00 ,
4979
5014
/* .ne01 =*/ ne01,
4980
5015
/* .ne02 =*/ ne02,
4981
5016
/* .ne03 =*/ ne03,
@@ -4998,11 +5033,7 @@ static bool ggml_metal_encode_node(
4998
5033
[encoder setBuffer: id_src0 offset: offs_src0 atIndex: 1 ];
4999
5034
[encoder setBuffer: id_dst offset: offs_dst atIndex: 2 ];
5000
5035
5001
- GGML_ASSERT (ne00 % ggml_blck_size (src0->type ) == 0 );
5002
- int nth = MIN (1024 , ne00/ggml_blck_size (src0->type ));
5003
-
5004
- [encoder dispatchThreadgroups: MTLSizeMake (ne01, ne02, ne03) threadsPerThreadgroup: MTLSizeMake (nth, 1 , 1 )];
5005
-
5036
+ [encoder dispatchThreadgroups: MTLSizeMake ((ne01 + nrptg - 1 )/nrptg, ne02, ne03) threadsPerThreadgroup: MTLSizeMake (nth, nrptg, 1 )];
5006
5037
} break ;
5007
5038
case GGML_OP_SET:
5008
5039
{
0 commit comments