Skip to content

Commit

Permalink
[naga msl-out hlsl-out] Improve workaround for infinite loops causing…
Browse files Browse the repository at this point in the history
… undefined behaviour

We must ensure that all loops emitted by the naga backends will
terminate, in order to avoid undefined behaviour. This was previously
implemented for the msl backend in #6545. However, the usage of
`volatile` prevents the compiler from making other important
optimizations. This patch improves the msl workaround and additionally
implements it for hlsl. The spv implementation will be left for a
follow up.

Rather than using volatile, this patch increments a counter on every
loop iteration, breaking from the loop after 2^64 iterations. This
ensures the compiler treats the loop as finite thereby avoiding
undefined behaviour, whilst at the same time allowing for other
optimizations and in reality not actually affecting execution.

Removing the old workaround (using a volatile variable) causes subtest
17 of the subgroup_operations test to fail on the Macos 14 worker on
CI. Adding the new workaround (using a 64-bit counter) has no
additional effect. The test passes locally when tested on an M2
Macbook Pro running Macos 15.
  • Loading branch information
jamienicol committed Jan 16, 2025
1 parent 779261e commit 565fb7c
Show file tree
Hide file tree
Showing 22 changed files with 215 additions and 86 deletions.
4 changes: 4 additions & 0 deletions naga/src/back/hlsl/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -211,6 +211,9 @@ pub struct Options {
pub zero_initialize_workgroup_memory: bool,
/// Should we restrict indexing of vectors, matrices and arrays?
pub restrict_indexing: bool,
/// If set, loops will have code injected into them, forcing the compiler
/// to think the number of iterations is bounded.
pub force_loop_bounding: bool,
}

impl Default for Options {
Expand All @@ -223,6 +226,7 @@ impl Default for Options {
push_constants_target: None,
zero_initialize_workgroup_memory: true,
restrict_indexing: true,
force_loop_bounding: true,
}
}
}
Expand Down
35 changes: 35 additions & 0 deletions naga/src/back/hlsl/writer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,32 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
self.need_bake_expressions.clear();
}

/// Generates statements to be inserted immediately before and inside the
/// body of each loop, to defeat infinite loop reasoning. The 0th item
/// of the returned tuple should be inserted immediately prior to the loop
/// and the 1st item should be inserted inside the loop body.
///
/// See [`back::msl::Writer::gen_force_bounded_loop_statements`] for details.
fn gen_force_bounded_loop_statements(
&mut self,
level: back::Level,
) -> Option<(String, String)> {
if !self.options.force_loop_bounding {
return None;
}

let loop_bound_name = self.namer.call("loop_bound");
let decl = format!("{level}uint2 {loop_bound_name} = uint2(0u, 0u);");
let level = level.next();
let max = u32::MAX;
let break_and_inc = format!(
"{level}if (all({loop_bound_name} == uint2({max}u, {max}u))) {{ break; }}
{level}{loop_bound_name} += uint2({loop_bound_name}.y == {max}u, 1u);"
);

Some((decl, break_and_inc))
}

/// Helper method used to find which expressions of a given function require baking
///
/// # Notes
Expand Down Expand Up @@ -2048,6 +2074,10 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
ref continuing,
break_if,
} => {
let force_loop_bound_statements = self.gen_force_bounded_loop_statements(level);
if let Some((ref decl, _)) = force_loop_bound_statements {
writeln!(self.out, "{decl}")?;
}
self.continue_ctx.enter_loop();
let l2 = level.next();
if !continuing.is_empty() || break_if.is_some() {
Expand Down Expand Up @@ -2075,6 +2105,11 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
for sta in body.iter() {
self.write_stmt(module, sta, func_ctx, l2)?;
}

if let Some((_, ref break_and_inc)) = force_loop_bound_statements {
writeln!(self.out, "{break_and_inc}")?;
}

writeln!(self.out, "{level}}}")?;
self.continue_ctx.exit_loop();
}
Expand Down
102 changes: 50 additions & 52 deletions naga/src/back/msl/writer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -383,11 +383,6 @@ pub struct Writer<W> {
/// Set of (struct type, struct field index) denoting which fields require
/// padding inserted **before** them (i.e. between fields at index - 1 and index)
struct_member_pads: FastHashSet<(Handle<crate::Type>, u32)>,

/// Name of the force-bounded-loop macro.
///
/// See `emit_force_bounded_loop_macro` for details.
force_bounded_loop_macro_name: String,
}

impl crate::Scalar {
Expand Down Expand Up @@ -601,7 +596,7 @@ struct ExpressionContext<'a> {
/// accesses. These may need to be cached in temporary variables. See
/// `index::find_checked_indexes` for details.
guarded_indices: HandleSet<crate::Expression>,
/// See [`Writer::emit_force_bounded_loop_macro`] for details.
/// See [`Writer::gen_force_bounded_loop_statements`] for details.
force_loop_bounding: bool,
}

Expand Down Expand Up @@ -685,7 +680,6 @@ impl<W: Write> Writer<W> {
#[cfg(test)]
put_block_stack_pointers: Default::default(),
struct_member_pads: FastHashSet::default(),
force_bounded_loop_macro_name: String::default(),
}
}

Expand All @@ -696,17 +690,10 @@ impl<W: Write> Writer<W> {
self.out
}

/// Define a macro to invoke at the bottom of each loop body, to
/// defeat MSL infinite loop reasoning.
///
/// If we haven't done so already, emit the definition of a preprocessor
/// macro to be invoked at the end of each loop body in the generated MSL,
/// to ensure that the MSL compiler's optimizations do not remove bounds
/// checks.
///
/// Only the first call to this function for a given module actually causes
/// the macro definition to be written. Subsequent loops can simply use the
/// prior macro definition, since macros aren't block-scoped.
/// Generates statements to be inserted immediately before and inside the
/// body of each loop, to defeat MSL infinite loop reasoning. The 0th item
/// of the returned tuple should be inserted immediately prior to the loop
/// and the 1st item should be inserted inside the loop body.
///
/// # What is this trying to solve?
///
Expand Down Expand Up @@ -774,7 +761,8 @@ impl<W: Write> Writer<W> {
/// but which in fact generates no instructions. Unfortunately, inline
/// assembly is not handled correctly by some Metal device drivers.
///
/// Instead, we add the following code to the bottom of every loop:
/// A previously used approach was to add the following code to the bottom
/// of every loop:
///
/// ```ignore
/// if (volatile bool unpredictable = false; unpredictable)
Expand All @@ -785,37 +773,47 @@ impl<W: Write> Writer<W> {
/// the `volatile` qualifier prevents the compiler from assuming this. Thus,
/// it must assume that the `break` might be reached, and hence that the
/// loop is not unbounded. This prevents the range analysis impact described
/// above.
/// above. Unfortunately this prevented the compiler from making important,
/// and safe, optimizations such as loop unrolling and was observed to
/// significantly hurt performance.
///
/// Our current approach declares a counter before every loop and
/// increments it every iteration, breaking after 2^64 iterations:
///
/// Unfortunately, what makes this a kludge, not a hack, is that this
/// solution leaves the GPU executing a pointless conditional branch, at
/// runtime, in every iteration of the loop. There's no part of the system
/// that has a global enough view to be sure that `unpredictable` is true,
/// and remove it from the code. Adding the branch also affects
/// optimization: for example, it's impossible to unroll this loop. This
/// transformation has been observed to significantly hurt performance.
/// ```ignore
/// uint2 loop_bound = uint2(0);
/// while (true) {
/// if (metal::all(loop_bound == uint2(4294967295))) { break; }
/// loop_bound += uint2(loop_bound.y == 4294967295, 1);
/// }
/// ```
///
/// To make our output a bit more legible, we pull the condition out into a
/// preprocessor macro defined at the top of the module.
/// This convinces the compiler that the loop is finite and therefore may
/// execute, whilst at the same time allowing optimizations such as loop
/// unrolling. Furthermore the 64-bit counter is large enough it seems
/// implausible that it would affect the execution of any shader.
///
/// This approach is also used by Chromium WebGPU's Dawn shader compiler:
/// <https://dawn.googlesource.com/dawn/+/a37557db581c2b60fb1cd2c01abdb232927dd961/src/tint/lang/msl/writer/printer/printer.cc#222>
fn emit_force_bounded_loop_macro(&mut self) -> BackendResult {
if !self.force_bounded_loop_macro_name.is_empty() {
return Ok(());
/// <https://dawn.googlesource.com/dawn/+/d9e2d1f718678ebee0728b999830576c410cce0a/src/tint/lang/core/ir/transform/prevent_infinite_loops.cc>
fn gen_force_bounded_loop_statements(
&mut self,
level: back::Level,
context: &StatementContext,
) -> Option<(String, String)> {
if !context.expression.force_loop_bounding {
return None;
}

self.force_bounded_loop_macro_name = self.namer.call("LOOP_IS_BOUNDED");
let loop_bounded_volatile_name = self.namer.call("unpredictable_break_from_loop");
writeln!(
self.out,
"#define {} {{ volatile bool {} = false; if ({}) break; }}",
self.force_bounded_loop_macro_name,
loop_bounded_volatile_name,
loop_bounded_volatile_name,
)?;
let loop_bound_name = self.namer.call("loop_bound");
let decl = format!("{level}uint2 {loop_bound_name} = uint2(0u);");
let level = level.next();
let max = u32::MAX;
let break_and_inc = format!(
"{level}if ({NAMESPACE}::all({loop_bound_name} == uint2({max}u))) {{ break; }}
{level}{loop_bound_name} += uint2({loop_bound_name}.y == {max}u, 1u);"
);

Ok(())
Some((decl, break_and_inc))
}

fn put_call_parameters(
Expand Down Expand Up @@ -3083,6 +3081,11 @@ impl<W: Write> Writer<W> {
ref continuing,
break_if,
} => {
let force_loop_bound_statements =
self.gen_force_bounded_loop_statements(level, context);
if let Some((ref decl, _)) = force_loop_bound_statements {
writeln!(self.out, "{decl}")?;
}
if !continuing.is_empty() || break_if.is_some() {
let gate_name = self.namer.call("loop_init");
writeln!(self.out, "{level}bool {gate_name} = true;")?;
Expand All @@ -3104,15 +3107,11 @@ impl<W: Write> Writer<W> {
writeln!(self.out, "{level}while(true) {{",)?;
}
self.put_block(level.next(), body, context)?;
if context.expression.force_loop_bounding {
self.emit_force_bounded_loop_macro()?;
writeln!(
self.out,
"{}{}",
level.next(),
self.force_bounded_loop_macro_name
)?;

if let Some((_, ref break_and_inc)) = force_loop_bound_statements {
writeln!(self.out, "{break_and_inc}")?;
}

writeln!(self.out, "{level}}}")?;
}
crate::Statement::Break => {
Expand Down Expand Up @@ -3606,7 +3605,6 @@ impl<W: Write> Writer<W> {
&[CLAMPED_LOD_LOAD_PREFIX],
&mut self.names,
);
self.force_bounded_loop_macro_name.clear();
self.struct_member_pads.clear();

writeln!(
Expand Down
3 changes: 3 additions & 0 deletions naga/tests/out/hlsl/boids.hlsl
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ void main(uint3 global_invocation_id : SV_DispatchThreadID)
vPos = _e8;
float2 _e14 = asfloat(particlesSrc.Load2(8+index*16+0));
vVel = _e14;
uint2 loop_bound = uint2(0u, 0u);
bool loop_init = true;
while(true) {
if (!loop_init) {
Expand Down Expand Up @@ -91,6 +92,8 @@ void main(uint3 global_invocation_id : SV_DispatchThreadID)
int _e88 = cVelCount;
cVelCount = (_e88 + 1);
}
if (all(loop_bound == uint2(4294967295u, 4294967295u))) { break; }
loop_bound += uint2(loop_bound.y == 4294967295u, 1u);
}
int _e94 = cMassCount;
if ((_e94 > 0)) {
Expand Down
12 changes: 12 additions & 0 deletions naga/tests/out/hlsl/break-if.hlsl
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
void breakIfEmpty()
{
uint2 loop_bound = uint2(0u, 0u);
bool loop_init = true;
while(true) {
if (!loop_init) {
Expand All @@ -8,6 +9,8 @@ void breakIfEmpty()
}
}
loop_init = false;
if (all(loop_bound == uint2(4294967295u, 4294967295u))) { break; }
loop_bound += uint2(loop_bound.y == 4294967295u, 1u);
}
return;
}
Expand All @@ -17,6 +20,7 @@ void breakIfEmptyBody(bool a)
bool b = (bool)0;
bool c = (bool)0;

uint2 loop_bound_1 = uint2(0u, 0u);
bool loop_init_1 = true;
while(true) {
if (!loop_init_1) {
Expand All @@ -29,6 +33,8 @@ void breakIfEmptyBody(bool a)
}
}
loop_init_1 = false;
if (all(loop_bound_1 == uint2(4294967295u, 4294967295u))) { break; }
loop_bound_1 += uint2(loop_bound_1.y == 4294967295u, 1u);
}
return;
}
Expand All @@ -38,6 +44,7 @@ void breakIf(bool a_1)
bool d = (bool)0;
bool e = (bool)0;

uint2 loop_bound_2 = uint2(0u, 0u);
bool loop_init_2 = true;
while(true) {
if (!loop_init_2) {
Expand All @@ -50,6 +57,8 @@ void breakIf(bool a_1)
d = a_1;
bool _e2 = d;
e = (a_1 != _e2);
if (all(loop_bound_2 == uint2(4294967295u, 4294967295u))) { break; }
loop_bound_2 += uint2(loop_bound_2.y == 4294967295u, 1u);
}
return;
}
Expand All @@ -58,6 +67,7 @@ void breakIfSeparateVariable()
{
uint counter = 0u;

uint2 loop_bound_3 = uint2(0u, 0u);
bool loop_init_3 = true;
while(true) {
if (!loop_init_3) {
Expand All @@ -69,6 +79,8 @@ void breakIfSeparateVariable()
loop_init_3 = false;
uint _e3 = counter;
counter = (_e3 + 1u);
if (all(loop_bound_3 == uint2(4294967295u, 4294967295u))) { break; }
loop_bound_3 += uint2(loop_bound_3.y == 4294967295u, 1u);
}
return;
}
Expand Down
3 changes: 3 additions & 0 deletions naga/tests/out/hlsl/collatz.hlsl
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ uint collatz_iterations(uint n_base)
uint i = 0u;

n = n_base;
uint2 loop_bound = uint2(0u, 0u);
while(true) {
uint _e4 = n;
if ((_e4 > 1u)) {
Expand All @@ -24,6 +25,8 @@ uint collatz_iterations(uint n_base)
uint _e20 = i;
i = (_e20 + 1u);
}
if (all(loop_bound == uint2(4294967295u, 4294967295u))) { break; }
loop_bound += uint2(loop_bound.y == 4294967295u, 1u);
}
uint _e23 = i;
return _e23;
Expand Down
Loading

0 comments on commit 565fb7c

Please sign in to comment.