Skip to content

Commit ab2de4b

Browse files
committed
coop: rewire WGSL support using references
1 parent 4bf58f9 commit ab2de4b

File tree

5 files changed

+39
-27
lines changed

5 files changed

+39
-27
lines changed

CHANGELOG.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -166,7 +166,7 @@ By @cwfitzgerald in [#8162](https://github.com/gfx-rs/wgpu/pull/8162).
166166

167167
- Added support for external textures based on WebGPU's [`GPUExternalTexture`](https://www.w3.org/TR/webgpu/#gpuexternaltexture). These allow shaders to transparently operate on potentially multiplanar source texture data in either RGB or YCbCr formats via WGSL's `texture_external` type. This is gated behind the `Features::EXTERNAL_TEXTURE` feature, which is currently only supported on DX12. By @jamienicol in [#4386](https://github.com/gfx-rs/wgpu/issues/4386).
168168

169-
- Added support for cooperative load/store operations in shaders. Currently only WGSL on the input and SPIR-V with METAL on the output are supported. By @kvark in [#8251](https://github.com/gfx-rs/wgpu/issues/8251).
169+
- Added support for cooperative load/store operations in shaders. Currently only WGSL on the input and SPIR-V,METAL, and WGSL on the output are supported. By @kvark in [#8251](https://github.com/gfx-rs/wgpu/issues/8251).
170170

171171
### Changes
172172

naga/src/back/wgsl/writer.rs

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -993,13 +993,13 @@ impl<W: Write> Writer<W> {
993993
} => {
994994
let op_str = if store { "Store" } else { "Load" };
995995
let suffix = if row_major { "T" } else { "" };
996-
write!(self.out, "coop{op_str}{suffix}(")?;
997-
self.write_expr(module, target, func_ctx)?;
996+
write!(self.out, "{level}coop{op_str}{suffix}(")?;
997+
self.write_expr_with_indirection(module, target, func_ctx, Indirection::Reference)?;
998998
write!(self.out, ", ")?;
999999
self.write_expr(module, pointer, func_ctx)?;
10001000
write!(self.out, ", ")?;
10011001
self.write_expr(module, stride, func_ctx)?;
1002-
write!(self.out, ")")?
1002+
writeln!(self.out, ");")?
10031003
}
10041004
}
10051005

@@ -1714,11 +1714,11 @@ impl<W: Write> Writer<W> {
17141714
| Expression::WorkGroupUniformLoadResult { .. } => {}
17151715
Expression::CooperativeMultiplyAdd { a, b, c } => {
17161716
write!(self.out, "coopMultiplyAdd(")?;
1717-
self.write_expr(module, a, func_ctx)?;
1717+
self.write_expr_with_indirection(module, a, func_ctx, Indirection::Reference)?;
17181718
write!(self.out, ", ")?;
1719-
self.write_expr(module, b, func_ctx)?;
1719+
self.write_expr_with_indirection(module, b, func_ctx, Indirection::Reference)?;
17201720
write!(self.out, ", ")?;
1721-
self.write_expr(module, c, func_ctx)?;
1721+
self.write_expr_with_indirection(module, c, func_ctx, Indirection::Reference)?;
17221722
write!(self.out, ")")?;
17231723
}
17241724
}

naga/src/front/wgsl/error.rs

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -412,6 +412,7 @@ pub(crate) enum Error<'a> {
412412
TypeTooLarge {
413413
span: Span,
414414
},
415+
InvalidCooperativeMatrix,
415416
UnderspecifiedCooperativeMatrix,
416417
UnsupportedCooperativeScalar(Span),
417418
}
@@ -1388,6 +1389,11 @@ impl<'a> Error<'a> {
13881389
crate::valid::MAX_TYPE_SIZE
13891390
)],
13901391
},
1392+
Error::InvalidCooperativeMatrix => ParseError {
1393+
message: "given type is not a cooperative matrix".into(),
1394+
labels: vec![],
1395+
notes: vec![format!("must be coop_mat")],
1396+
},
13911397
Error::UnderspecifiedCooperativeMatrix => ParseError {
13921398
message: "cooperative matrix constructor is underspecified".into(),
13931399
labels: vec![],

naga/src/front/wgsl/lower/mod.rs

Lines changed: 22 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -955,6 +955,13 @@ impl<T> Typed<T> {
955955
Self::Plain(expr) => Typed::Plain(f(expr)?),
956956
})
957957
}
958+
959+
fn ref_or<E>(self, error: E) -> core::result::Result<T, E> {
960+
match self {
961+
Self::Reference(v) => Ok(v),
962+
Self::Plain(_) => Err(error),
963+
}
964+
}
958965
}
959966

960967
/// A single vector component or swizzle.
@@ -1677,12 +1684,8 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
16771684
.as_expression(block, &mut emitter)
16781685
.interrupt_emitter(ir::Expression::LocalVariable(var), Span::UNDEFINED)?;
16791686
block.extend(emitter.finish(&ctx.function.expressions));
1680-
let typed = if ctx.module.types[ty].inner.is_handle() {
1681-
Typed::Plain(handle)
1682-
} else {
1683-
Typed::Reference(handle)
1684-
};
1685-
ctx.local_table.insert(v.handle, Declared::Runtime(typed));
1687+
ctx.local_table
1688+
.insert(v.handle, Declared::Runtime(Typed::Reference(handle)));
16861689

16871690
match initializer {
16881691
Some(initializer) => ir::Statement::Store {
@@ -1977,12 +1980,7 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
19771980
let value_span = ctx.ast_expressions.get_span(value);
19781981
let target = self
19791982
.expression_for_reference(value, &mut ctx.as_expression(block, &mut emitter))?;
1980-
let target_handle = match target {
1981-
Typed::Reference(handle) => handle,
1982-
Typed::Plain(_) => {
1983-
return Err(Box::new(Error::BadIncrDecrReferenceType(value_span)))
1984-
}
1985-
};
1983+
let target_handle = target.ref_or(Error::BadIncrDecrReferenceType(value_span))?;
19861984

19871985
let mut ectx = ctx.as_expression(block, &mut emitter);
19881986
let scalar = match *resolve_inner!(ectx, target_handle) {
@@ -2139,10 +2137,8 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
21392137
LoweredGlobalDecl::Var(handle) => {
21402138
let expr = ir::Expression::GlobalVariable(handle);
21412139
let v = &ctx.module.global_variables[handle];
2142-
let force_value = ctx.module.types[v.ty].inner.is_handle();
21432140
match v.space {
21442141
ir::AddressSpace::Handle => Typed::Plain(expr),
2145-
_ if force_value => Typed::Plain(expr),
21462142
_ => Typed::Reference(expr),
21472143
}
21482144
}
@@ -3140,7 +3136,9 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
31403136
let row_major = function.name.ends_with("T");
31413137

31423138
let mut args = ctx.prepare_args(arguments, 2, span);
3143-
let target = self.expression(args.next()?, ctx)?;
3139+
let target = self
3140+
.expression_for_reference(args.next()?, ctx)?
3141+
.ref_or(Error::InvalidCooperativeMatrix)?;
31443142
let pointer = self.expression(args.next()?, ctx)?;
31453143
let stride = if args.total_args > 2 {
31463144
self.expression(args.next()?, ctx)?
@@ -3178,9 +3176,15 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
31783176
}
31793177
"coopMultiplyAdd" => {
31803178
let mut args = ctx.prepare_args(arguments, 3, span);
3181-
let a = self.expression(args.next()?, ctx)?;
3182-
let b = self.expression(args.next()?, ctx)?;
3183-
let c = self.expression(args.next()?, ctx)?;
3179+
let a = self
3180+
.expression_for_reference(args.next()?, ctx)?
3181+
.ref_or(Error::InvalidCooperativeMatrix)?;
3182+
let b = self
3183+
.expression_for_reference(args.next()?, ctx)?
3184+
.ref_or(Error::InvalidCooperativeMatrix)?;
3185+
let c = self
3186+
.expression_for_reference(args.next()?, ctx)?
3187+
.ref_or(Error::InvalidCooperativeMatrix)?;
31843188
args.finish()?;
31853189

31863190
ir::Expression::CooperativeMultiplyAdd { a, b, c }

naga/tests/out/wgsl/wgsl-cooperative-matrix.wgsl

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@ fn main() {
88
var c: coop_mat8x8<f32,C> = coop_mat8x8<f32,C>();
99
var d: coop_mat8x8<f32,C>;
1010

11-
coopLoad((&c), (&ext[4]), 8u) d = coopMultiplyAdd((&a), (&b), (&c));
12-
coopStore((&c), (&ext[0]), 8u) return;
11+
coopLoad(c, (&ext[4]), 8u);
12+
d = coopMultiplyAdd(a, b, c);
13+
coopStore(c, (&ext[0]), 8u);
14+
return;
1315
}

0 commit comments

Comments
 (0)