diff --git a/src/diffusers/models/transformers/transformer_z_image.py b/src/diffusers/models/transformers/transformer_z_image.py index 5983c34ab640..085e9000d7bd 100644 --- a/src/diffusers/models/transformers/transformer_z_image.py +++ b/src/diffusers/models/transformers/transformer_z_image.py @@ -789,9 +789,12 @@ def _prepare_sequence( freqs_cis = pad_sequence(freqs_cis, batch_first=True, padding_value=0.0)[:, : feats.shape[1]] # Attention mask - attn_mask = torch.zeros((bsz, max_seqlen), dtype=torch.bool, device=device) - for i, seq_len in enumerate(item_seqlens): - attn_mask[i, :seq_len] = 1 + if all(seq == max_seqlen for seq in item_seqlens): + attn_mask = None + else: + attn_mask = torch.zeros((bsz, max_seqlen), dtype=torch.bool, device=device) + for i, seq_len in enumerate(item_seqlens): + attn_mask[i, :seq_len] = 1 # Noise mask noise_mask_tensor = None @@ -872,9 +875,12 @@ def _build_unified_sequence( unified_freqs = pad_sequence(unified_freqs, batch_first=True, padding_value=0.0) # Attention mask - attn_mask = torch.zeros((bsz, max_seqlen), dtype=torch.bool, device=device) - for i, seq_len in enumerate(unified_seqlens): - attn_mask[i, :seq_len] = 1 + if all(seq == max_seqlen for seq in unified_seqlens): + attn_mask = None + else: + attn_mask = torch.zeros((bsz, max_seqlen), dtype=torch.bool, device=device) + for i, seq_len in enumerate(unified_seqlens): + attn_mask[i, :seq_len] = 1 # Noise mask noise_mask_tensor = None