Skip to content

Ideogram4: Ideogram4MRoPE breaks under torch.autocast: all image positions collapse, producing flat single-color images #13920

@dxqb

Description

@dxqb

Describe the bug

Ideogram4MRoPE.forward computes the rotary frequencies with a matmul:

freqs = inv_freq @ pos.unsqueeze(2)

torch.matmul is on autocast's lower-precision op list, so when the transformer's forward runs inside
torch.autocast("cuda", torch.bfloat16) — common in training code, and users also frequently wrap pipeline calls in
autocast — this matmul executes in bfloat16 even though the module explicitly casts both operands to float32.

Ideogram4's image-token positions are IMAGE_POSITION_OFFSET + (t, h, w) = 65536…65536+grid. bfloat16 has an 8-bit
mantissa, so its representable step at 65536 is 512: every image position in a ≤512-wide grid rounds to exactly
65536
. All image tokens receive identical rotary embeddings, the model loses all spatial information, and sampling
degenerates to a near-uniform velocity field — the decoded image is a single flat color (with faint patch-sized
texture). Text positions (0…2047) are also quantized (step 8–16 above 1024), but the image-position collapse is total.

This was painful to find because it is invisible at the weight/IO level: text encoding, packing, scheduler, and the
checkpoint all check out, and the same code produces correct images as soon as the autocast context is removed.
Tracing per-module outputs (autocast vs. no autocast) shows rotary_emb as the first divergent module, with
max |Δcos| ≈ 2.0 (i.e. fully flipped rotations).

Other diffusers RoPE implementations that matmul/outer raw position values are exposed to the same class of bug, but Ideogram4 is uniquely catastrophic because of the 65536 position offset.


Drafted by Claude

The rationale for Ideogram4 using position offsets >= 65536 is here:

  • Definition: venv/src/diffusers/src/diffusers/models/transformers/transformer_ideogram4.py:42 — IMAGE_POSITION_OFFSET = 65536, with the upstream comment "Image grid coordinates start at this offset so they never collide with text token indices."

Reproduction

Weight-free, runs in under a second:

import torch
from diffusers.models.transformers.transformer_ideogram4 import IMAGE_POSITION_OFFSET, Ideogram4MRoPE

rope = Ideogram4MRoPE(head_dim=256, base=5_000_000, mrope_section=(24, 20, 20)).to("cuda")

# three image-grid positions as built by Ideogram4Pipeline._prepare_ids: (t, h, w) + IMAGE_POSITION_OFFSET
pos = torch.tensor([[[0, 0, 0], [0, 0, 1], [0, 63, 63]]], device="cuda") + IMAGE_POSITION_OFFSET

cos_ref, sin_ref = rope(pos)
with torch.autocast(device_type="cuda", dtype=torch.bfloat16):
    cos_ac, sin_ac = rope(pos)

print("fp32:     token(0,0) == token(0,1):", torch.equal(cos_ref[0, 0], cos_ref[0, 1]))
print("autocast: token(0,0) == token(0,1):", torch.equal(cos_ac[0, 0], cos_ac[0, 1]))
print("autocast: token(0,0) == token(63,63):", torch.equal(cos_ac[0, 0], cos_ac[0, 2]))
print("max |cos_autocast - cos_fp32|:", (cos_ac - cos_ref).abs().max().item())

Output:

fp32:     token(0,0) == token(0,1): False
autocast: token(0,0) == token(0,1): True
autocast: token(0,0) == token(63,63): True
max |cos_autocast - cos_fp32|: 1.933509349822998

End-to-end: running the Ideogram4Pipeline denoising loop inside torch.autocast("cuda", torch.bfloat16) produces a
flat single-color image; the identical call without autocast (same seed, same weights) renders correctly.

Expected behavior

Rotary embeddings should be computed in float32 regardless of an ambient autocast context. transformers guards all of
its rotary embedding modules for exactly this reason, e.g. Qwen3VL
(transformers/models/qwen3_vl/modeling_qwen3_vl.py):

with maybe_autocast(device_type=device_type, enabled=False):  # Force float32

(That guard is also why the Qwen3-VL text encoder used by this very pipeline is unaffected — only the diffusers-side
DiT MRoPE breaks.)

Suggested fix in Ideogram4MRoPE.forward:

def forward(self, position_ids: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
    ...
    device_type = position_ids.device.type if position_ids.device.type != "mps" else "cpu"
    with torch.autocast(device_type=device_type, enabled=False):
        pos = position_ids.permute(2, 0, 1).to(dtype=torch.float32)
        inv_freq = self.inv_freq.to(dtype=torch.float32)[None, None, :, None].expand(3, batch_size, -1, 1)
        freqs = inv_freq @ pos.unsqueeze(2)
        ...

Logs

System Info

  • diffusers 0.39.0.dev0 (main @ 9a0aaba36)
  • torch 2.12.0+cu130, CUDA, NVIDIA GeForce RTX 4070 Ti SUPER
  • transformers 5.9.0
  • Python 3.12, Linux

Who can help?

Transformers/Attention @DN6 @yiyixuxu @sayakpaul

cc @bghira

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't workingmodels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions