@@ -78,6 +78,7 @@ pub(crate) const ARGUMENT_BUFFER_WRAPPER_STRUCT: &str = "NagaArgumentBufferWrapp
78
78
/// allowing them to be conveniently passed to user-defined or wrapper
79
79
/// functions. The struct is declared in [`Writer::write_type_defs`].
80
80
pub ( crate ) const EXTERNAL_TEXTURE_WRAPPER_STRUCT : & str = "NagaExternalTextureWrapper" ;
81
+ pub ( crate ) const COOPERATIVE_MULTIPLY_ADD_FUNCTION : & str = "NagaCooperativeMultiplyAdd" ;
81
82
82
83
/// Write the Metal name for a Naga numeric type: scalar, vector, or matrix.
83
84
///
@@ -483,6 +484,12 @@ enum WrappedFunction {
483
484
ImageQuerySize {
484
485
class : crate :: ImageClass ,
485
486
} ,
487
+ CooperativeMultiplyAdd {
488
+ columns : crate :: CooperativeSize ,
489
+ rows : crate :: CooperativeSize ,
490
+ intermediate : crate :: CooperativeSize ,
491
+ scalar : crate :: Scalar ,
492
+ } ,
486
493
}
487
494
488
495
pub struct Writer < W > {
@@ -543,14 +550,6 @@ impl crate::Scalar {
543
550
}
544
551
}
545
552
546
- impl crate :: CooperativeScalar {
547
- const fn to_msl_name ( self ) -> & ' static str {
548
- match self {
549
- Self :: F32 => "float" ,
550
- }
551
- }
552
- }
553
-
554
553
const fn separate ( need_separator : bool ) -> & ' static str {
555
554
if need_separator {
556
555
","
@@ -2842,12 +2841,14 @@ impl<W: Write> Writer<W> {
2842
2841
}
2843
2842
write ! ( self . out, "}}" ) ?;
2844
2843
}
2845
- crate :: Expression :: MulAdd { a, b, c } => {
2846
- self . put_expression ( a, context, false ) ?;
2847
- write ! ( self . out, " * " ) ?;
2848
- self . put_expression ( b, context, false ) ?;
2849
- write ! ( self . out, " + " ) ?;
2850
- self . put_expression ( c, context, false ) ?;
2844
+ crate :: Expression :: CooperativeMultiplyAdd { a, b, c } => {
2845
+ write ! ( self . out, "{COOPERATIVE_MULTIPLY_ADD_FUNCTION}(" ) ?;
2846
+ self . put_expression ( a, context, true ) ?;
2847
+ write ! ( self . out, ", " ) ?;
2848
+ self . put_expression ( b, context, true ) ?;
2849
+ write ! ( self . out, ", " ) ?;
2850
+ self . put_expression ( c, context, true ) ?;
2851
+ write ! ( self . out, ")" ) ?;
2851
2852
}
2852
2853
}
2853
2854
Ok ( ( ) )
@@ -4230,6 +4231,49 @@ impl<W: Write> Writer<W> {
4230
4231
}
4231
4232
writeln ! ( self . out, ");" ) ?;
4232
4233
}
4234
+ crate :: Statement :: CooperativeLoadStore {
4235
+ store,
4236
+ target,
4237
+ pointer,
4238
+ stride,
4239
+ row_major,
4240
+ } => {
4241
+ let op_str = if store { "store" } else { "load" } ;
4242
+ write ! ( self . out, "{level}{NAMESPACE}::simdgroup_{op_str}(" ) ?;
4243
+ self . put_expression ( target, & context. expression , true ) ?;
4244
+ write ! ( self . out, ", " ) ?;
4245
+ self . put_expression ( pointer, & context. expression , true ) ?;
4246
+ if stride. is_some ( ) || row_major {
4247
+ write ! ( self . out, ", " ) ?;
4248
+ match stride {
4249
+ Some ( expression) => {
4250
+ self . put_expression ( expression, & context. expression , true ) ?;
4251
+ }
4252
+ None => {
4253
+ let default_stride = match * context. expression . resolve_type ( target)
4254
+ {
4255
+ crate :: TypeInner :: CooperativeMatrix {
4256
+ columns, rows, ..
4257
+ } => {
4258
+ if row_major {
4259
+ columns as u32
4260
+ } else {
4261
+ rows as u32
4262
+ }
4263
+ }
4264
+ _ => 0 ,
4265
+ } ;
4266
+ write ! ( self . out, "{default_stride}" ) ?;
4267
+ }
4268
+ }
4269
+ }
4270
+ if row_major {
4271
+ let matrix_origin = "0" ;
4272
+ let transpose = true ;
4273
+ write ! ( self . out, ", {matrix_origin}, {transpose}" ) ?;
4274
+ }
4275
+ writeln ! ( self . out, ");" ) ?;
4276
+ }
4233
4277
}
4234
4278
}
4235
4279
@@ -6286,6 +6330,62 @@ template <typename A>
6286
6330
Ok ( ( ) )
6287
6331
}
6288
6332
6333
+ fn write_wrapped_cooperative_multiply_add (
6334
+ & mut self ,
6335
+ module : & crate :: Module ,
6336
+ func_ctx : & back:: FunctionCtx ,
6337
+ a : Handle < crate :: Expression > ,
6338
+ b : Handle < crate :: Expression > ,
6339
+ ) -> BackendResult {
6340
+ let ( a_c, a_r, scalar) = match * func_ctx. resolve_type ( a, & module. types ) {
6341
+ crate :: TypeInner :: CooperativeMatrix {
6342
+ columns,
6343
+ rows,
6344
+ scalar,
6345
+ ..
6346
+ } => ( columns, rows, scalar) ,
6347
+ _ => unreachable ! ( ) ,
6348
+ } ;
6349
+ let ( b_c, b_r) = match * func_ctx. resolve_type ( b, & module. types ) {
6350
+ crate :: TypeInner :: CooperativeMatrix { columns, rows, .. } => ( columns, rows) ,
6351
+ _ => unreachable ! ( ) ,
6352
+ } ;
6353
+ let wrapped = WrappedFunction :: CooperativeMultiplyAdd {
6354
+ columns : b_c,
6355
+ rows : a_r,
6356
+ intermediate : a_c,
6357
+ scalar,
6358
+ } ;
6359
+ if !self . wrapped_functions . insert ( wrapped) {
6360
+ return Ok ( ( ) ) ;
6361
+ }
6362
+ let scalar_name = match scalar. width {
6363
+ 2 => "half" ,
6364
+ 4 => "float" ,
6365
+ 8 => "double" ,
6366
+ _ => unreachable ! ( ) ,
6367
+ } ;
6368
+ writeln ! (
6369
+ self . out,
6370
+ "{NAMESPACE}::simdgroup_{scalar_name}{}x{} {COOPERATIVE_MULTIPLY_ADD_FUNCTION}(const {NAMESPACE}::simdgroup_{scalar_name}{}x{}& a, const {NAMESPACE}::simdgroup_{scalar_name}{}x{}& b, const {NAMESPACE}::simdgroup_{scalar_name}{}x{}& c) {{" ,
6371
+ b_c as u32 , a_r as u32 , a_c as u32 , a_r as u32 , b_c as u32 , b_r as u32 , b_c as u32 , a_r as u32 ,
6372
+ ) ?;
6373
+ let l1 = back:: Level ( 1 ) ;
6374
+ writeln ! (
6375
+ self . out,
6376
+ "{l1}{NAMESPACE}::simdgroup_{scalar_name}{}x{} d;" ,
6377
+ b_c as u32 , a_r as u32
6378
+ ) ?;
6379
+ writeln ! (
6380
+ self . out,
6381
+ "{l1}{NAMESPACE}::simdgroup_multiply_accumulate(d,a,b,c);"
6382
+ ) ?;
6383
+ writeln ! ( self . out, "{l1}return d;" ) ?;
6384
+ writeln ! ( self . out, "}}" ) ?;
6385
+ writeln ! ( self . out) ?;
6386
+ Ok ( ( ) )
6387
+ }
6388
+
6289
6389
pub ( super ) fn write_wrapped_functions (
6290
6390
& mut self ,
6291
6391
module : & crate :: Module ,
@@ -6360,6 +6460,9 @@ template <typename A>
6360
6460
crate :: Expression :: ImageQuery { image, query } => {
6361
6461
self . write_wrapped_image_query ( module, func_ctx, image, query) ?;
6362
6462
}
6463
+ crate :: Expression :: CooperativeMultiplyAdd { a, b, c : _ } => {
6464
+ self . write_wrapped_cooperative_multiply_add ( module, func_ctx, a, b) ?;
6465
+ }
6363
6466
_ => { }
6364
6467
}
6365
6468
}
0 commit comments