Skip to content

Commit 056a5c1

Browse files
committed
coop: rewire WGSL support using references
1 parent 0360b94 commit 056a5c1

18 files changed

+174
-293
lines changed

CHANGELOG.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -166,7 +166,7 @@ By @cwfitzgerald in [#8162](https://github.com/gfx-rs/wgpu/pull/8162).
166166

167167
- Added support for external textures based on WebGPU's [`GPUExternalTexture`](https://www.w3.org/TR/webgpu/#gpuexternaltexture). These allow shaders to transparently operate on potentially multiplanar source texture data in either RGB or YCbCr formats via WGSL's `texture_external` type. This is gated behind the `Features::EXTERNAL_TEXTURE` feature, which is currently only supported on DX12. By @jamienicol in [#4386](https://github.com/gfx-rs/wgpu/issues/4386).
168168

169-
- Added support for cooperative load/store operations in shaders. Currently only WGSL on the input and SPIR-V with METAL on the output are supported. By @kvark in [#8251](https://github.com/gfx-rs/wgpu/issues/8251).
169+
- Added support for cooperative load/store operations in shaders. Currently only WGSL on the input and SPIR-V,METAL, and WGSL on the output are supported. By @kvark in [#8251](https://github.com/gfx-rs/wgpu/issues/8251).
170170

171171
### Changes
172172

naga/src/back/mod.rs

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -311,6 +311,16 @@ pub const fn binary_operation_str(op: crate::BinaryOperator) -> &'static str {
311311
}
312312
}
313313

314+
impl crate::TypeInner {
315+
/// Returns true if a variable of this type is a handle.
316+
pub const fn is_handle(&self) -> bool {
317+
match *self {
318+
Self::Image { .. } | Self::Sampler { .. } | Self::AccelerationStructure { .. } => true,
319+
_ => false,
320+
}
321+
}
322+
}
323+
314324
impl crate::Statement {
315325
/// Returns true if the statement directly terminates the current block.
316326
///

naga/src/back/msl/writer.rs

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -6316,16 +6316,22 @@ template <typename A>
63166316
b: Handle<crate::Expression>,
63176317
) -> BackendResult {
63186318
let (a_c, a_r, scalar) = match *func_ctx.resolve_type(a, &module.types) {
6319-
crate::TypeInner::CooperativeMatrix {
6320-
columns,
6321-
rows,
6322-
scalar,
6323-
..
6324-
} => (columns, rows, scalar),
6319+
crate::TypeInner::Pointer { base, space: _ } => match module.types[base].inner {
6320+
crate::TypeInner::CooperativeMatrix {
6321+
columns,
6322+
rows,
6323+
scalar,
6324+
..
6325+
} => (columns, rows, scalar),
6326+
_ => unreachable!(),
6327+
},
63256328
_ => unreachable!(),
63266329
};
63276330
let (b_c, b_r) = match *func_ctx.resolve_type(b, &module.types) {
6328-
crate::TypeInner::CooperativeMatrix { columns, rows, .. } => (columns, rows),
6331+
crate::TypeInner::Pointer { base, space: _ } => match module.types[base].inner {
6332+
crate::TypeInner::CooperativeMatrix { columns, rows, .. } => (columns, rows),
6333+
_ => unreachable!(),
6334+
},
63296335
_ => unreachable!(),
63306336
};
63316337
let wrapped = WrappedFunction::CooperativeMultiplyAdd {

naga/src/back/spv/block.rs

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3717,7 +3717,13 @@ impl BlockContext<'_> {
37173717
self.cached[stride],
37183718
));
37193719
} else {
3720-
let result_type_id = self.get_expression_type_id(&self.fun_info[target].ty);
3720+
let result_type_id =
3721+
match *self.fun_info[target].ty.inner_with(&self.ir_module.types) {
3722+
crate::TypeInner::Pointer { base, space: _ } => {
3723+
self.get_handle_type_id(base)
3724+
}
3725+
_ => unreachable!(),
3726+
};
37213727
let id = self.gen_id();
37223728
block.body.push(Instruction::coop_load(
37233729
result_type_id,

naga/src/back/spv/writer.rs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -971,14 +971,13 @@ impl Writer {
971971
}
972972
}
973973

974-
// Handle globals are pre-emitted and should be loaded automatically.
975-
//
976-
// Any that are binding arrays we skip as we cannot load the array, we must load the result after indexing.
977974
match ir_module.types[var.ty].inner {
975+
// Any that are binding arrays we skip as we cannot load the array, we must load the result after indexing.
978976
crate::TypeInner::BindingArray { .. } => {
979977
gv.access_id = gv.var_id;
980978
}
981979
_ => {
980+
// Handle globals are pre-emitted and should be loaded automatically.
982981
if var.space == crate::AddressSpace::Handle {
983982
let var_type_id = self.get_handle_type_id(var.ty);
984983
let id = self.id_gen.next();
@@ -1064,6 +1063,7 @@ impl Writer {
10641063
}
10651064
}),
10661065
);
1066+
10671067
context
10681068
.function
10691069
.variables

naga/src/back/wgsl/writer.rs

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -993,13 +993,13 @@ impl<W: Write> Writer<W> {
993993
} => {
994994
let op_str = if store { "Store" } else { "Load" };
995995
let suffix = if row_major { "T" } else { "" };
996-
write!(self.out, "coop{op_str}{suffix}(")?;
997-
self.write_expr(module, target, func_ctx)?;
996+
write!(self.out, "{level}coop{op_str}{suffix}(")?;
997+
self.write_expr_with_indirection(module, target, func_ctx, Indirection::Reference)?;
998998
write!(self.out, ", ")?;
999999
self.write_expr(module, pointer, func_ctx)?;
10001000
write!(self.out, ", ")?;
10011001
self.write_expr(module, stride, func_ctx)?;
1002-
write!(self.out, ")")?
1002+
writeln!(self.out, ");")?
10031003
}
10041004
}
10051005

@@ -1714,11 +1714,11 @@ impl<W: Write> Writer<W> {
17141714
| Expression::WorkGroupUniformLoadResult { .. } => {}
17151715
Expression::CooperativeMultiplyAdd { a, b, c } => {
17161716
write!(self.out, "coopMultiplyAdd(")?;
1717-
self.write_expr(module, a, func_ctx)?;
1717+
self.write_expr_with_indirection(module, a, func_ctx, Indirection::Reference)?;
17181718
write!(self.out, ", ")?;
1719-
self.write_expr(module, b, func_ctx)?;
1719+
self.write_expr_with_indirection(module, b, func_ctx, Indirection::Reference)?;
17201720
write!(self.out, ", ")?;
1721-
self.write_expr(module, c, func_ctx)?;
1721+
self.write_expr_with_indirection(module, c, func_ctx, Indirection::Reference)?;
17221722
write!(self.out, ")")?;
17231723
}
17241724
}

naga/src/front/wgsl/error.rs

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -412,6 +412,7 @@ pub(crate) enum Error<'a> {
412412
TypeTooLarge {
413413
span: Span,
414414
},
415+
InvalidCooperativeMatrix,
415416
UnderspecifiedCooperativeMatrix,
416417
UnsupportedCooperativeScalar(Span),
417418
}
@@ -1388,6 +1389,11 @@ impl<'a> Error<'a> {
13881389
crate::valid::MAX_TYPE_SIZE
13891390
)],
13901391
},
1392+
Error::InvalidCooperativeMatrix => ParseError {
1393+
message: "given type is not a cooperative matrix".into(),
1394+
labels: vec![],
1395+
notes: vec![format!("must be coop_mat")],
1396+
},
13911397
Error::UnderspecifiedCooperativeMatrix => ParseError {
13921398
message: "cooperative matrix constructor is underspecified".into(),
13931399
labels: vec![],

naga/src/front/wgsl/lower/mod.rs

Lines changed: 31 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -846,6 +846,15 @@ impl<'source, 'temp, 'out> ExpressionContext<'source, 'temp, 'out> {
846846
fn ensure_type_exists(&mut self, inner: ir::TypeInner) -> Handle<ir::Type> {
847847
self.as_global().ensure_type_exists(None, inner)
848848
}
849+
850+
fn _get_runtime_expression(&self, expr: Handle<ir::Expression>) -> &ir::Expression {
851+
match self.expr_type {
852+
ExpressionContextType::Runtime(ref ctx) => &ctx.function.expressions[expr],
853+
ExpressionContextType::Constant(_) | ExpressionContextType::Override => {
854+
unreachable!()
855+
}
856+
}
857+
}
849858
}
850859

851860
struct ArgumentContext<'ctx, 'source> {
@@ -955,6 +964,13 @@ impl<T> Typed<T> {
955964
Self::Plain(expr) => Typed::Plain(f(expr)?),
956965
})
957966
}
967+
968+
fn ref_or<E>(self, error: E) -> core::result::Result<T, E> {
969+
match self {
970+
Self::Reference(v) => Ok(v),
971+
Self::Plain(_) => Err(error),
972+
}
973+
}
958974
}
959975

960976
/// A single vector component or swizzle.
@@ -1677,12 +1693,8 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
16771693
.as_expression(block, &mut emitter)
16781694
.interrupt_emitter(ir::Expression::LocalVariable(var), Span::UNDEFINED)?;
16791695
block.extend(emitter.finish(&ctx.function.expressions));
1680-
let typed = if ctx.module.types[ty].inner.is_handle() {
1681-
Typed::Plain(handle)
1682-
} else {
1683-
Typed::Reference(handle)
1684-
};
1685-
ctx.local_table.insert(v.handle, Declared::Runtime(typed));
1696+
ctx.local_table
1697+
.insert(v.handle, Declared::Runtime(Typed::Reference(handle)));
16861698

16871699
match initializer {
16881700
Some(initializer) => ir::Statement::Store {
@@ -1977,12 +1989,7 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
19771989
let value_span = ctx.ast_expressions.get_span(value);
19781990
let target = self
19791991
.expression_for_reference(value, &mut ctx.as_expression(block, &mut emitter))?;
1980-
let target_handle = match target {
1981-
Typed::Reference(handle) => handle,
1982-
Typed::Plain(_) => {
1983-
return Err(Box::new(Error::BadIncrDecrReferenceType(value_span)))
1984-
}
1985-
};
1992+
let target_handle = target.ref_or(Error::BadIncrDecrReferenceType(value_span))?;
19861993

19871994
let mut ectx = ctx.as_expression(block, &mut emitter);
19881995
let scalar = match *resolve_inner!(ectx, target_handle) {
@@ -2139,10 +2146,8 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
21392146
LoweredGlobalDecl::Var(handle) => {
21402147
let expr = ir::Expression::GlobalVariable(handle);
21412148
let v = &ctx.module.global_variables[handle];
2142-
let force_value = ctx.module.types[v.ty].inner.is_handle();
21432149
match v.space {
21442150
ir::AddressSpace::Handle => Typed::Plain(expr),
2145-
_ if force_value => Typed::Plain(expr),
21462151
_ => Typed::Reference(expr),
21472152
}
21482153
}
@@ -3140,7 +3145,9 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
31403145
let row_major = function.name.ends_with("T");
31413146

31423147
let mut args = ctx.prepare_args(arguments, 2, span);
3143-
let target = self.expression(args.next()?, ctx)?;
3148+
let target = self
3149+
.expression_for_reference(args.next()?, ctx)?
3150+
.ref_or(Error::InvalidCooperativeMatrix)?;
31443151
let pointer = self.expression(args.next()?, ctx)?;
31453152
let stride = if args.total_args > 2 {
31463153
self.expression(args.next()?, ctx)?
@@ -3178,9 +3185,15 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
31783185
}
31793186
"coopMultiplyAdd" => {
31803187
let mut args = ctx.prepare_args(arguments, 3, span);
3181-
let a = self.expression(args.next()?, ctx)?;
3182-
let b = self.expression(args.next()?, ctx)?;
3183-
let c = self.expression(args.next()?, ctx)?;
3188+
let a = self
3189+
.expression_for_reference(args.next()?, ctx)?
3190+
.ref_or(Error::InvalidCooperativeMatrix)?;
3191+
let b = self
3192+
.expression_for_reference(args.next()?, ctx)?
3193+
.ref_or(Error::InvalidCooperativeMatrix)?;
3194+
let c = self
3195+
.expression_for_reference(args.next()?, ctx)?
3196+
.ref_or(Error::InvalidCooperativeMatrix)?;
31843197
args.finish()?;
31853198

31863199
ir::Expression::CooperativeMultiplyAdd { a, b, c }

naga/src/proc/type_methods.rs

Lines changed: 0 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -191,17 +191,6 @@ impl crate::TypeInner {
191191
}
192192
}
193193

194-
/// Returns true if a variable of this type is a handle.
195-
pub const fn is_handle(&self) -> bool {
196-
match *self {
197-
Self::Image { .. }
198-
| Self::Sampler { .. }
199-
| Self::AccelerationStructure { .. }
200-
| Self::CooperativeMatrix { .. } => true,
201-
_ => false,
202-
}
203-
}
204-
205194
/// Attempt to calculate the size of this type. Returns `None` if the size
206195
/// exceeds the limit of [`crate::valid::MAX_TYPE_SIZE`].
207196
pub fn try_size(&self, gctx: super::GlobalCtx) -> Option<u32> {

naga/src/proc/typifier.rs

Lines changed: 14 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -454,8 +454,7 @@ impl<'a> ResolveContext<'a> {
454454
}
455455
crate::Expression::GlobalVariable(h) => {
456456
let var = &self.global_vars[h];
457-
let ty = &types[var.ty].inner;
458-
if var.space == crate::AddressSpace::Handle || ty.is_handle() {
457+
if var.space == crate::AddressSpace::Handle {
459458
TypeResolution::Handle(var.ty)
460459
} else {
461460
TypeResolution::Value(Ti::Pointer {
@@ -466,15 +465,10 @@ impl<'a> ResolveContext<'a> {
466465
}
467466
crate::Expression::LocalVariable(h) => {
468467
let var = &self.local_vars[h];
469-
let ty = &types[var.ty].inner;
470-
if ty.is_handle() {
471-
TypeResolution::Handle(var.ty)
472-
} else {
473-
TypeResolution::Value(Ti::Pointer {
474-
base: var.ty,
475-
space: crate::AddressSpace::Function,
476-
})
477-
}
468+
TypeResolution::Value(Ti::Pointer {
469+
base: var.ty,
470+
space: crate::AddressSpace::Function,
471+
})
478472
}
479473
crate::Expression::Load { pointer } => match *past(pointer)?.inner_with(types) {
480474
Ti::Pointer { base, space: _ } => {
@@ -807,7 +801,15 @@ impl<'a> ResolveContext<'a> {
807801
scalar: crate::Scalar::U32,
808802
size: crate::VectorSize::Quad,
809803
}),
810-
crate::Expression::CooperativeMultiplyAdd { a: _, b: _, c } => past(c)?.clone(),
804+
crate::Expression::CooperativeMultiplyAdd { a: _, b: _, c } => {
805+
match *past(c)?.inner_with(types) {
806+
Ti::Pointer { base, space: _ } => TypeResolution::Handle(base),
807+
ref other => {
808+
log::error!("Pointer type {other:?}");
809+
return Err(ResolveError::InvalidPointer(c));
810+
}
811+
}
812+
}
811813
})
812814
}
813815
}

0 commit comments

Comments
 (0)