Skip to content

[Feature] Support QwenImageEditPlus series attention mask for NPU #13016

@zhangtao0408

Description

@zhangtao0408

Problem related

# Construct joint attention mask once to avoid reconstructing in every block
# This eliminates 60 GPU syncs during training while maintaining torch.compile compatibility
block_attention_kwargs = attention_kwargs.copy() if attention_kwargs is not None else {}
if encoder_hidden_states_mask is not None:
# Build joint mask: [text_mask, all_ones_for_image]
batch_size, image_seq_len = hidden_states.shape[:2]
image_mask = torch.ones((batch_size, image_seq_len), dtype=torch.bool, device=hidden_states.device)
joint_attention_mask = torch.cat([encoder_hidden_states_mask, image_mask], dim=1)
block_attention_kwargs["attention_mask"] = joint_attention_mask

Since PR #12702 introduced the attention mask to QwenImageEditPlus Series, the current _native_npu backend attention implementation does not support passing in the attention mask, which causes an error.

def _native_npu_attention(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attn_mask: Optional[torch.Tensor] = None,
dropout_p: float = 0.0,
scale: Optional[float] = None,
return_lse: bool = False,
_parallel_config: Optional["ParallelConfig"] = None,
) -> torch.Tensor:
if attn_mask is not None:
raise ValueError("`attn_mask` is not supported for NPU attention")
if return_lse:
raise ValueError("NPU attention backend does not support setting `return_lse=True`.")

We would like to enable NPU support for QwenImageEditPlus. Based on a printed check, the mask currently contains all 1s (full attention). Is it possible to use a workaround to bypass this limitation so that Qwen-Image-Edit-Plus can run normally on the NPU?

Solution

The npu_fusion_attention function supports several shapes of attention masks. We need to add a check to verify if the attention_mask is supported by the NPU backend:

  1. If the mask consists entirely of 1s (indicating full attention), we can pass None as the attention_mask to implement full attention in FA.
  2. Otherwise, we should refer to the official documentation to add a validation check and determine whether the attention mask is valid for npu_fusion_attention.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions