Skip to content

InvalidTerminatorError: dot-broadcast inside while loop causes yield type mismatch #104

@0xtaruhi

Description

@0xtaruhi

Summary

Using Julia's dot-broadcast syntax (.+, .*, ./) on tiles of different shapes inside a while loop triggers an InvalidTerminatorError from IRStructurizer. The error occurs because Julia's type inference sees a Union{Broadcasted{...}, FloatTile{...}} in the loop's phi nodes, and validate_terminators rejects the type mismatch between the then and else branches of the generated IfOp.

Minimal reproducer

using CUDA
import cuTile as ct

function rms_norm_kernel(X::ct.TileArray{Float32, 2}, Y::ct.TileArray{Float32, 2},
                         W::ct.TileArray{Float32, 1}, M::Int, N::Int) where {T}
    bid = ct.bid(1)
    upper_bound = cld(M, 16)

    # Load weight outside loop and reshape for broadcasting
    w = ct.load(W, (Int32(1),), (256,))
    w_2d = reshape(w, (1, 256))  # shape: (1, 256)

    num_tile_blocks = ct.num_blocks(1)
    current_bid = bid
    while current_bid <= upper_bound
        x = ct.load(X, (current_bid, Int32(1)), (16, 256))  # shape: (16, 256)

        # This line triggers InvalidTerminatorError:
        # dot-broadcast between (16, 256) and (1, 256) creates a lazy Broadcasted type
        y = x .* w_2d

        ct.store(Y, (current_bid, Int32(1)), y)
        current_bid += num_tile_blocks
    end
    return
end

# Launch (assuming appropriate CuArrays are set up)
# ct.launch(rms_norm_kernel, num_blocks, X_cu, Y_cu, W_cu, ct.Constant(M), ct.Constant(N))

Error message

InvalidTerminatorError: IfOp at %82: yield type mismatch at position 1
  (then: cuTile.FloatTile{Tuple{16, 256}, Float32},
   else: Union{
     Base.Broadcast.Broadcasted{<:Base.Broadcast.BroadcastStyle, Nothing, typeof(*),
       <:Tuple{Any, cuTile.FloatTile{Tuple{1, 256}, Float32}}},
     cuTile.FloatTile{Tuple{16, 256}, Float32}
   })

Full stacktrace:

Stacktrace:
  [1] validate_terminators
    @ IRStructurizer/src/validation.jl:118
  [2] IRStructurizer.StructuredIRCode(ir::Compiler.IRCode; ...)
    @ IRStructurizer/src/ir.jl:458
  [3] emit_ir(cache, mi; const_argtypes)
    @ cuTile src/compiler/interface.jl:342
  ...

Analysis

Julia's dot-broadcast works in two steps:

  1. Construct a lazy Broadcasted object (records the operation without executing it)
  2. Materialize via Base.copy(bc::Broadcasted{TileStyle}) into a concrete Tile

cuTile.jl's broadcast.jl correctly implements Base.copy — materialization works fine in straight-line code. The problem is at the Julia type inference level: when building the SSA IR for a while loop, the compiler needs to assign a unified type to the phi nodes (variables that carry values across loop iterations). When a variable is produced by a dot-broadcast expression, Julia's type inference sees the pre-materialization type (Broadcasted{...}) as a possibility alongside the post-materialization type (FloatTile{...}), resulting in a Union type. IRStructurizer.validate_terminators then rejects this Union because the IfOp's then and else branches yield different types.

Key observation: the issue only occurs when the broadcast involves different-shaped tiles (e.g., (16, 256) .* (1, 256)). Same-shape operations and reductions (sum, maximum) that always return concrete Tile types work fine inside loops.

Current workaround

Replace dot-broadcast with explicit map + ct.broadcast_to:

# Outside the loop: pre-broadcast to full tile shape
w_full = ct.broadcast_to(w_2d, (16, 256))  # (1, 256) → (16, 256), concrete Tile

while current_bid <= upper_bound
    x = ct.load(X, (current_bid, Int32(1)), (16, 256))

    # map requires same-shape tiles, always returns concrete Tile type
    y = map(*, x, w_full)

    ct.store(Y, (current_bid, Int32(1)), y)
    current_bid += num_tile_blocks
end

This works but is unergonomic — it forces users to manually manage shapes and avoid Julia's idiomatic broadcast syntax inside any while loop.

Expected behavior

The natural dot-broadcast syntax should work inside while loops just as it does in straight-line code:

while current_bid <= upper_bound
    x = ct.load(X, (current_bid, Int32(1)), (16, 256))
    y = x .* w_2d   # should just work
    ct.store(Y, (current_bid, Int32(1)), y)
    current_bid += num_tile_blocks
end

Environment

  • cuTile.jl: from cuTile.jl/ submodule (commit on local dev branch)
  • Julia: 1.12.5
  • IRStructurizer: version from ~/.julia/packages/IRStructurizer/fcUqd/
  • GPU: NVIDIA RTX PRO 6000 Blackwell Server Edition (sm_120)
  • CUDA: 13.1

Metadata

Metadata

Assignees

No one assigned

    Type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions