-
Notifications
You must be signed in to change notification settings - Fork 10
Description
Summary
Using Julia's dot-broadcast syntax (.+, .*, ./) on tiles of different shapes inside a while loop triggers an InvalidTerminatorError from IRStructurizer. The error occurs because Julia's type inference sees a Union{Broadcasted{...}, FloatTile{...}} in the loop's phi nodes, and validate_terminators rejects the type mismatch between the then and else branches of the generated IfOp.
Minimal reproducer
using CUDA
import cuTile as ct
function rms_norm_kernel(X::ct.TileArray{Float32, 2}, Y::ct.TileArray{Float32, 2},
W::ct.TileArray{Float32, 1}, M::Int, N::Int) where {T}
bid = ct.bid(1)
upper_bound = cld(M, 16)
# Load weight outside loop and reshape for broadcasting
w = ct.load(W, (Int32(1),), (256,))
w_2d = reshape(w, (1, 256)) # shape: (1, 256)
num_tile_blocks = ct.num_blocks(1)
current_bid = bid
while current_bid <= upper_bound
x = ct.load(X, (current_bid, Int32(1)), (16, 256)) # shape: (16, 256)
# This line triggers InvalidTerminatorError:
# dot-broadcast between (16, 256) and (1, 256) creates a lazy Broadcasted type
y = x .* w_2d
ct.store(Y, (current_bid, Int32(1)), y)
current_bid += num_tile_blocks
end
return
end
# Launch (assuming appropriate CuArrays are set up)
# ct.launch(rms_norm_kernel, num_blocks, X_cu, Y_cu, W_cu, ct.Constant(M), ct.Constant(N))Error message
InvalidTerminatorError: IfOp at %82: yield type mismatch at position 1
(then: cuTile.FloatTile{Tuple{16, 256}, Float32},
else: Union{
Base.Broadcast.Broadcasted{<:Base.Broadcast.BroadcastStyle, Nothing, typeof(*),
<:Tuple{Any, cuTile.FloatTile{Tuple{1, 256}, Float32}}},
cuTile.FloatTile{Tuple{16, 256}, Float32}
})
Full stacktrace:
Stacktrace:
[1] validate_terminators
@ IRStructurizer/src/validation.jl:118
[2] IRStructurizer.StructuredIRCode(ir::Compiler.IRCode; ...)
@ IRStructurizer/src/ir.jl:458
[3] emit_ir(cache, mi; const_argtypes)
@ cuTile src/compiler/interface.jl:342
...
Analysis
Julia's dot-broadcast works in two steps:
- Construct a lazy
Broadcastedobject (records the operation without executing it) - Materialize via
Base.copy(bc::Broadcasted{TileStyle})into a concreteTile
cuTile.jl's broadcast.jl correctly implements Base.copy — materialization works fine in straight-line code. The problem is at the Julia type inference level: when building the SSA IR for a while loop, the compiler needs to assign a unified type to the phi nodes (variables that carry values across loop iterations). When a variable is produced by a dot-broadcast expression, Julia's type inference sees the pre-materialization type (Broadcasted{...}) as a possibility alongside the post-materialization type (FloatTile{...}), resulting in a Union type. IRStructurizer.validate_terminators then rejects this Union because the IfOp's then and else branches yield different types.
Key observation: the issue only occurs when the broadcast involves different-shaped tiles (e.g., (16, 256) .* (1, 256)). Same-shape operations and reductions (sum, maximum) that always return concrete Tile types work fine inside loops.
Current workaround
Replace dot-broadcast with explicit map + ct.broadcast_to:
# Outside the loop: pre-broadcast to full tile shape
w_full = ct.broadcast_to(w_2d, (16, 256)) # (1, 256) → (16, 256), concrete Tile
while current_bid <= upper_bound
x = ct.load(X, (current_bid, Int32(1)), (16, 256))
# map requires same-shape tiles, always returns concrete Tile type
y = map(*, x, w_full)
ct.store(Y, (current_bid, Int32(1)), y)
current_bid += num_tile_blocks
endThis works but is unergonomic — it forces users to manually manage shapes and avoid Julia's idiomatic broadcast syntax inside any while loop.
Expected behavior
The natural dot-broadcast syntax should work inside while loops just as it does in straight-line code:
while current_bid <= upper_bound
x = ct.load(X, (current_bid, Int32(1)), (16, 256))
y = x .* w_2d # should just work
ct.store(Y, (current_bid, Int32(1)), y)
current_bid += num_tile_blocks
endEnvironment
- cuTile.jl: from
cuTile.jl/submodule (commit on local dev branch) - Julia: 1.12.5
- IRStructurizer: version from
~/.julia/packages/IRStructurizer/fcUqd/ - GPU: NVIDIA RTX PRO 6000 Blackwell Server Edition (sm_120)
- CUDA: 13.1