Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[naga msl-out hlsl-out] Improve workaround for infinite loops causing undefined behaviour #6929

Open
wants to merge 1 commit into
base: trunk
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions naga/src/back/hlsl/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -211,6 +211,9 @@ pub struct Options {
pub zero_initialize_workgroup_memory: bool,
/// Should we restrict indexing of vectors, matrices and arrays?
pub restrict_indexing: bool,
/// If set, loops will have code injected into them, forcing the compiler
/// to think the number of iterations is bounded.
pub force_loop_bounding: bool,
}

impl Default for Options {
Expand All @@ -223,6 +226,7 @@ impl Default for Options {
push_constants_target: None,
zero_initialize_workgroup_memory: true,
restrict_indexing: true,
force_loop_bounding: true,
}
}
}
Expand Down
50 changes: 44 additions & 6 deletions naga/src/back/hlsl/writer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,33 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
self.need_bake_expressions.clear();
}

/// Generates statements to be inserted immediately before and at the very
/// start of the body of each loop, to defeat infinite loop reasoning.
/// The 0th item of the returned tuple should be inserted immediately prior
/// to the loop and the 1st item should be inserted at the very start of
/// the loop body.
///
/// See [`back::msl::Writer::gen_force_bounded_loop_statements`] for details.
fn gen_force_bounded_loop_statements(
&mut self,
level: back::Level,
) -> Option<(String, String)> {
if !self.options.force_loop_bounding {
return None;
}

let loop_bound_name = self.namer.call("loop_bound");
let decl = format!("{level}uint2 {loop_bound_name} = uint2(0u, 0u);");
let level = level.next();
let max = u32::MAX;
let break_and_inc = format!(
"{level}if (all({loop_bound_name} == uint2({max}u, {max}u))) {{ break; }}
{level}{loop_bound_name} += uint2({loop_bound_name}.y == {max}u, 1u);"
);

Some((decl, break_and_inc))
}

/// Helper method used to find which expressions of a given function require baking
///
/// # Notes
Expand Down Expand Up @@ -2048,12 +2075,24 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
ref continuing,
break_if,
} => {
let force_loop_bound_statements = self.gen_force_bounded_loop_statements(level);
let gate_name = (!continuing.is_empty() || break_if.is_some())
.then(|| self.namer.call("loop_init"));

if let Some((ref decl, _)) = force_loop_bound_statements {
writeln!(self.out, "{decl}")?;
}
if let Some(ref gate_name) = gate_name {
writeln!(self.out, "{level}bool {gate_name} = true;")?;
}

self.continue_ctx.enter_loop();
writeln!(self.out, "{level}while(true) {{")?;
if let Some((_, ref break_and_inc)) = force_loop_bound_statements {
writeln!(self.out, "{break_and_inc}")?;
}
let l2 = level.next();
if !continuing.is_empty() || break_if.is_some() {
let gate_name = self.namer.call("loop_init");
writeln!(self.out, "{level}bool {gate_name} = true;")?;
writeln!(self.out, "{level}while(true) {{")?;
if let Some(gate_name) = gate_name {
writeln!(self.out, "{l2}if (!{gate_name}) {{")?;
let l3 = l2.next();
for sta in continuing.iter() {
Expand All @@ -2068,13 +2107,12 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
}
writeln!(self.out, "{l2}}}")?;
writeln!(self.out, "{l2}{gate_name} = false;")?;
} else {
writeln!(self.out, "{level}while(true) {{")?;
}

for sta in body.iter() {
self.write_stmt(module, sta, func_ctx, l2)?;
}

writeln!(self.out, "{level}}}")?;
self.continue_ctx.exit_loop();
}
Expand Down
117 changes: 59 additions & 58 deletions naga/src/back/msl/writer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -383,11 +383,6 @@ pub struct Writer<W> {
/// Set of (struct type, struct field index) denoting which fields require
/// padding inserted **before** them (i.e. between fields at index - 1 and index)
struct_member_pads: FastHashSet<(Handle<crate::Type>, u32)>,

/// Name of the force-bounded-loop macro.
///
/// See `emit_force_bounded_loop_macro` for details.
force_bounded_loop_macro_name: String,
}

impl crate::Scalar {
Expand Down Expand Up @@ -601,7 +596,7 @@ struct ExpressionContext<'a> {
/// accesses. These may need to be cached in temporary variables. See
/// `index::find_checked_indexes` for details.
guarded_indices: HandleSet<crate::Expression>,
/// See [`Writer::emit_force_bounded_loop_macro`] for details.
/// See [`Writer::gen_force_bounded_loop_statements`] for details.
force_loop_bounding: bool,
}

Expand Down Expand Up @@ -685,7 +680,6 @@ impl<W: Write> Writer<W> {
#[cfg(test)]
put_block_stack_pointers: Default::default(),
struct_member_pads: FastHashSet::default(),
force_bounded_loop_macro_name: String::default(),
}
}

Expand All @@ -696,17 +690,11 @@ impl<W: Write> Writer<W> {
self.out
}

/// Define a macro to invoke at the bottom of each loop body, to
/// defeat MSL infinite loop reasoning.
///
/// If we haven't done so already, emit the definition of a preprocessor
/// macro to be invoked at the end of each loop body in the generated MSL,
/// to ensure that the MSL compiler's optimizations do not remove bounds
/// checks.
///
/// Only the first call to this function for a given module actually causes
/// the macro definition to be written. Subsequent loops can simply use the
/// prior macro definition, since macros aren't block-scoped.
/// Generates statements to be inserted immediately before and at the very
/// start of the body of each loop, to defeat MSL infinite loop reasoning.
/// The 0th item of the returned tuple should be inserted immediately prior
/// to the loop and the 1st item should be inserted at the very start of
/// the loop body.
///
/// # What is this trying to solve?
///
Expand Down Expand Up @@ -774,7 +762,8 @@ impl<W: Write> Writer<W> {
/// but which in fact generates no instructions. Unfortunately, inline
/// assembly is not handled correctly by some Metal device drivers.
///
/// Instead, we add the following code to the bottom of every loop:
/// A previously used approach was to add the following code to the bottom
/// of every loop:
///
/// ```ignore
/// if (volatile bool unpredictable = false; unpredictable)
Expand All @@ -785,37 +774,47 @@ impl<W: Write> Writer<W> {
/// the `volatile` qualifier prevents the compiler from assuming this. Thus,
/// it must assume that the `break` might be reached, and hence that the
/// loop is not unbounded. This prevents the range analysis impact described
/// above.
/// above. Unfortunately this prevented the compiler from making important,
/// and safe, optimizations such as loop unrolling and was observed to
/// significantly hurt performance.
///
/// Unfortunately, what makes this a kludge, not a hack, is that this
/// solution leaves the GPU executing a pointless conditional branch, at
/// runtime, in every iteration of the loop. There's no part of the system
/// that has a global enough view to be sure that `unpredictable` is true,
/// and remove it from the code. Adding the branch also affects
/// optimization: for example, it's impossible to unroll this loop. This
/// transformation has been observed to significantly hurt performance.
/// Our current approach declares a counter before every loop and
/// increments it every iteration, breaking after 2^64 iterations:
///
/// ```ignore
/// uint2 loop_bound = uint2(0);
/// while (true) {
/// if (metal::all(loop_bound == uint2(4294967295))) { break; }
/// loop_bound += uint2(loop_bound.y == 4294967295, 1);
/// }
/// ```
///
/// To make our output a bit more legible, we pull the condition out into a
/// preprocessor macro defined at the top of the module.
/// This convinces the compiler that the loop is finite and therefore may
/// execute, whilst at the same time allowing optimizations such as loop
/// unrolling. Furthermore the 64-bit counter is large enough it seems
/// implausible that it would affect the execution of any shader.
///
/// This approach is also used by Chromium WebGPU's Dawn shader compiler:
/// <https://dawn.googlesource.com/dawn/+/a37557db581c2b60fb1cd2c01abdb232927dd961/src/tint/lang/msl/writer/printer/printer.cc#222>
fn emit_force_bounded_loop_macro(&mut self) -> BackendResult {
if !self.force_bounded_loop_macro_name.is_empty() {
return Ok(());
/// <https://dawn.googlesource.com/dawn/+/d9e2d1f718678ebee0728b999830576c410cce0a/src/tint/lang/core/ir/transform/prevent_infinite_loops.cc>
fn gen_force_bounded_loop_statements(
&mut self,
level: back::Level,
context: &StatementContext,
) -> Option<(String, String)> {
if !context.expression.force_loop_bounding {
return None;
}

self.force_bounded_loop_macro_name = self.namer.call("LOOP_IS_BOUNDED");
let loop_bounded_volatile_name = self.namer.call("unpredictable_break_from_loop");
writeln!(
self.out,
"#define {} {{ volatile bool {} = false; if ({}) break; }}",
self.force_bounded_loop_macro_name,
loop_bounded_volatile_name,
loop_bounded_volatile_name,
)?;
let loop_bound_name = self.namer.call("loop_bound");
let decl = format!("{level}uint2 {loop_bound_name} = uint2(0u);");
let level = level.next();
let max = u32::MAX;
let break_and_inc = format!(
"{level}if ({NAMESPACE}::all({loop_bound_name} == uint2({max}u))) {{ break; }}
{level}{loop_bound_name} += uint2({loop_bound_name}.y == {max}u, 1u);"
);

Ok(())
Some((decl, break_and_inc))
}

fn put_call_parameters(
Expand Down Expand Up @@ -3201,10 +3200,23 @@ impl<W: Write> Writer<W> {
ref continuing,
break_if,
} => {
if !continuing.is_empty() || break_if.is_some() {
let gate_name = self.namer.call("loop_init");
let force_loop_bound_statements =
self.gen_force_bounded_loop_statements(level, context);
let gate_name = (!continuing.is_empty() || break_if.is_some())
.then(|| self.namer.call("loop_init"));

if let Some((ref decl, _)) = force_loop_bound_statements {
writeln!(self.out, "{decl}")?;
}
if let Some(ref gate_name) = gate_name {
writeln!(self.out, "{level}bool {gate_name} = true;")?;
writeln!(self.out, "{level}while(true) {{",)?;
}

writeln!(self.out, "{level}while(true) {{",)?;
if let Some((_, ref break_and_inc)) = force_loop_bound_statements {
writeln!(self.out, "{break_and_inc}")?;
}
if let Some(ref gate_name) = gate_name {
let lif = level.next();
let lcontinuing = lif.next();
writeln!(self.out, "{lif}if (!{gate_name}) {{")?;
Expand All @@ -3218,19 +3230,9 @@ impl<W: Write> Writer<W> {
}
writeln!(self.out, "{lif}}}")?;
writeln!(self.out, "{lif}{gate_name} = false;")?;
} else {
writeln!(self.out, "{level}while(true) {{",)?;
}
self.put_block(level.next(), body, context)?;
if context.expression.force_loop_bounding {
self.emit_force_bounded_loop_macro()?;
writeln!(
self.out,
"{}{}",
level.next(),
self.force_bounded_loop_macro_name
)?;
}

writeln!(self.out, "{level}}}")?;
}
crate::Statement::Break => {
Expand Down Expand Up @@ -3724,7 +3726,6 @@ impl<W: Write> Writer<W> {
&[CLAMPED_LOD_LOAD_PREFIX],
&mut self.names,
);
self.force_bounded_loop_macro_name.clear();
self.struct_member_pads.clear();

writeln!(
Expand Down
3 changes: 3 additions & 0 deletions naga/tests/out/hlsl/boids.hlsl
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,11 @@ void main(uint3 global_invocation_id : SV_DispatchThreadID)
vPos = _e8;
float2 _e14 = asfloat(particlesSrc.Load2(8+index*16+0));
vVel = _e14;
uint2 loop_bound = uint2(0u, 0u);
bool loop_init = true;
while(true) {
if (all(loop_bound == uint2(4294967295u, 4294967295u))) { break; }
loop_bound += uint2(loop_bound.y == 4294967295u, 1u);
Comment on lines +47 to +48
Copy link
Member

@cwfitzgerald cwfitzgerald Jan 21, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This construct feels weird to me @jimblandy is there a reason this was the original suggestion? I would expect this to take the form of:

loop_bound.x += 1;
if (loop_bound.x == 4294967295u) {
    loop_bound.y += 1;
    if (loop_bound.y == 4294967295u) {
        break;
    }
}

I don't think this brings us into any uniformity issues compared to the above. Lifespan of variables should be the same too.

This brings us from 3 comparisons (4 as-written) and 2 additions in the hot path to just 1 comparison and 1 addition. While sure this isn't that big of a deal, we're going to be stacking this bad boy on every single loop, Additionally it may help loop bound analysis eliminate this if the first condition is simpler.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I thought I'd get some numbers to help verify this. Obvious caveat this is just one testcase on a couple of devices, but better than nothing.

I modified the hello_compute example to do 10 milllion loop iterations like so, and timed its duration:

@compute @workgroup_size(64)
fn doubleMe(@builtin(global_invocation_id) global_id: vec3<u32>) {
    var x: u32 = input[global_id.x];
    for (var i = 1u; i <= 10000000u; i++) {
      x = x + 1u;
    }
    output[global_id.x] = x;
}
No loop bounding Existing volatile workaround Current patch workaround Connor's suggestion
MSL M2 Macbook Pro 271ms 923ms 351ms 375ms
HLSL AMD Radeon Pro W6400 124ms N/A 284ms 220ms

Metal seems to have a slight preference for the way the PR is currently written. DXC for connor's suggestion. Both are significantly better than the current situation, but still significantly worse than no loop bounding at all.

The current construct makes sense to me as "emulate a u64 counter with a vec2<u32>". But I think these results give me a slight preference for switching to @cwfitzgerald's suggestion. (Though I think we should be doing == 0u rather than == 4294967295u as the comparison occurs after the increment. Not that it will really matter in practice). Can anyone think of any specific shader constructs we should test this further with? Or are we happy enough to proceed based on this - it's clearly still a performance hit, but much better than the current situation.

One thing I noticed looking at Tint's code is they do some analysis of whether a loop is finite, and only emit the workaround if required. Perhaps long term we need to do something similar to really solve the performance issues.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So I put together a gpu side benchmark and unfortunately I don't think these numbers are representative 😅 The benchmark is #6987 and can be run with cargo bench "Loop Workaround". You currently need to divide the time by 100 to get the real number.

I'm currently getting 4.45us GPU time on my AMD laptop with clock speeds locked. I get the same number for both of these shaders:

@group(0) @binding(0) var<storage, read_write> data: array<u32>;

@compute @workgroup_size(64)
fn addABunch(@builtin(global_invocation_id) global_id: vec3<u32>) {
    var x: u32 = data[global_id.x];
    for (var i = 1u; i <= 10000000u; i++) {
      x = x + 1u;
    }
    data[global_id.x] = x;
}
@group(0) @binding(0) var<storage, read_write> data: array<u32>;

@compute @workgroup_size(64)
fn addABunch(@builtin(global_invocation_id) global_id: vec3<u32>) {
    var x: u32 = data[global_id.x];
    data[global_id.x] = x * 2;
}

Ooops....

Going to look into this a smidge more to make sure the numbers are really what's going on and to see if there's some math we can do to preserve the loop...

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Alright, I've updated the PR to have this shader which the compiler can't see through the body of, but should be able to easily see through the bounds check of, and now it take a nice rock solid 74ms on my machine. I've pushed these changes to the PR I linked. If you rebase/cherry pick this on top of your changes, you should be able to see the difference.

@group(0) @binding(0) var<storage, read_write> data: array<u32>;

@compute @workgroup_size(64)
fn addABunch(@builtin(global_invocation_id) global_id: vec3<u32>) {
    var x: u32 = data[global_id.x];
    for (var i = 1u; i <= 100000u; i++) {
      x = u32(sin(f32(x * 120u)));
    }
    data[global_id.x] = x;
}

Copy link
Member

@cwfitzgerald cwfitzgerald Jan 25, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Running it without locked clocks maxes out my gpu clocks and I get a stable 18ms runtime (locked clocks are 700mhz, boost clocks ~2800mhz)

Copy link
Member

@cwfitzgerald cwfitzgerald Jan 25, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ran this benchmark on my M1:

No Workaround Current Workaround This PR My Suggestion
M1 Mini 143ms 157ms 162ms 183ms
RTX 4070 7.24ms - 9.5ms 10.1ms

So it seems like both this PR and my idea are significantly worse!

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have pushed to my fork to make this easier to reproduce:

if (!loop_init) {
uint _e91 = i;
i = (_e91 + 1u);
Expand Down
12 changes: 12 additions & 0 deletions naga/tests/out/hlsl/break-if.hlsl
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
void breakIfEmpty()
{
uint2 loop_bound = uint2(0u, 0u);
bool loop_init = true;
while(true) {
if (all(loop_bound == uint2(4294967295u, 4294967295u))) { break; }
loop_bound += uint2(loop_bound.y == 4294967295u, 1u);
if (!loop_init) {
if (true) {
break;
Expand All @@ -17,8 +20,11 @@ void breakIfEmptyBody(bool a)
bool b = (bool)0;
bool c = (bool)0;

uint2 loop_bound_1 = uint2(0u, 0u);
bool loop_init_1 = true;
while(true) {
if (all(loop_bound_1 == uint2(4294967295u, 4294967295u))) { break; }
loop_bound_1 += uint2(loop_bound_1.y == 4294967295u, 1u);
if (!loop_init_1) {
b = a;
bool _e2 = b;
Expand All @@ -38,8 +44,11 @@ void breakIf(bool a_1)
bool d = (bool)0;
bool e = (bool)0;

uint2 loop_bound_2 = uint2(0u, 0u);
bool loop_init_2 = true;
while(true) {
if (all(loop_bound_2 == uint2(4294967295u, 4294967295u))) { break; }
loop_bound_2 += uint2(loop_bound_2.y == 4294967295u, 1u);
if (!loop_init_2) {
bool _e5 = e;
if ((a_1 == _e5)) {
Expand All @@ -58,8 +67,11 @@ void breakIfSeparateVariable()
{
uint counter = 0u;

uint2 loop_bound_3 = uint2(0u, 0u);
bool loop_init_3 = true;
while(true) {
if (all(loop_bound_3 == uint2(4294967295u, 4294967295u))) { break; }
loop_bound_3 += uint2(loop_bound_3.y == 4294967295u, 1u);
if (!loop_init_3) {
uint _e5 = counter;
if ((_e5 == 5u)) {
Expand Down
3 changes: 3 additions & 0 deletions naga/tests/out/hlsl/collatz.hlsl
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,10 @@ uint collatz_iterations(uint n_base)
uint i = 0u;

n = n_base;
uint2 loop_bound = uint2(0u, 0u);
while(true) {
if (all(loop_bound == uint2(4294967295u, 4294967295u))) { break; }
loop_bound += uint2(loop_bound.y == 4294967295u, 1u);
uint _e4 = n;
if ((_e4 > 1u)) {
} else {
Expand Down
Loading
Loading