Skip to content

Commit 430d104

Browse files
committed
coop: Implement Load/Store statement
1 parent 39cab57 commit 430d104

35 files changed

+798
-234
lines changed

naga/src/back/dot/mod.rs

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -403,6 +403,24 @@ impl StatementGraph {
403403
},
404404
}
405405
}
406+
S::CooperativeLoadStore {
407+
store,
408+
target,
409+
pointer,
410+
stride,
411+
row_major: _,
412+
} => {
413+
self.dependencies.push((id, target, "target"));
414+
self.dependencies.push((id, pointer, "pointer"));
415+
if let Some(stride) = stride {
416+
self.dependencies.push((id, stride, "stride"));
417+
}
418+
if store {
419+
"Store"
420+
} else {
421+
"Load"
422+
}
423+
}
406424
};
407425
// Set the last node to the merge node
408426
last_node = merge_id;
@@ -742,11 +760,11 @@ fn write_function_expressions(
742760
let ty = if committed { "Committed" } else { "Candidate" };
743761
(format!("get{ty}HitVertexPositions").into(), 4)
744762
}
745-
E::MulAdd { a, b, c } => {
763+
E::CooperativeMultiplyAdd { a, b, c } => {
746764
edges.insert("a", a);
747765
edges.insert("b", b);
748766
edges.insert("c", c);
749-
("MulAdd".into(), 6)
767+
("cooperativeMultiplyAdd".into(), 4)
750768
}
751769
};
752770

naga/src/back/glsl/mod.rs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2805,6 +2805,7 @@ impl<'a, W: Write> Writer<'a, W> {
28052805
}
28062806
writeln!(self.out, ");")?;
28072807
}
2808+
Statement::CooperativeLoadStore { .. } => unimplemented!(),
28082809
}
28092810

28102811
Ok(())
@@ -4342,7 +4343,7 @@ impl<'a, W: Write> Writer<'a, W> {
43424343
// not supported yet
43434344
Expression::RayQueryGetIntersection { .. }
43444345
| Expression::RayQueryVertexPositions { .. }
4345-
| Expression::MulAdd { .. } => unreachable!(),
4346+
| Expression::CooperativeMultiplyAdd { .. } => unreachable!(),
43464347
}
43474348

43484349
Ok(())

naga/src/back/hlsl/writer.rs

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2747,6 +2747,7 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
27472747
}
27482748
writeln!(self.out, ");")?;
27492749
}
2750+
Statement::CooperativeLoadStore { .. } => unimplemented!(),
27502751
}
27512752

27522753
Ok(())
@@ -4275,7 +4276,8 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
42754276
}
42764277
}
42774278
// Not supported yet
4278-
Expression::RayQueryVertexPositions { .. } | Expression::MulAdd { .. } => {
4279+
Expression::RayQueryVertexPositions { .. }
4280+
| Expression::CooperativeMultiplyAdd { .. } => {
42794281
unreachable!()
42804282
}
42814283
// Nothing to do here, since call expression already cached

naga/src/back/mod.rs

Lines changed: 0 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -311,18 +311,6 @@ pub const fn binary_operation_str(op: crate::BinaryOperator) -> &'static str {
311311
}
312312
}
313313

314-
impl crate::TypeInner {
315-
/// Returns true if this is a handle to a type rather than the type directly.
316-
pub const fn is_handle(&self) -> bool {
317-
match *self {
318-
crate::TypeInner::Image { .. }
319-
| crate::TypeInner::Sampler { .. }
320-
| crate::TypeInner::AccelerationStructure { .. } => true,
321-
_ => false,
322-
}
323-
}
324-
}
325-
326314
impl crate::Statement {
327315
/// Returns true if the statement directly terminates the current block.
328316
///

naga/src/back/msl/writer.rs

Lines changed: 117 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,7 @@ pub(crate) const ARGUMENT_BUFFER_WRAPPER_STRUCT: &str = "NagaArgumentBufferWrapp
7878
/// allowing them to be conveniently passed to user-defined or wrapper
7979
/// functions. The struct is declared in [`Writer::write_type_defs`].
8080
pub(crate) const EXTERNAL_TEXTURE_WRAPPER_STRUCT: &str = "NagaExternalTextureWrapper";
81+
pub(crate) const COOPERATIVE_MULTIPLY_ADD_FUNCTION: &str = "NagaCooperativeMultiplyAdd";
8182

8283
/// Write the Metal name for a Naga numeric type: scalar, vector, or matrix.
8384
///
@@ -483,6 +484,12 @@ enum WrappedFunction {
483484
ImageQuerySize {
484485
class: crate::ImageClass,
485486
},
487+
CooperativeMultiplyAdd {
488+
columns: crate::CooperativeSize,
489+
rows: crate::CooperativeSize,
490+
intermediate: crate::CooperativeSize,
491+
scalar: crate::Scalar,
492+
},
486493
}
487494

488495
pub struct Writer<W> {
@@ -543,14 +550,6 @@ impl crate::Scalar {
543550
}
544551
}
545552

546-
impl crate::CooperativeScalar {
547-
const fn to_msl_name(self) -> &'static str {
548-
match self {
549-
Self::F32 => "float",
550-
}
551-
}
552-
}
553-
554553
const fn separate(need_separator: bool) -> &'static str {
555554
if need_separator {
556555
","
@@ -2842,12 +2841,14 @@ impl<W: Write> Writer<W> {
28422841
}
28432842
write!(self.out, "}}")?;
28442843
}
2845-
crate::Expression::MulAdd { a, b, c } => {
2846-
self.put_expression(a, context, false)?;
2847-
write!(self.out, " * ")?;
2848-
self.put_expression(b, context, false)?;
2849-
write!(self.out, " + ")?;
2850-
self.put_expression(c, context, false)?;
2844+
crate::Expression::CooperativeMultiplyAdd { a, b, c } => {
2845+
write!(self.out, "{COOPERATIVE_MULTIPLY_ADD_FUNCTION}(")?;
2846+
self.put_expression(a, context, true)?;
2847+
write!(self.out, ", ")?;
2848+
self.put_expression(b, context, true)?;
2849+
write!(self.out, ", ")?;
2850+
self.put_expression(c, context, true)?;
2851+
write!(self.out, ")")?;
28512852
}
28522853
}
28532854
Ok(())
@@ -4230,6 +4231,49 @@ impl<W: Write> Writer<W> {
42304231
}
42314232
writeln!(self.out, ");")?;
42324233
}
4234+
crate::Statement::CooperativeLoadStore {
4235+
store,
4236+
target,
4237+
pointer,
4238+
stride,
4239+
row_major,
4240+
} => {
4241+
let op_str = if store { "store" } else { "load" };
4242+
write!(self.out, "{level}{NAMESPACE}::simdgroup_{op_str}(")?;
4243+
self.put_expression(target, &context.expression, true)?;
4244+
write!(self.out, ", ")?;
4245+
self.put_expression(pointer, &context.expression, true)?;
4246+
if stride.is_some() || row_major {
4247+
write!(self.out, ", ")?;
4248+
match stride {
4249+
Some(expression) => {
4250+
self.put_expression(expression, &context.expression, true)?;
4251+
}
4252+
None => {
4253+
let default_stride = match *context.expression.resolve_type(target)
4254+
{
4255+
crate::TypeInner::CooperativeMatrix {
4256+
columns, rows, ..
4257+
} => {
4258+
if row_major {
4259+
columns as u32
4260+
} else {
4261+
rows as u32
4262+
}
4263+
}
4264+
_ => 0,
4265+
};
4266+
write!(self.out, "{default_stride}")?;
4267+
}
4268+
}
4269+
}
4270+
if row_major {
4271+
let matrix_origin = "0";
4272+
let transpose = true;
4273+
write!(self.out, ", {matrix_origin}, {transpose}")?;
4274+
}
4275+
writeln!(self.out, ");")?;
4276+
}
42334277
}
42344278
}
42354279

@@ -6286,6 +6330,62 @@ template <typename A>
62866330
Ok(())
62876331
}
62886332

6333+
fn write_wrapped_cooperative_multiply_add(
6334+
&mut self,
6335+
module: &crate::Module,
6336+
func_ctx: &back::FunctionCtx,
6337+
a: Handle<crate::Expression>,
6338+
b: Handle<crate::Expression>,
6339+
) -> BackendResult {
6340+
let (a_c, a_r, scalar) = match *func_ctx.resolve_type(a, &module.types) {
6341+
crate::TypeInner::CooperativeMatrix {
6342+
columns,
6343+
rows,
6344+
scalar,
6345+
..
6346+
} => (columns, rows, scalar),
6347+
_ => unreachable!(),
6348+
};
6349+
let (b_c, b_r) = match *func_ctx.resolve_type(b, &module.types) {
6350+
crate::TypeInner::CooperativeMatrix { columns, rows, .. } => (columns, rows),
6351+
_ => unreachable!(),
6352+
};
6353+
let wrapped = WrappedFunction::CooperativeMultiplyAdd {
6354+
columns: b_c,
6355+
rows: a_r,
6356+
intermediate: a_c,
6357+
scalar,
6358+
};
6359+
if !self.wrapped_functions.insert(wrapped) {
6360+
return Ok(());
6361+
}
6362+
let scalar_name = match scalar.width {
6363+
2 => "half",
6364+
4 => "float",
6365+
8 => "double",
6366+
_ => unreachable!(),
6367+
};
6368+
writeln!(
6369+
self.out,
6370+
"{NAMESPACE}::simdgroup_{scalar_name}{}x{} {COOPERATIVE_MULTIPLY_ADD_FUNCTION}(const {NAMESPACE}::simdgroup_{scalar_name}{}x{}& a, const {NAMESPACE}::simdgroup_{scalar_name}{}x{}& b, const {NAMESPACE}::simdgroup_{scalar_name}{}x{}& c) {{",
6371+
b_c as u32, a_r as u32, a_c as u32, a_r as u32, b_c as u32, b_r as u32, b_c as u32, a_r as u32,
6372+
)?;
6373+
let l1 = back::Level(1);
6374+
writeln!(
6375+
self.out,
6376+
"{l1}{NAMESPACE}::simdgroup_{scalar_name}{}x{} d;",
6377+
b_c as u32, a_r as u32
6378+
)?;
6379+
writeln!(
6380+
self.out,
6381+
"{l1}{NAMESPACE}::simdgroup_multiply_accumulate(d,a,b,c);"
6382+
)?;
6383+
writeln!(self.out, "{l1}return d;")?;
6384+
writeln!(self.out, "}}")?;
6385+
writeln!(self.out)?;
6386+
Ok(())
6387+
}
6388+
62896389
pub(super) fn write_wrapped_functions(
62906390
&mut self,
62916391
module: &crate::Module,
@@ -6360,6 +6460,9 @@ template <typename A>
63606460
crate::Expression::ImageQuery { image, query } => {
63616461
self.write_wrapped_image_query(module, func_ctx, image, query)?;
63626462
}
6463+
crate::Expression::CooperativeMultiplyAdd { a, b, c: _ } => {
6464+
self.write_wrapped_cooperative_multiply_add(module, func_ctx, a, b)?;
6465+
}
63636466
_ => {}
63646467
}
63656468
}

naga/src/back/pipeline_constants.rs

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -633,7 +633,7 @@ fn adjust_expr(new_pos: &HandleVec<Expression, Handle<Expression>>, expr: &mut E
633633
} => {
634634
adjust(query);
635635
}
636-
Expression::MulAdd {
636+
Expression::CooperativeMultiplyAdd {
637637
ref mut a,
638638
ref mut b,
639639
ref mut c,
@@ -844,6 +844,19 @@ fn adjust_stmt(new_pos: &HandleVec<Expression, Handle<Expression>>, stmt: &mut S
844844
crate::RayQueryFunction::Terminate => {}
845845
}
846846
}
847+
Statement::CooperativeLoadStore {
848+
store: _,
849+
ref mut target,
850+
ref mut pointer,
851+
ref mut stride,
852+
row_major: _,
853+
} => {
854+
adjust(target);
855+
adjust(pointer);
856+
if let Some(ref mut stride) = *stride {
857+
adjust(stride);
858+
}
859+
}
847860
Statement::Break
848861
| Statement::Continue
849862
| Statement::Kill

naga/src/back/spv/block.rs

Lines changed: 47 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1805,14 +1805,21 @@ impl BlockContext<'_> {
18051805
)?;
18061806
self.write_ray_query_return_vertex_position(query, block, committed)
18071807
}
1808-
crate::Expression::MulAdd { a, b, c } => {
1808+
crate::Expression::CooperativeMultiplyAdd { a, b, c } => {
1809+
self.writer.require_any(
1810+
"CooperativeMatrix",
1811+
&[spirv::Capability::CooperativeMatrixKHR],
1812+
)?;
1813+
let a_id = self.cached[a];
1814+
let b_id = self.cached[b];
1815+
let c_id = self.cached[c];
18091816
let id = self.gen_id();
18101817
block.body.push(Instruction::coop_mul_add(
18111818
result_type_id,
18121819
id,
1813-
self.cached[a],
1814-
self.cached[b],
1815-
self.cached[c],
1820+
a_id,
1821+
b_id,
1822+
c_id,
18161823
));
18171824
id
18181825
}
@@ -3677,6 +3684,42 @@ impl BlockContext<'_> {
36773684
} => {
36783685
self.write_subgroup_gather(mode, argument, result, &mut block)?;
36793686
}
3687+
Statement::CooperativeLoadStore {
3688+
store,
3689+
target,
3690+
pointer,
3691+
stride,
3692+
row_major,
3693+
} => {
3694+
let layout = if row_major {
3695+
spirv::CooperativeMatrixLayout::RowMajorKHR
3696+
} else {
3697+
spirv::CooperativeMatrixLayout::ColumnMajorKHR
3698+
};
3699+
let layout_id = self.get_index_constant(layout as u32);
3700+
let stride_id = stride.map(|exp| self.cached[exp]);
3701+
if store {
3702+
block.body.push(Instruction::coop_store(
3703+
self.cached[target],
3704+
self.cached[pointer],
3705+
layout_id,
3706+
stride_id,
3707+
));
3708+
} else {
3709+
let result_type_id = self.get_expression_type_id(&self.fun_info[target].ty);
3710+
let id = self.gen_id();
3711+
block.body.push(Instruction::coop_load(
3712+
result_type_id,
3713+
id,
3714+
self.cached[pointer],
3715+
layout_id,
3716+
stride_id,
3717+
));
3718+
block
3719+
.body
3720+
.push(Instruction::store(self.cached[target], id, None));
3721+
}
3722+
}
36803723
}
36813724
}
36823725

0 commit comments

Comments
 (0)