@@ -78,6 +78,7 @@ pub(crate) const ARGUMENT_BUFFER_WRAPPER_STRUCT: &str = "NagaArgumentBufferWrapp
78
78
/// allowing them to be conveniently passed to user-defined or wrapper
79
79
/// functions. The struct is declared in [`Writer::write_type_defs`].
80
80
pub ( crate ) const EXTERNAL_TEXTURE_WRAPPER_STRUCT : & str = "NagaExternalTextureWrapper" ;
81
+ pub ( crate ) const COOPERATIVE_LOAD_FUNCTION : & str = "NagaCooperativeLoad" ;
81
82
pub ( crate ) const COOPERATIVE_MULTIPLY_ADD_FUNCTION : & str = "NagaCooperativeMultiplyAdd" ;
82
83
83
84
/// Write the Metal name for a Naga numeric type: scalar, vector, or matrix.
@@ -484,6 +485,12 @@ enum WrappedFunction {
484
485
ImageQuerySize {
485
486
class : crate :: ImageClass ,
486
487
} ,
488
+ CooperativeLoad {
489
+ space : crate :: AddressSpace ,
490
+ columns : crate :: CooperativeSize ,
491
+ rows : crate :: CooperativeSize ,
492
+ scalar : crate :: Scalar ,
493
+ } ,
487
494
CooperativeMultiplyAdd {
488
495
space : crate :: AddressSpace ,
489
496
columns : crate :: CooperativeSize ,
@@ -2842,6 +2849,17 @@ impl<W: Write> Writer<W> {
2842
2849
}
2843
2850
write ! ( self . out, "}}" ) ?;
2844
2851
}
2852
+ crate :: Expression :: CooperativeLoad { ref data, .. } => {
2853
+ if context. lang_version < ( 2 , 3 ) {
2854
+ return Err ( Error :: UnsupportedCooperativeMatrix ) ;
2855
+ }
2856
+ write ! ( self . out, "{COOPERATIVE_LOAD_FUNCTION}(" ) ?;
2857
+ write ! ( self . out, "&" ) ?;
2858
+ self . put_access_chain ( data. pointer , context. policies . index , context) ?;
2859
+ write ! ( self . out, ", " ) ?;
2860
+ self . put_expression ( data. stride , context, true ) ?;
2861
+ write ! ( self . out, ", {})" , data. row_major) ?;
2862
+ }
2845
2863
crate :: Expression :: CooperativeMultiplyAdd { a, b, c } => {
2846
2864
if context. lang_version < ( 2 , 3 ) {
2847
2865
return Err ( Error :: UnsupportedCooperativeMatrix ) ;
@@ -4235,25 +4253,18 @@ impl<W: Write> Writer<W> {
4235
4253
}
4236
4254
writeln ! ( self . out, ");" ) ?;
4237
4255
}
4238
- crate :: Statement :: CooperativeLoadStore {
4239
- store,
4240
- target,
4241
- pointer,
4242
- stride,
4243
- row_major,
4244
- } => {
4245
- let op_str = if store { "store" } else { "load" } ;
4246
- write ! ( self . out, "{level}simdgroup_{op_str}(" ) ?;
4256
+ crate :: Statement :: CooperativeStore { target, ref data } => {
4257
+ write ! ( self . out, "{level}simdgroup_store(" ) ?;
4247
4258
self . put_expression ( target, & context. expression , true ) ?;
4248
4259
write ! ( self . out, ", &" ) ?;
4249
4260
self . put_access_chain (
4250
- pointer,
4261
+ data . pointer ,
4251
4262
context. expression . policies . index ,
4252
4263
& context. expression ,
4253
4264
) ?;
4254
4265
write ! ( self . out, ", " ) ?;
4255
- self . put_expression ( stride, & context. expression , true ) ?;
4256
- if row_major {
4266
+ self . put_expression ( data . stride , & context. expression , true ) ?;
4267
+ if data . row_major {
4257
4268
let matrix_origin = "0" ;
4258
4269
let transpose = true ;
4259
4270
write ! ( self . out, ", {matrix_origin}, {transpose}" ) ?;
@@ -6316,6 +6327,55 @@ template <typename A>
6316
6327
Ok ( ( ) )
6317
6328
}
6318
6329
6330
+ fn write_wrapped_cooperative_load (
6331
+ & mut self ,
6332
+ module : & crate :: Module ,
6333
+ func_ctx : & back:: FunctionCtx ,
6334
+ columns : crate :: CooperativeSize ,
6335
+ rows : crate :: CooperativeSize ,
6336
+ pointer : Handle < crate :: Expression > ,
6337
+ ) -> BackendResult {
6338
+ let ptr_ty = func_ctx. resolve_type ( pointer, & module. types ) ;
6339
+ let space = ptr_ty. pointer_space ( ) . unwrap ( ) ;
6340
+ let scalar = ptr_ty
6341
+ . pointer_base_type ( )
6342
+ . unwrap ( )
6343
+ . inner_with ( & module. types )
6344
+ . scalar ( )
6345
+ . unwrap ( ) ;
6346
+ let wrapped = WrappedFunction :: CooperativeLoad {
6347
+ space,
6348
+ columns,
6349
+ rows,
6350
+ scalar,
6351
+ } ;
6352
+ if !self . wrapped_functions . insert ( wrapped) {
6353
+ return Ok ( ( ) ) ;
6354
+ }
6355
+ let space_name = space. to_msl_name ( ) . unwrap_or_default ( ) ;
6356
+ let scalar_name = scalar. to_msl_name ( ) ;
6357
+ writeln ! (
6358
+ self . out,
6359
+ "{NAMESPACE}::simdgroup_{scalar_name}{}x{} {COOPERATIVE_LOAD_FUNCTION}(const {space_name} {scalar_name}* ptr, int stride, bool is_row_major) {{" ,
6360
+ columns as u32 , rows as u32 ,
6361
+ ) ?;
6362
+ let l1 = back:: Level ( 1 ) ;
6363
+ writeln ! (
6364
+ self . out,
6365
+ "{l1}{NAMESPACE}::simdgroup_{scalar_name}{}x{} m;" ,
6366
+ columns as u32 , rows as u32
6367
+ ) ?;
6368
+ let matrix_origin = "0" ;
6369
+ writeln ! (
6370
+ self . out,
6371
+ "{l1}simdgroup_load(m, ptr, stride, {matrix_origin}, is_row_major);"
6372
+ ) ?;
6373
+ writeln ! ( self . out, "{l1}return m;" ) ?;
6374
+ writeln ! ( self . out, "}}" ) ?;
6375
+ writeln ! ( self . out) ?;
6376
+ Ok ( ( ) )
6377
+ }
6378
+
6319
6379
fn write_wrapped_cooperative_multiply_add (
6320
6380
& mut self ,
6321
6381
module : & crate :: Module ,
@@ -6441,6 +6501,20 @@ template <typename A>
6441
6501
crate :: Expression :: ImageQuery { image, query } => {
6442
6502
self . write_wrapped_image_query ( module, func_ctx, image, query) ?;
6443
6503
}
6504
+ crate :: Expression :: CooperativeLoad {
6505
+ columns,
6506
+ rows,
6507
+ role : _,
6508
+ ref data,
6509
+ } => {
6510
+ self . write_wrapped_cooperative_load (
6511
+ module,
6512
+ func_ctx,
6513
+ columns,
6514
+ rows,
6515
+ data. pointer ,
6516
+ ) ?;
6517
+ }
6444
6518
crate :: Expression :: CooperativeMultiplyAdd { a, b, c : _ } => {
6445
6519
let space = crate :: AddressSpace :: Private ;
6446
6520
self . write_wrapped_cooperative_multiply_add ( module, func_ctx, space, a, b) ?;
0 commit comments