Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WanImageToVideoPipeline - swap out a limited number of blocks #10999

Open
spezialspezial opened this issue Mar 7, 2025 · 6 comments
Open

WanImageToVideoPipeline - swap out a limited number of blocks #10999

spezialspezial opened this issue Mar 7, 2025 · 6 comments

Comments

@spezialspezial
Copy link
Contributor

I can fit WanImageToVideoPipeline on a 24GB card but it does scrape the ceiling and is a bit too close for comfort to OOMing at some random system event.

The kijai/ComfyUI-WanVideoWrapper has a nice option to swap a limited and user defined number of blocks out of VRAM. Can a similar thing be done right now with sequential_cpu_offload? If not I would like to request something along these lines to shave off 2-4 GB.

I'm open to other ideas for a small VRAM reduction.

@a-r-r-o-w
Copy link
Member

Check apply_group_offloading related docs. It'll be going into current release soon and we can work on improving discoverability/docs. It has the minimal VRAM requirements without much overhead to generation time if you're on a modern cuda gpu. Currently, the RAM requirements are high but it's a WIP to improve it (happy to accept any improvement PRs 🤗)

Some numbers: #10847 (comment)

If you combine it with #10623 and precompute text embeddings and do tiled VAE decode, you can run in under ~7-10 GB.

@spezialspezial
Copy link
Contributor Author

Superbe, I'll check it out

@rolux
Copy link

rolux commented Mar 7, 2025

@a-r-r-o-w: Thanks for the hints!

I did some tests on Google Colab, and the following code will saturate the System RAM (83.5 GB) and then hang.

pipe = WanImageToVideoPipeline.from_pretrained(
    "Wan-AI/Wan2.1-I2V-14B-480P-Diffusers"
)
pipe.transformer.enable_group_offload(
    onload_device=torch.device("cuda"),
    offload_device=torch.device("cpu"),
    offload_type="leaf_level",
    use_stream=True
)

Meanwhile, pipe.vae.enable_tiling() throws an error:

AttributeError: 'AutoencoderKLWan' object has no attribute 'enable_tiling'

@a-r-r-o-w
Copy link
Member

a-r-r-o-w commented Mar 7, 2025

Group offloading with streams currently has a significant limitation in that it pins weight tensors on the CPU -- makes it require a lot more RAM than other methods. Could you try without streams, and also with offload_type="block_level"? It should probably lower it based on my finding here We're working on optimizing it/figuring out how to improve it.

Tiling support hasn't been added to Wan yet I believe (cc @yiyixuxu).

@rolux
Copy link

rolux commented Mar 7, 2025

@a-r-r-o-w: Thanks for the info!

The following code currently throws, see below.

code
!pip install git+https://github.com/huggingface/diffusers.git
!pip install ftfy


import os
import torch
from diffusers import (
    WanImageToVideoPipeline,
    WanTransformer3DModel
)
from diffusers.hooks import apply_group_offloading
from diffusers.utils import export_to_video
from PIL import Image


onload_device = torch.device("cuda")
offload_device = torch.device("cpu")
model_id = "Wan-AI/Wan2.1-I2V-14B-480P-Diffusers"
pipe = WanImageToVideoPipeline.from_pretrained(model_id)

pipe.transformer.enable_group_offload(
    onload_device=onload_device,
    offload_device=offload_device,
    offload_type="block_level",
    num_blocks_per_group=2
)
apply_group_offloading(
    pipe.text_encoder,
    onload_device=onload_device,
    offload_type="block_level",
    num_blocks_per_group=2
)
apply_group_offloading(
    pipe.vae,
    onload_device=onload_device,
    offload_type="block_level",
    num_blocks_per_group=2
)


def render(
    filename,
    image,
    prompt,
    seed=0,
    width=832,
    height=480,
    num_frames=81,
    num_inference_steps=30,
    guidance_scale=5.0,
    fps=16
):
    video = pipe(
        image=image,
        prompt=prompt,
        generator=torch.Generator(device=pipe.device).manual_seed(seed),
        width=width,
        height=height,
        num_frames=num_frames,
        num_inference_steps=num_inference_steps,
        guidance_scale=guidance_scale
    ).frames[0]
    os.makedirs(os.path.dirname(filename), exist_ok=True)
    export_to_video(video, filename, fps=fps)


render(
    filename="/content/test.mp4",
    image=Image.open("/content/test.png"),
    prompt="a woman in a yellow coat is dancing in the desert",
    seed=42
)
RuntimeError                              Traceback (most recent call last)
[<ipython-input-3-41f8c8c49ba8>](https://localhost:8080/#) in <cell line: 0>()
     72 
     73 
---> 74 render(
     75     filename="/content/test.mp4",
     76     image=Image.open("/content/test.png"),

16 frames
[<ipython-input-3-41f8c8c49ba8>](https://localhost:8080/#) in render(filename, image, prompt, seed, width, height, num_frames, num_inference_steps, guidance_scale, fps)
     58     fps=16
     59 ):
---> 60     video = pipe(
     61         image=image,
     62         prompt=prompt,

[/usr/local/lib/python3.11/dist-packages/torch/utils/_contextlib.py](https://localhost:8080/#) in decorate_context(*args, **kwargs)
    114     def decorate_context(*args, **kwargs):
    115         with ctx_factory():
--> 116             return func(*args, **kwargs)
    117 
    118     return decorate_context

[/usr/local/lib/python3.11/dist-packages/diffusers/pipelines/wan/pipeline_wan_i2v.py](https://localhost:8080/#) in __call__(self, image, prompt, negative_prompt, height, width, num_frames, num_inference_steps, guidance_scale, num_videos_per_prompt, generator, latents, prompt_embeds, negative_prompt_embeds, output_type, return_dict, attention_kwargs, callback_on_step_end, callback_on_step_end_tensor_inputs, max_sequence_length)
    585             negative_prompt_embeds = negative_prompt_embeds.to(transformer_dtype)
    586 
--> 587         image_embeds = self.encode_image(image)
    588         image_embeds = image_embeds.repeat(batch_size, 1, 1)
    589         image_embeds = image_embeds.to(transformer_dtype)

[/usr/local/lib/python3.11/dist-packages/diffusers/pipelines/wan/pipeline_wan_i2v.py](https://localhost:8080/#) in encode_image(self, image)
    222     def encode_image(self, image: PipelineImageInput):
    223         image = self.image_processor(images=image, return_tensors="pt").to(self.device)
--> 224         image_embeds = self.image_encoder(**image, output_hidden_states=True)
    225         return image_embeds.hidden_states[-2]
    226 

[/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py](https://localhost:8080/#) in _wrapped_call_impl(self, *args, **kwargs)
   1734             return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1735         else:
-> 1736             return self._call_impl(*args, **kwargs)
   1737 
   1738     # torchrec tests the code consistency with the following code

[/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py](https://localhost:8080/#) in _call_impl(self, *args, **kwargs)
   1745                 or _global_backward_pre_hooks or _global_backward_hooks
   1746                 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1747             return forward_call(*args, **kwargs)
   1748 
   1749         result = None

[/usr/local/lib/python3.11/dist-packages/transformers/models/clip/modeling_clip.py](https://localhost:8080/#) in forward(self, pixel_values, output_attentions, output_hidden_states, interpolate_pos_encoding, return_dict)
   1552         return_dict = return_dict if return_dict is not None else self.config.use_return_dict
   1553 
-> 1554         vision_outputs = self.vision_model(
   1555             pixel_values=pixel_values,
   1556             output_attentions=output_attentions,

[/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py](https://localhost:8080/#) in _wrapped_call_impl(self, *args, **kwargs)
   1734             return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1735         else:
-> 1736             return self._call_impl(*args, **kwargs)
   1737 
   1738     # torchrec tests the code consistency with the following code

[/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py](https://localhost:8080/#) in _call_impl(self, *args, **kwargs)
   1745                 or _global_backward_pre_hooks or _global_backward_hooks
   1746                 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1747             return forward_call(*args, **kwargs)
   1748 
   1749         result = None

[/usr/local/lib/python3.11/dist-packages/transformers/models/clip/modeling_clip.py](https://localhost:8080/#) in forward(self, pixel_values, output_attentions, output_hidden_states, return_dict, interpolate_pos_encoding)
   1091             raise ValueError("You have to specify pixel_values")
   1092 
-> 1093         hidden_states = self.embeddings(pixel_values, interpolate_pos_encoding=interpolate_pos_encoding)
   1094         hidden_states = self.pre_layrnorm(hidden_states)
   1095 

[/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py](https://localhost:8080/#) in _wrapped_call_impl(self, *args, **kwargs)
   1734             return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1735         else:
-> 1736             return self._call_impl(*args, **kwargs)
   1737 
   1738     # torchrec tests the code consistency with the following code

[/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py](https://localhost:8080/#) in _call_impl(self, *args, **kwargs)
   1745                 or _global_backward_pre_hooks or _global_backward_hooks
   1746                 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1747             return forward_call(*args, **kwargs)
   1748 
   1749         result = None

[/usr/local/lib/python3.11/dist-packages/transformers/models/clip/modeling_clip.py](https://localhost:8080/#) in forward(self, pixel_values, interpolate_pos_encoding)
    246             )
    247         target_dtype = self.patch_embedding.weight.dtype
--> 248         patch_embeds = self.patch_embedding(pixel_values.to(dtype=target_dtype))  # shape = [*, width, grid, grid]
    249         patch_embeds = patch_embeds.flatten(2).transpose(1, 2)
    250 

[/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py](https://localhost:8080/#) in _wrapped_call_impl(self, *args, **kwargs)
   1734             return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1735         else:
-> 1736             return self._call_impl(*args, **kwargs)
   1737 
   1738     # torchrec tests the code consistency with the following code

[/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py](https://localhost:8080/#) in _call_impl(self, *args, **kwargs)
   1745                 or _global_backward_pre_hooks or _global_backward_hooks
   1746                 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1747             return forward_call(*args, **kwargs)
   1748 
   1749         result = None

[/usr/local/lib/python3.11/dist-packages/torch/nn/modules/conv.py](https://localhost:8080/#) in forward(self, input)
    552 
    553     def forward(self, input: Tensor) -> Tensor:
--> 554         return self._conv_forward(input, self.weight, self.bias)
    555 
    556 

[/usr/local/lib/python3.11/dist-packages/torch/nn/modules/conv.py](https://localhost:8080/#) in _conv_forward(self, input, weight, bias)
    547                 self.groups,
    548             )
--> 549         return F.conv2d(
    550             input, weight, bias, self.stride, self.padding, self.dilation, self.groups
    551         )

RuntimeError: Input type (torch.cuda.FloatTensor) and weight type (torch.FloatTensor) should be the same

@rolux
Copy link

rolux commented Mar 7, 2025

If I comment out the two instances of apply_group_offloading (t5 + vae), it throws here:

RuntimeError                              Traceback (most recent call last)
[<ipython-input-1-edbc9a66ffc4>](https://localhost:8080/#) in <cell line: 0>()
     73 
     74 
---> 75 render(
     76     filename="/content/test.mp4",
     77     image=Image.open("/content/test.png"),

14 frames
[<ipython-input-1-edbc9a66ffc4>](https://localhost:8080/#) in render(filename, image, prompt, seed, width, height, num_frames, num_inference_steps, guidance_scale, fps)
     59     fps=16
     60 ):
---> 61     video = pipe(
     62         image=image,
     63         prompt=prompt,

[/usr/local/lib/python3.11/dist-packages/torch/utils/_contextlib.py](https://localhost:8080/#) in decorate_context(*args, **kwargs)
    114     def decorate_context(*args, **kwargs):
    115         with ctx_factory():
--> 116             return func(*args, **kwargs)
    117 
    118     return decorate_context

[/usr/local/lib/python3.11/dist-packages/diffusers/pipelines/wan/pipeline_wan_i2v.py](https://localhost:8080/#) in __call__(self, image, prompt, negative_prompt, height, width, num_frames, num_inference_steps, guidance_scale, num_videos_per_prompt, generator, latents, prompt_embeds, negative_prompt_embeds, output_type, return_dict, attention_kwargs, callback_on_step_end, callback_on_step_end_tensor_inputs, max_sequence_length)
    568 
    569         # 3. Encode input prompt
--> 570         prompt_embeds, negative_prompt_embeds = self.encode_prompt(
    571             prompt=prompt,
    572             negative_prompt=negative_prompt,

[/usr/local/lib/python3.11/dist-packages/diffusers/pipelines/wan/pipeline_wan_i2v.py](https://localhost:8080/#) in encode_prompt(self, prompt, negative_prompt, do_classifier_free_guidance, num_videos_per_prompt, prompt_embeds, negative_prompt_embeds, max_sequence_length, device, dtype)
    273 
    274         if prompt_embeds is None:
--> 275             prompt_embeds = self._get_t5_prompt_embeds(
    276                 prompt=prompt,
    277                 num_videos_per_prompt=num_videos_per_prompt,

[/usr/local/lib/python3.11/dist-packages/diffusers/pipelines/wan/pipeline_wan_i2v.py](https://localhost:8080/#) in _get_t5_prompt_embeds(self, prompt, num_videos_per_prompt, max_sequence_length, device, dtype)
    206         seq_lens = mask.gt(0).sum(dim=1).long()
    207 
--> 208         prompt_embeds = self.text_encoder(text_input_ids.to(device), mask.to(device)).last_hidden_state
    209         prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
    210         prompt_embeds = [u[:v] for u, v in zip(prompt_embeds, seq_lens)]

[/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py](https://localhost:8080/#) in _wrapped_call_impl(self, *args, **kwargs)
   1734             return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1735         else:
-> 1736             return self._call_impl(*args, **kwargs)
   1737 
   1738     # torchrec tests the code consistency with the following code

[/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py](https://localhost:8080/#) in _call_impl(self, *args, **kwargs)
   1745                 or _global_backward_pre_hooks or _global_backward_hooks
   1746                 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1747             return forward_call(*args, **kwargs)
   1748 
   1749         result = None

[/usr/local/lib/python3.11/dist-packages/transformers/models/umt5/modeling_umt5.py](https://localhost:8080/#) in forward(self, input_ids, attention_mask, head_mask, inputs_embeds, output_attentions, output_hidden_states, return_dict)
   1607         return_dict = return_dict if return_dict is not None else self.config.use_return_dict
   1608 
-> 1609         encoder_outputs = self.encoder(
   1610             input_ids=input_ids,
   1611             attention_mask=attention_mask,

[/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py](https://localhost:8080/#) in _wrapped_call_impl(self, *args, **kwargs)
   1734             return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1735         else:
-> 1736             return self._call_impl(*args, **kwargs)
   1737 
   1738     # torchrec tests the code consistency with the following code

[/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py](https://localhost:8080/#) in _call_impl(self, *args, **kwargs)
   1745                 or _global_backward_pre_hooks or _global_backward_hooks
   1746                 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1747             return forward_call(*args, **kwargs)
   1748 
   1749         result = None

[/usr/local/lib/python3.11/dist-packages/transformers/models/umt5/modeling_umt5.py](https://localhost:8080/#) in forward(self, input_ids, attention_mask, encoder_hidden_states, encoder_attention_mask, inputs_embeds, head_mask, cross_attn_head_mask, past_key_values, use_cache, output_attentions, output_hidden_states, return_dict, cache_position)
    686             if self.embed_tokens is None:
    687                 raise ValueError("You have to initialize the model with valid token embeddings")
--> 688             inputs_embeds = self.embed_tokens(input_ids)
    689 
    690         batch_size, seq_length = input_shape

[/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py](https://localhost:8080/#) in _wrapped_call_impl(self, *args, **kwargs)
   1734             return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1735         else:
-> 1736             return self._call_impl(*args, **kwargs)
   1737 
   1738     # torchrec tests the code consistency with the following code

[/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py](https://localhost:8080/#) in _call_impl(self, *args, **kwargs)
   1745                 or _global_backward_pre_hooks or _global_backward_hooks
   1746                 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1747             return forward_call(*args, **kwargs)
   1748 
   1749         result = None

[/usr/local/lib/python3.11/dist-packages/torch/nn/modules/sparse.py](https://localhost:8080/#) in forward(self, input)
    188 
    189     def forward(self, input: Tensor) -> Tensor:
--> 190         return F.embedding(
    191             input,
    192             self.weight,

[/usr/local/lib/python3.11/dist-packages/torch/nn/functional.py](https://localhost:8080/#) in embedding(input, weight, padding_idx, max_norm, norm_type, scale_grad_by_freq, sparse)
   2549         # remove once script supports set_grad_enabled
   2550         _no_grad_embedding_renorm_(weight, input, max_norm, norm_type)
-> 2551     return torch.embedding(weight, input, padding_idx, scale_grad_by_freq, sparse)
   2552 
   2553 

RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cpu and cuda:0! (when checking argument for argument index in method wrapper_CUDA__index_select)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants