Skip to content

Commit 07be9e9

Browse files
committed
coop: make cooperativeLoad to be an expression
1 parent 87360f2 commit 07be9e9

25 files changed

+462
-324
lines changed

naga/src/back/dot/mod.rs

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -403,20 +403,14 @@ impl StatementGraph {
403403
},
404404
}
405405
}
406-
S::CooperativeLoadStore {
407-
store,
408-
target,
409-
pointer,
410-
stride,
411-
row_major: _,
412-
} => {
406+
S::CooperativeStore { target, data } => {
413407
self.dependencies.push((id, target, "target"));
414-
self.dependencies.push((id, pointer, "pointer"));
415-
self.dependencies.push((id, stride, "stride"));
416-
if store {
417-
"Store"
408+
self.dependencies.push((id, data.pointer, "pointer"));
409+
self.dependencies.push((id, data.stride, "stride"));
410+
if data.row_major {
411+
"CoopStoreT"
418412
} else {
419-
"Load"
413+
"CoopStore"
420414
}
421415
}
422416
};
@@ -758,6 +752,12 @@ fn write_function_expressions(
758752
let ty = if committed { "Committed" } else { "Candidate" };
759753
(format!("get{ty}HitVertexPositions").into(), 4)
760754
}
755+
E::CooperativeLoad { ref data, .. } => {
756+
edges.insert("pointer", data.pointer);
757+
edges.insert("stride", data.stride);
758+
let suffix = if data.row_major { "T " } else { "" };
759+
(format!("coopLoad{suffix}").into(), 4)
760+
}
761761
E::CooperativeMultiplyAdd { a, b, c } => {
762762
edges.insert("a", a);
763763
edges.insert("b", b);

naga/src/back/glsl/mod.rs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2805,7 +2805,7 @@ impl<'a, W: Write> Writer<'a, W> {
28052805
}
28062806
writeln!(self.out, ");")?;
28072807
}
2808-
Statement::CooperativeLoadStore { .. } => unimplemented!(),
2808+
Statement::CooperativeStore { .. } => unimplemented!(),
28092809
}
28102810

28112811
Ok(())
@@ -4343,6 +4343,7 @@ impl<'a, W: Write> Writer<'a, W> {
43434343
// not supported yet
43444344
Expression::RayQueryGetIntersection { .. }
43454345
| Expression::RayQueryVertexPositions { .. }
4346+
| Expression::CooperativeLoad { .. }
43464347
| Expression::CooperativeMultiplyAdd { .. } => unreachable!(),
43474348
}
43484349

naga/src/back/hlsl/writer.rs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2747,7 +2747,7 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
27472747
}
27482748
writeln!(self.out, ");")?;
27492749
}
2750-
Statement::CooperativeLoadStore { .. } => unimplemented!(),
2750+
Statement::CooperativeStore { .. } => unimplemented!(),
27512751
}
27522752

27532753
Ok(())
@@ -4277,6 +4277,7 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
42774277
}
42784278
// Not supported yet
42794279
Expression::RayQueryVertexPositions { .. }
4280+
| Expression::CooperativeLoad { .. }
42804281
| Expression::CooperativeMultiplyAdd { .. } => {
42814282
unreachable!()
42824283
}

naga/src/back/msl/writer.rs

Lines changed: 86 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,7 @@ pub(crate) const ARGUMENT_BUFFER_WRAPPER_STRUCT: &str = "NagaArgumentBufferWrapp
7878
/// allowing them to be conveniently passed to user-defined or wrapper
7979
/// functions. The struct is declared in [`Writer::write_type_defs`].
8080
pub(crate) const EXTERNAL_TEXTURE_WRAPPER_STRUCT: &str = "NagaExternalTextureWrapper";
81+
pub(crate) const COOPERATIVE_LOAD_FUNCTION: &str = "NagaCooperativeLoad";
8182
pub(crate) const COOPERATIVE_MULTIPLY_ADD_FUNCTION: &str = "NagaCooperativeMultiplyAdd";
8283

8384
/// Write the Metal name for a Naga numeric type: scalar, vector, or matrix.
@@ -484,6 +485,12 @@ enum WrappedFunction {
484485
ImageQuerySize {
485486
class: crate::ImageClass,
486487
},
488+
CooperativeLoad {
489+
space: crate::AddressSpace,
490+
columns: crate::CooperativeSize,
491+
rows: crate::CooperativeSize,
492+
scalar: crate::Scalar,
493+
},
487494
CooperativeMultiplyAdd {
488495
space: crate::AddressSpace,
489496
columns: crate::CooperativeSize,
@@ -2842,6 +2849,17 @@ impl<W: Write> Writer<W> {
28422849
}
28432850
write!(self.out, "}}")?;
28442851
}
2852+
crate::Expression::CooperativeLoad { ref data, .. } => {
2853+
if context.lang_version < (2, 3) {
2854+
return Err(Error::UnsupportedCooperativeMatrix);
2855+
}
2856+
write!(self.out, "{COOPERATIVE_LOAD_FUNCTION}(")?;
2857+
write!(self.out, "&")?;
2858+
self.put_access_chain(data.pointer, context.policies.index, context)?;
2859+
write!(self.out, ", ")?;
2860+
self.put_expression(data.stride, context, true)?;
2861+
write!(self.out, ", {})", data.row_major)?;
2862+
}
28452863
crate::Expression::CooperativeMultiplyAdd { a, b, c } => {
28462864
if context.lang_version < (2, 3) {
28472865
return Err(Error::UnsupportedCooperativeMatrix);
@@ -4235,25 +4253,18 @@ impl<W: Write> Writer<W> {
42354253
}
42364254
writeln!(self.out, ");")?;
42374255
}
4238-
crate::Statement::CooperativeLoadStore {
4239-
store,
4240-
target,
4241-
pointer,
4242-
stride,
4243-
row_major,
4244-
} => {
4245-
let op_str = if store { "store" } else { "load" };
4246-
write!(self.out, "{level}simdgroup_{op_str}(")?;
4256+
crate::Statement::CooperativeStore { target, ref data } => {
4257+
write!(self.out, "{level}simdgroup_store(")?;
42474258
self.put_expression(target, &context.expression, true)?;
42484259
write!(self.out, ", &")?;
42494260
self.put_access_chain(
4250-
pointer,
4261+
data.pointer,
42514262
context.expression.policies.index,
42524263
&context.expression,
42534264
)?;
42544265
write!(self.out, ", ")?;
4255-
self.put_expression(stride, &context.expression, true)?;
4256-
if row_major {
4266+
self.put_expression(data.stride, &context.expression, true)?;
4267+
if data.row_major {
42574268
let matrix_origin = "0";
42584269
let transpose = true;
42594270
write!(self.out, ", {matrix_origin}, {transpose}")?;
@@ -6316,6 +6327,55 @@ template <typename A>
63166327
Ok(())
63176328
}
63186329

6330+
fn write_wrapped_cooperative_load(
6331+
&mut self,
6332+
module: &crate::Module,
6333+
func_ctx: &back::FunctionCtx,
6334+
columns: crate::CooperativeSize,
6335+
rows: crate::CooperativeSize,
6336+
pointer: Handle<crate::Expression>,
6337+
) -> BackendResult {
6338+
let ptr_ty = func_ctx.resolve_type(pointer, &module.types);
6339+
let space = ptr_ty.pointer_space().unwrap();
6340+
let scalar = ptr_ty
6341+
.pointer_base_type()
6342+
.unwrap()
6343+
.inner_with(&module.types)
6344+
.scalar()
6345+
.unwrap();
6346+
let wrapped = WrappedFunction::CooperativeLoad {
6347+
space,
6348+
columns,
6349+
rows,
6350+
scalar,
6351+
};
6352+
if !self.wrapped_functions.insert(wrapped) {
6353+
return Ok(());
6354+
}
6355+
let space_name = space.to_msl_name().unwrap_or_default();
6356+
let scalar_name = scalar.to_msl_name();
6357+
writeln!(
6358+
self.out,
6359+
"{NAMESPACE}::simdgroup_{scalar_name}{}x{} {COOPERATIVE_LOAD_FUNCTION}(const {space_name} {scalar_name}* ptr, int stride, bool is_row_major) {{",
6360+
columns as u32, rows as u32,
6361+
)?;
6362+
let l1 = back::Level(1);
6363+
writeln!(
6364+
self.out,
6365+
"{l1}{NAMESPACE}::simdgroup_{scalar_name}{}x{} m;",
6366+
columns as u32, rows as u32
6367+
)?;
6368+
let matrix_origin = "0";
6369+
writeln!(
6370+
self.out,
6371+
"{l1}simdgroup_load(m, ptr, stride, {matrix_origin}, is_row_major);"
6372+
)?;
6373+
writeln!(self.out, "{l1}return m;")?;
6374+
writeln!(self.out, "}}")?;
6375+
writeln!(self.out)?;
6376+
Ok(())
6377+
}
6378+
63196379
fn write_wrapped_cooperative_multiply_add(
63206380
&mut self,
63216381
module: &crate::Module,
@@ -6441,6 +6501,20 @@ template <typename A>
64416501
crate::Expression::ImageQuery { image, query } => {
64426502
self.write_wrapped_image_query(module, func_ctx, image, query)?;
64436503
}
6504+
crate::Expression::CooperativeLoad {
6505+
columns,
6506+
rows,
6507+
role: _,
6508+
ref data,
6509+
} => {
6510+
self.write_wrapped_cooperative_load(
6511+
module,
6512+
func_ctx,
6513+
columns,
6514+
rows,
6515+
data.pointer,
6516+
)?;
6517+
}
64446518
crate::Expression::CooperativeMultiplyAdd { a, b, c: _ } => {
64456519
let space = crate::AddressSpace::Private;
64466520
self.write_wrapped_cooperative_multiply_add(module, func_ctx, space, a, b)?;

naga/src/back/pipeline_constants.rs

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -633,6 +633,10 @@ fn adjust_expr(new_pos: &HandleVec<Expression, Handle<Expression>>, expr: &mut E
633633
} => {
634634
adjust(query);
635635
}
636+
Expression::CooperativeLoad { ref mut data, .. } => {
637+
adjust(&mut data.pointer);
638+
adjust(&mut data.stride);
639+
}
636640
Expression::CooperativeMultiplyAdd {
637641
ref mut a,
638642
ref mut b,
@@ -844,16 +848,13 @@ fn adjust_stmt(new_pos: &HandleVec<Expression, Handle<Expression>>, stmt: &mut S
844848
crate::RayQueryFunction::Terminate => {}
845849
}
846850
}
847-
Statement::CooperativeLoadStore {
848-
store: _,
851+
Statement::CooperativeStore {
849852
ref mut target,
850-
ref mut pointer,
851-
ref mut stride,
852-
row_major: _,
853+
ref mut data,
853854
} => {
854855
adjust(target);
855-
adjust(pointer);
856-
adjust(stride);
856+
adjust(&mut data.pointer);
857+
adjust(&mut data.stride);
857858
}
858859
Statement::Break
859860
| Statement::Continue

naga/src/back/spv/block.rs

Lines changed: 42 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -1805,6 +1805,39 @@ impl BlockContext<'_> {
18051805
)?;
18061806
self.write_ray_query_return_vertex_position(query, block, committed)
18071807
}
1808+
crate::Expression::CooperativeLoad { ref data, .. } => {
1809+
self.writer.require_any(
1810+
"CooperativeMatrix",
1811+
&[spirv::Capability::CooperativeMatrixKHR],
1812+
)?;
1813+
let pointer_id = match self.write_access_chain(
1814+
data.pointer,
1815+
block,
1816+
AccessTypeAdjustment::None,
1817+
)? {
1818+
ExpressionPointer::Ready { pointer_id } => pointer_id,
1819+
ExpressionPointer::Conditional { .. } => {
1820+
return Err(Error::FeatureNotImplemented(
1821+
"Copperative load/store out-of-bounds handling",
1822+
));
1823+
}
1824+
};
1825+
let layout = if data.row_major {
1826+
spirv::CooperativeMatrixLayout::RowMajorKHR
1827+
} else {
1828+
spirv::CooperativeMatrixLayout::ColumnMajorKHR
1829+
};
1830+
let layout_id = self.get_index_constant(layout as u32);
1831+
let id = self.gen_id();
1832+
block.body.push(Instruction::coop_load(
1833+
result_type_id,
1834+
id,
1835+
pointer_id,
1836+
layout_id,
1837+
self.cached[data.stride],
1838+
));
1839+
id
1840+
}
18081841
crate::Expression::CooperativeMultiplyAdd { a, b, c } => {
18091842
self.writer.require_any(
18101843
"CooperativeMatrix",
@@ -3684,15 +3717,9 @@ impl BlockContext<'_> {
36843717
} => {
36853718
self.write_subgroup_gather(mode, argument, result, &mut block)?;
36863719
}
3687-
Statement::CooperativeLoadStore {
3688-
store,
3689-
target,
3690-
pointer,
3691-
stride,
3692-
row_major,
3693-
} => {
3720+
Statement::CooperativeStore { target, ref data } => {
36943721
let pointer_id = match self.write_access_chain(
3695-
pointer,
3722+
data.pointer,
36963723
&mut block,
36973724
AccessTypeAdjustment::None,
36983725
)? {
@@ -3703,44 +3730,18 @@ impl BlockContext<'_> {
37033730
));
37043731
}
37053732
};
3706-
let layout = if row_major {
3733+
let layout = if data.row_major {
37073734
spirv::CooperativeMatrixLayout::RowMajorKHR
37083735
} else {
37093736
spirv::CooperativeMatrixLayout::ColumnMajorKHR
37103737
};
37113738
let layout_id = self.get_index_constant(layout as u32);
3712-
if store {
3713-
block.body.push(Instruction::coop_store(
3714-
self.cached[target],
3715-
pointer_id,
3716-
layout_id,
3717-
self.cached[stride],
3718-
));
3719-
} else {
3720-
let result_type_id = self.get_expression_type_id(&self.fun_info[target].ty);
3721-
let id = self.gen_id();
3722-
block.body.push(Instruction::coop_load(
3723-
result_type_id,
3724-
id,
3725-
pointer_id,
3726-
layout_id,
3727-
self.cached[stride],
3728-
));
3729-
match self.write_access_chain(
3730-
target,
3731-
&mut block,
3732-
AccessTypeAdjustment::None,
3733-
)? {
3734-
ExpressionPointer::Ready {
3735-
pointer_id: target_id,
3736-
} => {
3737-
block.body.push(Instruction::store(target_id, id, None));
3738-
}
3739-
ExpressionPointer::Conditional { .. } => {
3740-
unimplemented!()
3741-
}
3742-
};
3743-
}
3739+
block.body.push(Instruction::coop_store(
3740+
self.cached[target],
3741+
pointer_id,
3742+
layout_id,
3743+
self.cached[data.stride],
3744+
));
37443745
}
37453746
}
37463747
}

0 commit comments

Comments
 (0)