Skip to content

[Bug] QwenImageEditPlus series error Expected all tensors to be on the same device on NPU #13015

@zhangtao0408

Description

@zhangtao0408

Describe the bug

While testing Qwen-Image-Edit-2509, the 'Expected all tensors to be on the same device' error occurred within compute_text_seq_len_from_mask

per_sample_len = torch.where(has_active, active_positions.max(dim=1).values + 1, torch.as_tensor(text_seq_len))

related to: #12702

Reproduction

import torch
import torch_npu
import torch.distributed as dist

import os, time
from PIL import Image
from diffusers import QwenImageEditPlusPipeline, ContextParallelConfig
from diffusers.utils import load_image

# Initialize Env
rank = int(os.getenv("RANK", 0))
world_size = int(os.getenv("WORLD_SIZE", 1))

if world_size > 1 and not dist.is_initialized():
	dist.init_process_group(backend="hccl")
	rank = dist.get_rank()
	device = torch.device("npu", rank % torch.npu.device_count())
	torch.npu.set_device(device)
else:
    device='npu'

image1 = load_image("https://github.com/vipshop/cache-dit/raw/main/examples/data/edit2509_1.jpg")
image2 = load_image("https://github.com/vipshop/cache-dit/raw/main/examples/data/edit2509_2.jpg")
prompt = "The magician bear is on the left, the alchemist bear is on the right, facing each other in the central park square"

pipe = QwenImageEditPlusPipeline.from_pretrained(
    "/PATH/TO/Qwen-Image-Edit-2509",
    torch_dtype=torch.bfloat16
).to(device)
pipe.transformer.set_attention_backend("_native_npu")

pipe.set_progress_bar_config(disable=rank != 0)
pipe.enable_model_cpu_offload(device=device)

if world_size > 1:
    pipe.transformer.enable_parallelism(
        config=ContextParallelConfig(ulysses_degree=world_size)
    )

with torch.inference_mode():
    torch.npu.synchronize()
    start_time = time.time()
    output = pipe(
        image=[image1, image2],
        prompt=prompt,
        generator=torch.Generator(device="cpu").manual_seed(0),
        true_cfg_scale=4.0,
        negative_prompt=" ",
        num_inference_steps=20,
        num_images_per_prompt=1,
        height=1024,
        width=1024,
    )
    torch.npu.synchronize()
    end_time = time.time()
    
    inference_time = end_time - start_time
    if rank == 0:
        output_image = output.images[0]
        output_image.save(f"qwen-image-ulysses{world_size}-time{inference_time:.2f}s.png")
        print(f"image saved at qwen-image-ulysses{world_size}-time{inference_time:.2f}s.png")

Running Python Script:

python your_script.py

Logs

Traceback (most recent call last):
  File "/home/qwen_image_edit_test.py", line 42, in <module>
    _ = pipe(
        ^^^^^
  File "/usr/local/python3.11.13/lib/python3.11/site-packages/torch/utils/_contextlib.py", line 120, in decorate_context
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/home/diffusers/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_edit_plus.py", line 803, in __call__
    noise_pred = self.transformer(
                 ^^^^^^^^^^^^^^^^^
  File "/usr/local/python3.11.13/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/python3.11.13/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1784, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/python3.11.13/lib/python3.11/site-packages/accelerate/hooks.py", line 175, in new_forward
    output = module._old_forward(*args, **kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/diffusers/src/diffusers/models/transformers/transformer_qwenimage.py", line 923, in forward
    text_seq_len, _, encoder_hidden_states_mask = compute_text_seq_len_from_mask(
                                                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/diffusers/src/diffusers/models/transformers/transformer_qwenimage.py", line 167, in compute_text_seq_len_from_mask
    per_sample_len = torch.where(has_active, active_positions.max(dim=1).values + 1, torch.as_tensor(text_seq_len))
                     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
RuntimeError: Expected all tensors to be on the same device. Expected NPU tensor, please check whether the input tensor device is correct.
[ERROR] 2026-01-22-02:30:34 (PID:633567, Device:0, RankID:-1) ERR01002 OPS invalid type

System Info

Copy-and-paste the text below in your GitHub issue and FILL OUT the two last points.

  • 🤗 Diffusers version: 0.37.0.dev0
  • Platform: Linux-5.10.0-216.0.0.115.oe2203sp4.aarch64-aarch64-with-glibc2.35
  • Running on Google Colab?: No
  • Python version: 3.11.13
  • PyTorch version (GPU?): 2.8.0+cpu (False)
  • Flax version (CPU?/GPU?/TPU?): not installed (NA)
  • Jax version: not installed
  • JaxLib version: not installed
  • Huggingface_hub version: 0.36.0
  • Transformers version: 4.57.6
  • Accelerate version: 1.11.0
  • PEFT version: not installed
  • Bitsandbytes version: not installed
  • Safetensors version: 0.6.2
  • xFormers version: not installed
  • Accelerator: NA
  • Using GPU in script?:
  • Using distributed or parallel set-up in script?:

Who can help?

No response

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions