-
Notifications
You must be signed in to change notification settings - Fork 6.7k
Open
Labels
bugSomething isn't workingSomething isn't working
Description
Describe the bug
While testing Qwen-Image-Edit-2509, the 'Expected all tensors to be on the same device' error occurred within compute_text_seq_len_from_mask
| per_sample_len = torch.where(has_active, active_positions.max(dim=1).values + 1, torch.as_tensor(text_seq_len)) |
related to: #12702
Reproduction
import torch
import torch_npu
import torch.distributed as dist
import os, time
from PIL import Image
from diffusers import QwenImageEditPlusPipeline, ContextParallelConfig
from diffusers.utils import load_image
# Initialize Env
rank = int(os.getenv("RANK", 0))
world_size = int(os.getenv("WORLD_SIZE", 1))
if world_size > 1 and not dist.is_initialized():
dist.init_process_group(backend="hccl")
rank = dist.get_rank()
device = torch.device("npu", rank % torch.npu.device_count())
torch.npu.set_device(device)
else:
device='npu'
image1 = load_image("https://github.com/vipshop/cache-dit/raw/main/examples/data/edit2509_1.jpg")
image2 = load_image("https://github.com/vipshop/cache-dit/raw/main/examples/data/edit2509_2.jpg")
prompt = "The magician bear is on the left, the alchemist bear is on the right, facing each other in the central park square"
pipe = QwenImageEditPlusPipeline.from_pretrained(
"/PATH/TO/Qwen-Image-Edit-2509",
torch_dtype=torch.bfloat16
).to(device)
pipe.transformer.set_attention_backend("_native_npu")
pipe.set_progress_bar_config(disable=rank != 0)
pipe.enable_model_cpu_offload(device=device)
if world_size > 1:
pipe.transformer.enable_parallelism(
config=ContextParallelConfig(ulysses_degree=world_size)
)
with torch.inference_mode():
torch.npu.synchronize()
start_time = time.time()
output = pipe(
image=[image1, image2],
prompt=prompt,
generator=torch.Generator(device="cpu").manual_seed(0),
true_cfg_scale=4.0,
negative_prompt=" ",
num_inference_steps=20,
num_images_per_prompt=1,
height=1024,
width=1024,
)
torch.npu.synchronize()
end_time = time.time()
inference_time = end_time - start_time
if rank == 0:
output_image = output.images[0]
output_image.save(f"qwen-image-ulysses{world_size}-time{inference_time:.2f}s.png")
print(f"image saved at qwen-image-ulysses{world_size}-time{inference_time:.2f}s.png")Running Python Script:
python your_script.pyLogs
Traceback (most recent call last):
File "/home/qwen_image_edit_test.py", line 42, in <module>
_ = pipe(
^^^^^
File "/usr/local/python3.11.13/lib/python3.11/site-packages/torch/utils/_contextlib.py", line 120, in decorate_context
return func(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^
File "/home/diffusers/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_edit_plus.py", line 803, in __call__
noise_pred = self.transformer(
^^^^^^^^^^^^^^^^^
File "/usr/local/python3.11.13/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/python3.11.13/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1784, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/python3.11.13/lib/python3.11/site-packages/accelerate/hooks.py", line 175, in new_forward
output = module._old_forward(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/diffusers/src/diffusers/models/transformers/transformer_qwenimage.py", line 923, in forward
text_seq_len, _, encoder_hidden_states_mask = compute_text_seq_len_from_mask(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/diffusers/src/diffusers/models/transformers/transformer_qwenimage.py", line 167, in compute_text_seq_len_from_mask
per_sample_len = torch.where(has_active, active_positions.max(dim=1).values + 1, torch.as_tensor(text_seq_len))
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
RuntimeError: Expected all tensors to be on the same device. Expected NPU tensor, please check whether the input tensor device is correct.
[ERROR] 2026-01-22-02:30:34 (PID:633567, Device:0, RankID:-1) ERR01002 OPS invalid typeSystem Info
Copy-and-paste the text below in your GitHub issue and FILL OUT the two last points.
- 🤗 Diffusers version: 0.37.0.dev0
- Platform: Linux-5.10.0-216.0.0.115.oe2203sp4.aarch64-aarch64-with-glibc2.35
- Running on Google Colab?: No
- Python version: 3.11.13
- PyTorch version (GPU?): 2.8.0+cpu (False)
- Flax version (CPU?/GPU?/TPU?): not installed (NA)
- Jax version: not installed
- JaxLib version: not installed
- Huggingface_hub version: 0.36.0
- Transformers version: 4.57.6
- Accelerate version: 1.11.0
- PEFT version: not installed
- Bitsandbytes version: not installed
- Safetensors version: 0.6.2
- xFormers version: not installed
- Accelerator: NA
- Using GPU in script?:
- Using distributed or parallel set-up in script?:
Who can help?
No response
Metadata
Metadata
Assignees
Labels
bugSomething isn't workingSomething isn't working