Skip to content

Commit 7da965e

Browse files
committed
coop: rewire WGSL support using references
1 parent 321ddf4 commit 7da965e

22 files changed

+190
-249
lines changed

CHANGELOG.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -166,7 +166,7 @@ By @cwfitzgerald in [#8162](https://github.com/gfx-rs/wgpu/pull/8162).
166166

167167
- Added support for external textures based on WebGPU's [`GPUExternalTexture`](https://www.w3.org/TR/webgpu/#gpuexternaltexture). These allow shaders to transparently operate on potentially multiplanar source texture data in either RGB or YCbCr formats via WGSL's `texture_external` type. This is gated behind the `Features::EXTERNAL_TEXTURE` feature, which is currently only supported on DX12. By @jamienicol in [#4386](https://github.com/gfx-rs/wgpu/issues/4386).
168168

169-
- Added support for cooperative load/store operations in shaders. Currently only WGSL on the input and SPIR-V with METAL on the output are supported. By @kvark in [#8251](https://github.com/gfx-rs/wgpu/issues/8251).
169+
- Added support for cooperative load/store operations in shaders. Currently only WGSL on the input and SPIR-V,METAL, and WGSL on the output are supported. By @kvark in [#8251](https://github.com/gfx-rs/wgpu/issues/8251).
170170

171171
### Changes
172172

naga/src/back/mod.rs

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -311,6 +311,16 @@ pub const fn binary_operation_str(op: crate::BinaryOperator) -> &'static str {
311311
}
312312
}
313313

314+
impl crate::TypeInner {
315+
/// Returns true if a variable of this type is a handle.
316+
pub const fn is_handle(&self) -> bool {
317+
match *self {
318+
Self::Image { .. } | Self::Sampler { .. } | Self::AccelerationStructure { .. } => true,
319+
_ => false,
320+
}
321+
}
322+
}
323+
314324
impl crate::Statement {
315325
/// Returns true if the statement directly terminates the current block.
316326
///

naga/src/back/msl/mod.rs

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -228,8 +228,10 @@ pub enum Error {
228228
UnsupportedArrayOf(String),
229229
#[error("array of type '{0:?}' is not supported")]
230230
UnsupportedArrayOfType(Handle<crate::Type>),
231-
#[error("ray tracing is not supported prior to MSL 2.3")]
231+
#[error("ray tracing is not supported prior to MSL 2.4")]
232232
UnsupportedRayTracing,
233+
#[error("cooperative matrix is not supported prior to MSL 2.3")]
234+
UnsupportedCooperativeMatrix,
233235
#[error("overrides should not be present at this stage")]
234236
Override,
235237
#[error("bitcasting to {0:?} is not supported")]

naga/src/back/msl/writer.rs

Lines changed: 21 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -236,6 +236,7 @@ impl Display for TypeContext<'_> {
236236
rows,
237237
scalar,
238238
} => put_numeric_type(out, scalar, &[rows, columns]),
239+
// Requires Metal-2.3
239240
crate::TypeInner::CooperativeMatrix {
240241
columns,
241242
rows,
@@ -244,8 +245,7 @@ impl Display for TypeContext<'_> {
244245
} => {
245246
write!(
246247
out,
247-
"{}::simdgroup_{}{}x{}",
248-
NAMESPACE,
248+
"{NAMESPACE}::simdgroup_{}{}x{}",
249249
scalar.to_msl_name(),
250250
columns as u32,
251251
rows as u32,
@@ -485,6 +485,7 @@ enum WrappedFunction {
485485
class: crate::ImageClass,
486486
},
487487
CooperativeMultiplyAdd {
488+
space: crate::AddressSpace,
488489
columns: crate::CooperativeSize,
489490
rows: crate::CooperativeSize,
490491
intermediate: crate::CooperativeSize,
@@ -2842,6 +2843,9 @@ impl<W: Write> Writer<W> {
28422843
write!(self.out, "}}")?;
28432844
}
28442845
crate::Expression::CooperativeMultiplyAdd { a, b, c } => {
2846+
if context.lang_version < (2, 3) {
2847+
return Err(Error::UnsupportedCooperativeMatrix);
2848+
}
28452849
write!(self.out, "{COOPERATIVE_MULTIPLY_ADD_FUNCTION}(")?;
28462850
self.put_expression(a, context, true)?;
28472851
write!(self.out, ", ")?;
@@ -4239,10 +4243,14 @@ impl<W: Write> Writer<W> {
42394243
row_major,
42404244
} => {
42414245
let op_str = if store { "store" } else { "load" };
4242-
write!(self.out, "{level}{NAMESPACE}::simdgroup_{op_str}(")?;
4246+
write!(self.out, "{level}simdgroup_{op_str}(")?;
42434247
self.put_expression(target, &context.expression, true)?;
4244-
write!(self.out, ", ")?;
4245-
self.put_expression(pointer, &context.expression, true)?;
4248+
write!(self.out, ", &")?;
4249+
self.put_access_chain(
4250+
pointer,
4251+
context.expression.policies.index,
4252+
&context.expression,
4253+
)?;
42464254
write!(self.out, ", ")?;
42474255
self.put_expression(stride, &context.expression, true)?;
42484256
if row_major {
@@ -6312,6 +6320,7 @@ template <typename A>
63126320
&mut self,
63136321
module: &crate::Module,
63146322
func_ctx: &back::FunctionCtx,
6323+
space: crate::AddressSpace,
63156324
a: Handle<crate::Expression>,
63166325
b: Handle<crate::Expression>,
63176326
) -> BackendResult {
@@ -6329,6 +6338,7 @@ template <typename A>
63296338
_ => unreachable!(),
63306339
};
63316340
let wrapped = WrappedFunction::CooperativeMultiplyAdd {
6341+
space,
63326342
columns: b_c,
63336343
rows: a_r,
63346344
intermediate: a_c,
@@ -6337,15 +6347,11 @@ template <typename A>
63376347
if !self.wrapped_functions.insert(wrapped) {
63386348
return Ok(());
63396349
}
6340-
let scalar_name = match scalar.width {
6341-
2 => "half",
6342-
4 => "float",
6343-
8 => "double",
6344-
_ => unreachable!(),
6345-
};
6350+
let space_name = space.to_msl_name().unwrap_or_default();
6351+
let scalar_name = scalar.to_msl_name();
63466352
writeln!(
63476353
self.out,
6348-
"{NAMESPACE}::simdgroup_{scalar_name}{}x{} {COOPERATIVE_MULTIPLY_ADD_FUNCTION}(const {NAMESPACE}::simdgroup_{scalar_name}{}x{}& a, const {NAMESPACE}::simdgroup_{scalar_name}{}x{}& b, const {NAMESPACE}::simdgroup_{scalar_name}{}x{}& c) {{",
6354+
"{NAMESPACE}::simdgroup_{scalar_name}{}x{} {COOPERATIVE_MULTIPLY_ADD_FUNCTION}(const {space_name} {NAMESPACE}::simdgroup_{scalar_name}{}x{}& a, const {space_name} {NAMESPACE}::simdgroup_{scalar_name}{}x{}& b, const {space_name} {NAMESPACE}::simdgroup_{scalar_name}{}x{}& c) {{",
63496355
b_c as u32, a_r as u32, a_c as u32, a_r as u32, b_c as u32, b_r as u32, b_c as u32, a_r as u32,
63506356
)?;
63516357
let l1 = back::Level(1);
@@ -6354,10 +6360,7 @@ template <typename A>
63546360
"{l1}{NAMESPACE}::simdgroup_{scalar_name}{}x{} d;",
63556361
b_c as u32, a_r as u32
63566362
)?;
6357-
writeln!(
6358-
self.out,
6359-
"{l1}{NAMESPACE}::simdgroup_multiply_accumulate(d,a,b,c);"
6360-
)?;
6363+
writeln!(self.out, "{l1}simdgroup_multiply_accumulate(d,a,b,c);")?;
63616364
writeln!(self.out, "{l1}return d;")?;
63626365
writeln!(self.out, "}}")?;
63636366
writeln!(self.out)?;
@@ -6439,7 +6442,8 @@ template <typename A>
64396442
self.write_wrapped_image_query(module, func_ctx, image, query)?;
64406443
}
64416444
crate::Expression::CooperativeMultiplyAdd { a, b, c: _ } => {
6442-
self.write_wrapped_cooperative_multiply_add(module, func_ctx, a, b)?;
6445+
let space = crate::AddressSpace::Private;
6446+
self.write_wrapped_cooperative_multiply_add(module, func_ctx, space, a, b)?;
64436447
}
64446448
_ => {}
64456449
}
@@ -6632,7 +6636,6 @@ template <typename A>
66326636
names: &self.names,
66336637
handle,
66346638
usage: fun_info[handle],
6635-
66366639
reference: true,
66376640
};
66386641
let separator =

naga/src/back/spv/block.rs

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3717,7 +3717,13 @@ impl BlockContext<'_> {
37173717
self.cached[stride],
37183718
));
37193719
} else {
3720-
let result_type_id = self.get_expression_type_id(&self.fun_info[target].ty);
3720+
let result_type_id =
3721+
match *self.fun_info[target].ty.inner_with(&self.ir_module.types) {
3722+
crate::TypeInner::Pointer { base, space: _ } => {
3723+
self.get_handle_type_id(base)
3724+
}
3725+
_ => unreachable!(),
3726+
};
37213727
let id = self.gen_id();
37223728
block.body.push(Instruction::coop_load(
37233729
result_type_id,

naga/src/back/spv/writer.rs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -971,14 +971,13 @@ impl Writer {
971971
}
972972
}
973973

974-
// Handle globals are pre-emitted and should be loaded automatically.
975-
//
976-
// Any that are binding arrays we skip as we cannot load the array, we must load the result after indexing.
977974
match ir_module.types[var.ty].inner {
975+
// Any that are binding arrays we skip as we cannot load the array, we must load the result after indexing.
978976
crate::TypeInner::BindingArray { .. } => {
979977
gv.access_id = gv.var_id;
980978
}
981979
_ => {
980+
// Handle globals are pre-emitted and should be loaded automatically.
982981
if var.space == crate::AddressSpace::Handle {
983982
let var_type_id = self.get_handle_type_id(var.ty);
984983
let id = self.id_gen.next();
@@ -1064,6 +1063,7 @@ impl Writer {
10641063
}
10651064
}),
10661065
);
1066+
10671067
context
10681068
.function
10691069
.variables

naga/src/back/wgsl/writer.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -993,13 +993,13 @@ impl<W: Write> Writer<W> {
993993
} => {
994994
let op_str = if store { "Store" } else { "Load" };
995995
let suffix = if row_major { "T" } else { "" };
996-
write!(self.out, "coop{op_str}{suffix}(")?;
996+
write!(self.out, "{level}coop{op_str}{suffix}(")?;
997997
self.write_expr(module, target, func_ctx)?;
998998
write!(self.out, ", ")?;
999999
self.write_expr(module, pointer, func_ctx)?;
10001000
write!(self.out, ", ")?;
10011001
self.write_expr(module, stride, func_ctx)?;
1002-
write!(self.out, ")")?
1002+
writeln!(self.out, ");")?
10031003
}
10041004
}
10051005

naga/src/front/wgsl/error.rs

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -412,6 +412,7 @@ pub(crate) enum Error<'a> {
412412
TypeTooLarge {
413413
span: Span,
414414
},
415+
InvalidCooperativeMatrix,
415416
UnderspecifiedCooperativeMatrix,
416417
UnsupportedCooperativeScalar(Span),
417418
}
@@ -1388,6 +1389,11 @@ impl<'a> Error<'a> {
13881389
crate::valid::MAX_TYPE_SIZE
13891390
)],
13901391
},
1392+
Error::InvalidCooperativeMatrix => ParseError {
1393+
message: "given type is not a cooperative matrix".into(),
1394+
labels: vec![],
1395+
notes: vec![format!("must be coop_mat")],
1396+
},
13911397
Error::UnderspecifiedCooperativeMatrix => ParseError {
13921398
message: "cooperative matrix constructor is underspecified".into(),
13931399
labels: vec![],

naga/src/front/wgsl/lower/mod.rs

Lines changed: 23 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -524,6 +524,7 @@ impl<'source, 'temp, 'out> ExpressionContext<'source, 'temp, 'out> {
524524
span: Span,
525525
) -> Result<'source, Handle<ir::Expression>> {
526526
let mut eval = self.as_const_evaluator();
527+
log::debug!("appending {expr:?}");
527528
eval.try_eval_and_append(expr, span)
528529
.map_err(|e| Box::new(Error::ConstantEvaluatorError(e.into(), span)))
529530
}
@@ -846,6 +847,15 @@ impl<'source, 'temp, 'out> ExpressionContext<'source, 'temp, 'out> {
846847
fn ensure_type_exists(&mut self, inner: ir::TypeInner) -> Handle<ir::Type> {
847848
self.as_global().ensure_type_exists(None, inner)
848849
}
850+
851+
fn _get_runtime_expression(&self, expr: Handle<ir::Expression>) -> &ir::Expression {
852+
match self.expr_type {
853+
ExpressionContextType::Runtime(ref ctx) => &ctx.function.expressions[expr],
854+
ExpressionContextType::Constant(_) | ExpressionContextType::Override => {
855+
unreachable!()
856+
}
857+
}
858+
}
849859
}
850860

851861
struct ArgumentContext<'ctx, 'source> {
@@ -955,6 +965,13 @@ impl<T> Typed<T> {
955965
Self::Plain(expr) => Typed::Plain(f(expr)?),
956966
})
957967
}
968+
969+
fn ref_or<E>(self, error: E) -> core::result::Result<T, E> {
970+
match self {
971+
Self::Reference(v) => Ok(v),
972+
Self::Plain(_) => Err(error),
973+
}
974+
}
958975
}
959976

960977
/// A single vector component or swizzle.
@@ -1677,12 +1694,8 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
16771694
.as_expression(block, &mut emitter)
16781695
.interrupt_emitter(ir::Expression::LocalVariable(var), Span::UNDEFINED)?;
16791696
block.extend(emitter.finish(&ctx.function.expressions));
1680-
let typed = if ctx.module.types[ty].inner.is_handle() {
1681-
Typed::Plain(handle)
1682-
} else {
1683-
Typed::Reference(handle)
1684-
};
1685-
ctx.local_table.insert(v.handle, Declared::Runtime(typed));
1697+
ctx.local_table
1698+
.insert(v.handle, Declared::Runtime(Typed::Reference(handle)));
16861699

16871700
match initializer {
16881701
Some(initializer) => ir::Statement::Store {
@@ -1977,12 +1990,7 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
19771990
let value_span = ctx.ast_expressions.get_span(value);
19781991
let target = self
19791992
.expression_for_reference(value, &mut ctx.as_expression(block, &mut emitter))?;
1980-
let target_handle = match target {
1981-
Typed::Reference(handle) => handle,
1982-
Typed::Plain(_) => {
1983-
return Err(Box::new(Error::BadIncrDecrReferenceType(value_span)))
1984-
}
1985-
};
1993+
let target_handle = target.ref_or(Error::BadIncrDecrReferenceType(value_span))?;
19861994

19871995
let mut ectx = ctx.as_expression(block, &mut emitter);
19881996
let scalar = match *resolve_inner!(ectx, target_handle) {
@@ -2139,10 +2147,8 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
21392147
LoweredGlobalDecl::Var(handle) => {
21402148
let expr = ir::Expression::GlobalVariable(handle);
21412149
let v = &ctx.module.global_variables[handle];
2142-
let force_value = ctx.module.types[v.ty].inner.is_handle();
21432150
match v.space {
21442151
ir::AddressSpace::Handle => Typed::Plain(expr),
2145-
_ if force_value => Typed::Plain(expr),
21462152
_ => Typed::Reference(expr),
21472153
}
21482154
}
@@ -3140,7 +3146,9 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
31403146
let row_major = function.name.ends_with("T");
31413147

31423148
let mut args = ctx.prepare_args(arguments, 2, span);
3143-
let target = self.expression(args.next()?, ctx)?;
3149+
let target = self
3150+
.expression_for_reference(args.next()?, ctx)?
3151+
.ref_or(Error::InvalidCooperativeMatrix)?;
31443152
let pointer = self.expression(args.next()?, ctx)?;
31453153
let stride = if args.total_args > 2 {
31463154
self.expression(args.next()?, ctx)?

naga/src/proc/type_methods.rs

Lines changed: 0 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -191,17 +191,6 @@ impl crate::TypeInner {
191191
}
192192
}
193193

194-
/// Returns true if a variable of this type is a handle.
195-
pub const fn is_handle(&self) -> bool {
196-
match *self {
197-
Self::Image { .. }
198-
| Self::Sampler { .. }
199-
| Self::AccelerationStructure { .. }
200-
| Self::CooperativeMatrix { .. } => true,
201-
_ => false,
202-
}
203-
}
204-
205194
/// Attempt to calculate the size of this type. Returns `None` if the size
206195
/// exceeds the limit of [`crate::valid::MAX_TYPE_SIZE`].
207196
pub fn try_size(&self, gctx: super::GlobalCtx) -> Option<u32> {

0 commit comments

Comments
 (0)