Skip to content

Commit 4bf58f9

Browse files
committed
coop: make stride non-optional
1 parent a02e7b5 commit 4bf58f9

17 files changed

+105
-120
lines changed

naga/src/back/dot/mod.rs

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -412,9 +412,7 @@ impl StatementGraph {
412412
} => {
413413
self.dependencies.push((id, target, "target"));
414414
self.dependencies.push((id, pointer, "pointer"));
415-
if let Some(stride) = stride {
416-
self.dependencies.push((id, stride, "stride"));
417-
}
415+
self.dependencies.push((id, stride, "stride"));
418416
if store {
419417
"Store"
420418
} else {

naga/src/back/msl/writer.rs

Lines changed: 2 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -4243,30 +4243,8 @@ impl<W: Write> Writer<W> {
42434243
self.put_expression(target, &context.expression, true)?;
42444244
write!(self.out, ", ")?;
42454245
self.put_expression(pointer, &context.expression, true)?;
4246-
if stride.is_some() || row_major {
4247-
write!(self.out, ", ")?;
4248-
match stride {
4249-
Some(expression) => {
4250-
self.put_expression(expression, &context.expression, true)?;
4251-
}
4252-
None => {
4253-
let default_stride = match *context.expression.resolve_type(target)
4254-
{
4255-
crate::TypeInner::CooperativeMatrix {
4256-
columns, rows, ..
4257-
} => {
4258-
if row_major {
4259-
columns as u32
4260-
} else {
4261-
rows as u32
4262-
}
4263-
}
4264-
_ => 0,
4265-
};
4266-
write!(self.out, "{default_stride}")?;
4267-
}
4268-
}
4269-
}
4246+
write!(self.out, ", ")?;
4247+
self.put_expression(stride, &context.expression, true)?;
42704248
if row_major {
42714249
let matrix_origin = "0";
42724250
let transpose = true;

naga/src/back/pipeline_constants.rs

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -853,9 +853,7 @@ fn adjust_stmt(new_pos: &HandleVec<Expression, Handle<Expression>>, stmt: &mut S
853853
} => {
854854
adjust(target);
855855
adjust(pointer);
856-
if let Some(ref mut stride) = *stride {
857-
adjust(stride);
858-
}
856+
adjust(stride);
859857
}
860858
Statement::Break
861859
| Statement::Continue

naga/src/back/spv/block.rs

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3709,13 +3709,12 @@ impl BlockContext<'_> {
37093709
spirv::CooperativeMatrixLayout::ColumnMajorKHR
37103710
};
37113711
let layout_id = self.get_index_constant(layout as u32);
3712-
let stride_id = stride.map(|exp| self.cached[exp]);
37133712
if store {
37143713
block.body.push(Instruction::coop_store(
37153714
self.cached[target],
37163715
pointer_id,
37173716
layout_id,
3718-
stride_id,
3717+
self.cached[stride],
37193718
));
37203719
} else {
37213720
let result_type_id = self.get_expression_type_id(&self.fun_info[target].ty);
@@ -3725,7 +3724,7 @@ impl BlockContext<'_> {
37253724
id,
37263725
pointer_id,
37273726
layout_id,
3728-
stride_id,
3727+
self.cached[stride],
37293728
));
37303729
block
37313730
.body

naga/src/back/spv/instructions.rs

Lines changed: 4 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1252,33 +1252,22 @@ impl super::Instruction {
12521252
id: Word,
12531253
pointer_id: Word,
12541254
layout_id: Word,
1255-
stride_id: Option<Word>,
1255+
stride_id: Word,
12561256
) -> Self {
12571257
let mut instruction = Self::new(Op::CooperativeMatrixLoadKHR);
12581258
instruction.set_type(result_type_id);
12591259
instruction.set_result(id);
12601260
instruction.add_operand(pointer_id);
12611261
instruction.add_operand(layout_id);
1262-
if let Some(stride_id) = stride_id {
1263-
instruction.add_operand(stride_id);
1264-
}
1265-
1262+
instruction.add_operand(stride_id);
12661263
instruction
12671264
}
1268-
pub(super) fn coop_store(
1269-
id: Word,
1270-
pointer_id: Word,
1271-
layout_id: Word,
1272-
stride_id: Option<Word>,
1273-
) -> Self {
1265+
pub(super) fn coop_store(id: Word, pointer_id: Word, layout_id: Word, stride_id: Word) -> Self {
12741266
let mut instruction = Self::new(Op::CooperativeMatrixStoreKHR);
12751267
instruction.add_operand(pointer_id);
12761268
instruction.add_operand(id);
12771269
instruction.add_operand(layout_id);
1278-
if let Some(stride_id) = stride_id {
1279-
instruction.add_operand(stride_id);
1280-
}
1281-
1270+
instruction.add_operand(stride_id);
12821271
instruction
12831272
}
12841273
pub(super) fn coop_mul_add(result_type_id: Word, id: Word, a: Word, b: Word, c: Word) -> Self {

naga/src/back/wgsl/writer.rs

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -997,10 +997,8 @@ impl<W: Write> Writer<W> {
997997
self.write_expr(module, target, func_ctx)?;
998998
write!(self.out, ", ")?;
999999
self.write_expr(module, pointer, func_ctx)?;
1000-
if let Some(stride) = stride {
1001-
write!(self.out, ", ")?;
1002-
self.write_expr(module, stride, func_ctx)?;
1003-
}
1000+
write!(self.out, ", ")?;
1001+
self.write_expr(module, stride, func_ctx)?;
10041002
write!(self.out, ")")?
10051003
}
10061004
}

naga/src/compact/statements.rs

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -161,9 +161,7 @@ impl FunctionTracer<'_> {
161161
} => {
162162
self.expressions_used.insert(target);
163163
self.expressions_used.insert(pointer);
164-
if let Some(stride) = stride {
165-
self.expressions_used.insert(stride);
166-
}
164+
self.expressions_used.insert(stride);
167165
}
168166

169167
// Trivial statements.
@@ -393,9 +391,7 @@ impl FunctionMap {
393391
} => {
394392
adjust(target);
395393
adjust(pointer);
396-
if let Some(ref mut stride) = *stride {
397-
adjust(stride);
398-
}
394+
adjust(stride);
399395
}
400396

401397
// Trivial statements.

naga/src/front/wgsl/lower/mod.rs

Lines changed: 19 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3136,19 +3136,33 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
31363136
return Ok(Some(result));
31373137
}
31383138
"coopLoad" | "coopLoadT" | "coopStore" | "coopStoreT" => {
3139+
let store = function.name.contains("Store");
3140+
let row_major = function.name.ends_with("T");
3141+
31393142
let mut args = ctx.prepare_args(arguments, 2, span);
31403143
let target = self.expression(args.next()?, ctx)?;
31413144
let pointer = self.expression(args.next()?, ctx)?;
31423145
let stride = if args.total_args > 2 {
3143-
Some(self.expression(args.next()?, ctx)?)
3146+
self.expression(args.next()?, ctx)?
31443147
} else {
3145-
None
3148+
// Infer the stride from the matrix type
3149+
let stride = match *resolve_inner!(ctx, target) {
3150+
ir::TypeInner::CooperativeMatrix { columns, rows, .. } => {
3151+
if row_major {
3152+
columns as u32
3153+
} else {
3154+
rows as u32
3155+
}
3156+
}
3157+
_ => 0,
3158+
};
3159+
ctx.append_expression(
3160+
ir::Expression::Literal(ir::Literal::U32(stride)),
3161+
Span::UNDEFINED,
3162+
)?
31463163
};
31473164
args.finish()?;
31483165

3149-
let store = function.name.contains("Store");
3150-
let row_major = function.name.ends_with("T");
3151-
31523166
let rctx = ctx.runtime_expression_ctx(span)?;
31533167
rctx.block.push(
31543168
crate::Statement::CooperativeLoadStore {

naga/src/ir/mod.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2216,7 +2216,7 @@ pub enum Statement {
22162216
store: bool,
22172217
target: Handle<Expression>,
22182218
pointer: Handle<Expression>,
2219-
stride: Option<Handle<Expression>>,
2219+
stride: Handle<Expression>,
22202220
row_major: bool,
22212221
},
22222222
}

naga/src/valid/analyzer.rs

Lines changed: 10 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1162,18 +1162,16 @@ impl FunctionInfo {
11621162
pointer,
11631163
stride,
11641164
row_major: _,
1165-
} => {
1166-
if let Some(stride) = stride {
1167-
let _ = self.add_ref(stride);
1168-
}
1169-
FunctionUniformity {
1170-
result: Uniformity {
1171-
non_uniform_result: self.add_ref(target).or(self.add_ref(pointer)),
1172-
requirements: UniformityRequirements::COOP_OPS,
1173-
},
1174-
exit: ExitFlags::empty(),
1175-
}
1176-
}
1165+
} => FunctionUniformity {
1166+
result: Uniformity {
1167+
non_uniform_result: self
1168+
.add_ref(target)
1169+
.or(self.add_ref(pointer))
1170+
.or(self.add_ref(stride)),
1171+
requirements: UniformityRequirements::COOP_OPS,
1172+
},
1173+
exit: ExitFlags::empty(),
1174+
},
11771175
};
11781176

11791177
disruptor = disruptor.or(uniformity.exit_disruptor());

0 commit comments

Comments
 (0)