Skip to content

[Request] Optimize HunyuanVideo Inference Speed with ParaAttention #10383

Closed
@chengzeyi

Description

@chengzeyi

Hi guys,

First and foremost, I would like to commend you for the incredible work on the diffusers library. It has been an invaluable resource for my projects.

I am writing to suggest an enhancement to the inference speed of the HunyuanVideo model. We have found that using ParaAttention can significantly speed up the inference of HunyuanVideo. ParaAttention provides context parallel attention that works with torch.compile, supporting Ulysses Style and Ring Style parallelism. I hope we could add a doc or introduction of how to make HunyuanVideo of diffusers run faster with ParaAttention. Besides HunyuanVideo, FLUX, Mochi and CogVideoX are also supported.

Steps to Optimize HunyuanVideo Inference with ParaAttention:

Install ParaAttention:

pip3 install para-attn
# Or visit https://github.com/chengzeyi/ParaAttention.git to see detailed instructions

Example Script:

Here is an example script to run HunyuanVideo with ParaAttention:

import torch
import torch.distributed as dist
from diffusers import HunyuanVideoPipeline, HunyuanVideoTransformer3DModel
from diffusers.utils import export_to_video

dist.init_process_group()

# [rank1]: RuntimeError: Expected mha_graph->execute(handle, variant_pack, workspace_ptr.get()).is_good() to be true, but got false.  (Could this error message be improved?  If so, please report an enhancement request to PyTorch.)
torch.backends.cuda.enable_cudnn_sdp(False)

model_id = "tencent/HunyuanVideo"
transformer = HunyuanVideoTransformer3DModel.from_pretrained(
    model_id,
    subfolder="transformer",
    torch_dtype=torch.bfloat16,
    revision="refs/pr/18",
)
pipe = HunyuanVideoPipeline.from_pretrained(
    model_id,
    transformer=transformer,
    torch_dtype=torch.float16,
    revision="refs/pr/18",
).to(f"cuda:{dist.get_rank()}")

pipe.vae.enable_tiling(
    # Make it runnable on GPUs with 48GB memory
    # tile_sample_min_height=128,
    # tile_sample_stride_height=96,
    # tile_sample_min_width=128,
    # tile_sample_stride_width=96,
    # tile_sample_min_num_frames=32,
    # tile_sample_stride_num_frames=24,
)

from para_attn.context_parallel import init_context_parallel_mesh
from para_attn.context_parallel.diffusers_adapters import parallelize_pipe
from para_attn.parallel_vae.diffusers_adapters import parallelize_vae

mesh = init_context_parallel_mesh(
    pipe.device.type,
)
parallelize_pipe(
    pipe,
    mesh=mesh,
)
parallelize_vae(pipe.vae, mesh=mesh._flatten())

# pipe.enable_model_cpu_offload(gpu_id=dist.get_rank())

# torch._inductor.config.reorder_for_compute_comm_overlap = True
# pipe.transformer = torch.compile(pipe.transformer, mode="max-autotune-no-cudagraphs")

output = pipe(
    prompt="A cat walks on the grass, realistic",
    height=720,
    width=1280,
    num_frames=129,
    num_inference_steps=30,
    output_type="pil" if dist.get_rank() == 0 else "pt",
).frames[0]

if dist.get_rank() == 0:
    print("Saving video to hunyuan_video.mp4")
    export_to_video(output, "hunyuan_video.mp4", fps=15)

dist.destroy_process_group()

Save the above code to run_hunyuan_video.py and run it with torchrun:

torchrun --nproc_per_node=2 run_hunyuan_video.py

The generated video on 2xH100:

hunyuan_video.mp4

By following these steps, users can leverage ParaAttention to achieve faster inference times with HunyuanVideo on multiple GPUs.

Thank you for considering this suggestion. I believe it could greatly benefit the community and enhance the performance of HunyuanVideo. Please let me know if there are any questions or further clarifications needed.

Metadata

Metadata

Assignees

No one assigned

    Labels

    roadmapAdd to current release roadmap

    Type

    No type

    Projects

    Status

    Done

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions