Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bug] CogVideoX 5B compile_pipe, backend=one_flow issue #1123

Open
kursatdinc opened this issue Oct 24, 2024 · 0 comments
Open

[Bug] CogVideoX 5B compile_pipe, backend=one_flow issue #1123

kursatdinc opened this issue Oct 24, 2024 · 0 comments
Labels
Request-bug Something isn't working

Comments

@kursatdinc
Copy link

Your current environment information

PyTorch version: 2.4.0+cu121
Is debug build: False
CUDA used to build PyTorch: 12.1
ROCM used to build PyTorch: N/A

OneFlow version: path: ['/home/kursat_dinc/anaconda3/envs/cogvideo_v2/lib/python3.9/site-packages/oneflow'], version: 0.9.1.dev20240802+cu118, git_commit: d23c061, cmake_build_type: Release, rdma: True, mlir: True, enterprise: False
Nexfort version: 0.1.dev275
OneDiff version: 1.2.0
OneDiffX version: 1.2.0

OS: Ubuntu 22.04.5 LTS (x86_64)
GCC version: (Ubuntu 11.4.0-1ubuntu1~22.04) 11.4.0
Clang version: Could not collect
CMake version: Could not collect
Libc version: glibc-2.35

Python version: 3.9.20 (main, Oct 3 2024, 07:27:41) [GCC 11.2.0] (64-bit runtime)
Python platform: Linux-6.8.0-1016-gcp-x86_64-with-glibc2.35
Is CUDA available: True
CUDA runtime version: 11.5.119
CUDA_MODULE_LOADING set to: LAZY
GPU models and configuration: GPU 0: NVIDIA A100-SXM4-40GB
Nvidia driver version: 550.107.02
cuDNN version: Probably one of the following:
/usr/lib/x86_64-linux-gnu/libcudnn.so.8.9.7
/usr/lib/x86_64-linux-gnu/libcudnn_adv_infer.so.8.9.7
/usr/lib/x86_64-linux-gnu/libcudnn_adv_train.so.8.9.7
/usr/lib/x86_64-linux-gnu/libcudnn_cnn_infer.so.8.9.7
/usr/lib/x86_64-linux-gnu/libcudnn_cnn_train.so.8.9.7
/usr/lib/x86_64-linux-gnu/libcudnn_ops_infer.so.8.9.7
/usr/lib/x86_64-linux-gnu/libcudnn_ops_train.so.8.9.7
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

CPU:
Architecture: x86_64
CPU op-mode(s): 32-bit, 64-bit
Address sizes: 46 bits physical, 48 bits virtual
Byte Order: Little Endian
CPU(s): 12
On-line CPU(s) list: 0-11
Vendor ID: GenuineIntel
Model name: Intel(R) Xeon(R) CPU @ 2.20GHz
CPU family: 6
Model: 85
Thread(s) per core: 2
Core(s) per socket: 6
Socket(s): 1
Stepping: 7
BogoMIPS: 4400.38
Flags: fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ss ht syscall nx pdpe1gb rdtscp lm constant_tsc rep_good nopl xtopology nonstop_tsc cpuid tsc_known_freq pni pclmulqdq ssse3 fma cx16 pcid sse4_1 sse4_2 x2apic movbe popcnt aes xsave avx f16c rdrand hypervisor lahf_lm abm 3dnowprefetch ssbd ibrs ibpb stibp ibrs_enhanced fsgsbase tsc_adjust bmi1 hle avx2 smep bmi2 erms invpcid rtm mpx avx512f avx512dq rdseed adx smap clflushopt clwb avx512cd avx512bw avx512vl xsaveopt xsavec xgetbv1 xsaves arat avx512_vnni md_clear arch_capabilities
Hypervisor vendor: KVM
Virtualization type: full
L1d cache: 192 KiB (6 instances)
L1i cache: 192 KiB (6 instances)
L2 cache: 6 MiB (6 instances)
L3 cache: 38.5 MiB (1 instance)
NUMA node(s): 1
NUMA node0 CPU(s): 0-11
Vulnerability Gather data sampling: Not affected
Vulnerability Itlb multihit: Not affected
Vulnerability L1tf: Not affected
Vulnerability Mds: Not affected
Vulnerability Meltdown: Not affected
Vulnerability Mmio stale data: Vulnerable: Clear CPU buffers attempted, no microcode; SMT Host state unknown
Vulnerability Reg file data sampling: Not affected
Vulnerability Retbleed: Mitigation; Enhanced IBRS
Vulnerability Spec rstack overflow: Not affected
Vulnerability Spec store bypass: Mitigation; Speculative Store Bypass disabled via prctl
Vulnerability Spectre v1: Mitigation; usercopy/swapgs barriers and __user pointer sanitization
Vulnerability Spectre v2: Mitigation; Enhanced / Automatic IBRS; IBPB conditional; RSB filling; PBRSB-eIBRS SW sequence; BHI SW loop, KVM SW loop
Vulnerability Srbds: Not affected
Vulnerability Tsx async abort: Vulnerable: Clear CPU buffers attempted, no microcode; SMT Host state unknown

Versions of relevant libraries:
[pip3] diffusers==0.30.3
[pip3] numpy==1.26.4
[pip3] onnx==1.16.1
[pip3] onnx-graphsurgeon==0.5.2
[pip3] onnxconverter-common==1.14.0
[pip3] onnxmltools==1.12.0
[pip3] onnxruntime_extensions==0.12.0
[pip3] onnxruntime-gpu==1.18.0
[pip3] torch==2.4.0
[pip3] torch_tensorrt==2.4.0
[pip3] torchao==0.6.1
[pip3] torchaudio==2.4.0
[pip3] torchprofile==0.0.4
[pip3] torchvision==0.19.0
[pip3] transformers==4.45.2
[pip3] triton==3.0.0
[conda] numpy 1.26.4 pypi_0 pypi
[conda] torch 2.4.0 pypi_0 pypi
[conda] torch-tensorrt 2.4.0 pypi_0 pypi
[conda] torchao 0.6.1 pypi_0 pypi
[conda] torchaudio 2.4.0 pypi_0 pypi
[conda] torchprofile 0.0.4 pypi_0 pypi
[conda] torchvision 0.19.0 pypi_0 pypi
[conda] triton 3.0.0 pypi_0 pypi

🐛 Describe the bug

import time
import json
import os
import torch
from diffusers import (CogVideoXDDIMScheduler, DDIMScheduler,
                       DPMSolverMultistepScheduler,
                       EulerAncestralDiscreteScheduler, EulerDiscreteScheduler,
                       PNDMScheduler, CogVideoXDPMScheduler)
from transformers import T5EncoderModel

from cogvideox.models.transformer3d import CogVideoXTransformer3DModel
from cogvideox.models.autoencoder_magvit import AutoencoderKLCogVideoX
from cogvideox.pipeline.pipeline_cogvideox_inpaint import CogVideoX_Fun_Pipeline_Inpaint
from cogvideox.utils.lora_utils import merge_lora, unmerge_lora
from cogvideox.utils.utils import get_image_to_video_latent, save_videos_grid

from torchao.quantization import quantize_, int8_weight_only
from onediffx import compile_pipe, save_pipe
from onediff.infer_compiler import oneflow_compile, compile, OneflowCompileOptions
import oneflow as flow 

# from sageattention import sageattn
# import torch.nn.functional as F

# F.scaled_dot_product_attention = sageattn

os.environ['NEXFORT_FUSE_TIMESTEP_EMBEDDING'] = '0'
os.environ['NEXFORT_FX_FORCE_TRITON_SDPA'] = '1'
os.environ['NEXFORT_GRAPH_CACHE'] = '1'
os.environ['TORCHINDUCTOR_CACHE_DIR']= '~/torchinductor'


def init_pipeline(
    model_path,
    lora_path,
    lora_weight,
    sampler_name,
    weight_dtype,
    quantization,
    device,
):
    # Text Encoder
    text_encoder = T5EncoderModel.from_pretrained(model_path, subfolder="text_encoder", torch_dtype=weight_dtype).to(device)
    #quantize_(text_encoder, quantization())

    # Transformer
    transformer = CogVideoXTransformer3DModel.from_pretrained_2d(model_path, subfolder="transformer").to(weight_dtype).to(device)
    #quantize_(transformer, quantization())

    # Vae
    vae = AutoencoderKLCogVideoX.from_pretrained(model_path, subfolder="vae").to(weight_dtype).to(device)
    #quantize_(vae, quantization())

    # Scheduler
    Choosen_Scheduler = {
        "Euler": EulerDiscreteScheduler,
        "Euler A": EulerAncestralDiscreteScheduler,
        "DPM++": DPMSolverMultistepScheduler,
        "DPM_Cog": CogVideoXDPMScheduler, 
        "PNDM": PNDMScheduler,
        "DDIM_Cog": CogVideoXDDIMScheduler,
        "DDIM_Origin": DDIMScheduler,
        }[sampler_name]
    scheduler = Choosen_Scheduler.from_pretrained(model_path, subfolder="scheduler")

    # Pipeline
    pipeline = CogVideoX_Fun_Pipeline_Inpaint.from_pretrained(
        model_path,
        text_encoder=text_encoder,
        transformer=transformer,
        vae=vae,
        scheduler=scheduler,
        torch_dtype=weight_dtype
        )

    if lora_path is not None:
        pipeline = merge_lora(pipeline, lora_path, lora_weight)

    pipeline.to(device)

    pipeline.vae.enable_slicing()
    pipeline.vae.enable_tiling()
    
    ### TORCH TRANSFORMER COMPILE ###

    # pipeline.transformer = torch.compile(
    #     pipeline.transformer,
    #     mode = 'reduce-overhead',
    #     fullgraph=True,
    #     dynamic=False,
    #     )

    ###


    ### ONEFLOW TRANSFORMER COMPILE ###

    # compile_options = OneflowCompileOptions()
    # compile_options.max_cached_graph_size = 9
    # pipeline.transormer = oneflow_compile(
    #     pipeline.transformer,
    #     options= compile_options
    # )

    ###


    ### NEXFORT PIPE COMPILE ###

    #options = '{"mode": "max-optimize:max-autotune:low-precision", "memory_format": "channels_last"}'
    #options = '{"mode": "max-autotune:cudagraphs", "memory_format": "channels_last"}'
    #options = '{"mode": "cudagraphs:low-precision", "memory_format": "channels_last"}'
    #options = '{"mode": "O3", "memory_format": "channels_last"}'
    #options = '{"mode": "max-autotune:low-precision:cache-all:cudagraphs", "memory_format": "channels_last", "dynamic": true}'
    #options = '{"mode": "max-autotune:freezing:benchmark:low-precision", "memory_format": "channels_last", "options": {"inductor.optimize_linear_epilogue": false, "triton.fuse_attention_allow_fp16_reduction": false}}'

    # pipeline = compile_pipe(
    #     pipeline,
    #     backend="nexfort",
    #     options= options,
    #     ignores=["vae", "text_encoder"],
    #     fuse_qkv_projections=True,
    # )   
    
    pipeline = compile_pipe(
        pipeline,
        ignores=["vae", "text_encoder"],
        fuse_qkv_projections=True,)

    return pipeline, vae


def generate_video(
        pipeline,
        vae,
        lora_path,
        lora_weight,
        prompt,
        negative_prompt,
        input_image_start,
        input_image_end,
        num_frames,
        sample_size,
        guidance_scale,
        num_videos_per_prompt,
        num_inference_steps,
        seed,
):
    video_length = int((num_frames - 1) // vae.config.temporal_compression_ratio * vae.config.temporal_compression_ratio) + 1 if num_frames != 1 else 1
    input_video, input_video_mask, clip_image = get_image_to_video_latent(input_image_start, input_image_end, video_length=video_length, sample_size=sample_size)

    with torch.no_grad():
        generated_video = pipeline(
            prompt = prompt,
            negative_prompt = negative_prompt, 
            num_frames = video_length,
            height = sample_size[0],
            width = sample_size[1],
            guidance_scale = guidance_scale,
            num_videos_per_prompt = num_videos_per_prompt,
            num_inference_steps = num_inference_steps,
            generator = torch.Generator().manual_seed(seed),
            video = input_video,
            mask_video = input_video_mask
        ).videos

    if lora_path is not None:
        pipeline = unmerge_lora(pipeline, lora_path, lora_weight)

    return generated_video


model_path = "models/CogVideoX-Fun-V1.1-5b-InP"
lora_path = None
lora_weight = None
sampler_name = "Euler"
num_frames = 49
fps = 8
weight_dtype = torch.bfloat16
input_image_end = None
quantization = int8_weight_only
device = "cuda"

start_time = time.time()
initialized_pipeline, initialized_vae = init_pipeline(
    model_path = model_path,
    lora_path = lora_path,
    lora_weight = lora_weight,
    sampler_name = sampler_name,
    weight_dtype = weight_dtype,
    quantization = quantization,
    device = device,
)
end_time = time.time()
exec_time = end_time - start_time
print(f"Pipe Initialization Time:{exec_time}s")

prompt_list = {"cakeify_it": "",
               "crush_it": "In this experiment, we witness the fate of a soft, innocent plush creature under the relentless force of a hydraulic press. At the beginning, our subject sits quietly, his large, expressive eyes full of curiosity, perhaps unaware of what awaits him. The scene changes as the high metallic press begins its slow descent, bringing with it an air of tension and inevitability. The metallic press touches the plush creature at its base and begins to crush it. As the pressure increases, the plush toy resists for a moment, flattening slightly but holding its shape, as if clinging to hope. But the force is unstoppable. Slowly but surely, the squishy figure succumbs, its fabric folds and compresses, distorting its once round form. The press continues to exert its force until the toy is completely crushed and its contents spill out dramatically on all sides. The experiment ends when the press is removed to reveal the shattered remains of the plush friend - now unrecognizable, a mere memory of its original form. This mesmerizing moment demonstrates the irresistible power of machines and the fragility of soft objects, leaving us both mesmerized and with some empathy for the oppressed subject.",
               "explode_it": "In a explode of chaos, the object disintegrates, chunks the flying through the around. The scene transitions to a slow-motion spectacle, capturing each piece in mid-flight, their textures and colors vivid against. Finally, the remnants settle in a haphazard pile, creating a whimsical, messy aftermath that contrasts sharply with the initial perfection.",
               "inflate_it": "The subject in the photo inflates with air like a balloon, with the camera fixed in a static shot. As the subject begins to fill with air, they rise vertically upwards, detaching from the ground and the image. Each part of the subject swells proportionally, expanding gently, as they continue to ascend. The subject becomes lighter and larger, floating higher and faster, rising gracefully toward the top of the frame, where they appear fully inflated and airborne.",
               "melt_it": "The subject in the image begins to melt, turning into liquid with static shot. Starting from the edges, the subject’s form starts to melting, their features blending together as they liquefy. Each part of the subject gradually turns into liquid, dripping downwards and pooling onto the ground. The liquid spreads and flows across the frame,  flowing liquid state. Eventually, the entire subject completely dissolves, leaving no trace behind as the liquid merges into the surface and vanishes.",
               "squish_it": ""}

style = "melt_it"
prompt = prompt_list.get(style)
negative_prompt = ""
sample_size = [512, 512]
cfg = 7.0
num_videos = 1
steps = 50
seed = 42

input_image_start = "samples/spider.png"

start_time = time.time()
generated_video = generate_video(
    pipeline = initialized_pipeline,
    vae = initialized_vae,
    lora_path = lora_path,
    lora_weight = lora_weight,
    prompt = prompt,
    negative_prompt = negative_prompt,
    input_image_start = input_image_start,
    input_image_end = input_image_end,
    num_frames = num_frames,
    sample_size = sample_size,
    guidance_scale = cfg,
    num_videos_per_prompt = num_videos,
    num_inference_steps = steps,
    seed = seed,
)
end_time = time.time()
exec_time = end_time - start_time

output_path = f"outputs/{style}-{os.path.splitext(os.path.basename(input_image_start))[0]}-res:{sample_size}-steps:{steps}-time:{exec_time:.2f}s.mp4"
save_videos_grid(generated_video, output_path, fps=fps)
print(f'Video saved, Path:{output_path}')

I can run CogVideoX5B in this file when ı use compile backend as nexfort, but I cannot save graphs at nexfort backend so I want to use oneflow as backend but I got an error below. Please help.

{
	"name": "TypeError",
	"message": "expected Tensor as element 0 in argument 0, but got Tensor",
	"stack": "---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
File ~/anaconda3/envs/cogvideo_v2/lib/python3.9/site-packages/onediff/infer_compiler/backends/oneflow/deployable_module.py:42, in handle_deployable_exception.<locals>.wrapper(self, *args, **kwargs)
     41 try:
---> 42     return func(self, *args, **kwargs)
     43 except Exception as e:

File ~/anaconda3/envs/cogvideo_v2/lib/python3.9/site-packages/onediff/infer_compiler/backends/oneflow/online_quantization_utils.py:65, in quantize_and_deploy_wrapper.<locals>.wrapper(self, *args, **kwargs)
     64     self._deployable_module_quant_config = None
---> 65 output = func(self, *args, **kwargs)
     66 return output

File ~/anaconda3/envs/cogvideo_v2/lib/python3.9/site-packages/onediff/infer_compiler/backends/oneflow/graph_management_utils.py:123, in graph_file_management.<locals>.wrapper(self, *args, **kwargs)
    122 else:
--> 123     ret = func(self, *args, **kwargs)
    125 return ret

File ~/anaconda3/envs/cogvideo_v2/lib/python3.9/site-packages/onediff/infer_compiler/backends/oneflow/args_tree_util.py:70, in input_output_processor.<locals>.wrapper(self, *args, **kwargs)
     68         self._load_graph_first_run = True
---> 70 output = func(self, *mapped_args, **mapped_kwargs)
     71 return process_output(output)

File ~/anaconda3/envs/cogvideo_v2/lib/python3.9/site-packages/onediff/infer_compiler/backends/oneflow/deployable_module.py:154, in OneflowDeployableModule.forward(self, *args, **kwargs)
    153     with oneflow_exec_mode():
--> 154         output = dpl_graph(*args, **kwargs)
    155 else:

File ~/anaconda3/envs/cogvideo_v2/lib/python3.9/site-packages/oneflow/nn/graph/graph.py:295, in Graph.__call__(self, *args, **kwargs)
    294 if not self._is_compiled:
--> 295     self._compile(*args, **kwargs)
    297 return self.__run(*args, **kwargs)

File ~/anaconda3/envs/cogvideo_v2/lib/python3.9/site-packages/oneflow/nn/graph/graph.py:861, in Graph._compile(self, *args, **kwargs)
    860 if self._run_with_cache:
--> 861     return self._dynamic_input_graph_cache._compile(*args, **kwargs)
    863 if not self._is_compiled:

File ~/anaconda3/envs/cogvideo_v2/lib/python3.9/site-packages/oneflow/nn/graph/cache.py:121, in GraphCache._compile(self, *args, **kwargs)
    120 with AvoidRecursiveCacheCall(graph):
--> 121     return graph._compile(*args, **kwargs)

File ~/anaconda3/envs/cogvideo_v2/lib/python3.9/site-packages/oneflow/nn/graph/graph.py:865, in Graph._compile(self, *args, **kwargs)
    864 if not self._build_with_shared_graph:
--> 865     return self._compile_new(*args, **kwargs)
    866 else:

File ~/anaconda3/envs/cogvideo_v2/lib/python3.9/site-packages/oneflow/nn/graph/graph.py:884, in Graph._compile_new(self, *args, **kwargs)
    883 self.__ensure_input_tensors_contiguous(*args, **kwargs)
--> 884 _, eager_outputs = self.build_graph(*args, **kwargs)
    885 if isinstance(eager_outputs, (tuple, list)) and all(
    886     isinstance(arg, Tensor) for arg in eager_outputs
    887 ):

File ~/anaconda3/envs/cogvideo_v2/lib/python3.9/site-packages/oneflow/nn/graph/graph.py:1429, in Graph.build_graph(self, *args, **kwargs)
   1422 with graph_build_util.DebugScopeContext(
   1423     self._debug_min_s_level,
   1424     self._debug_max_v_level,
   (...)
   1427     self._debug_only_user_py_stack,
   1428 ):
-> 1429     outputs = self.__build_graph(*args, **kwargs)
   1430 build_graph_end = time.perf_counter()

File ~/anaconda3/envs/cogvideo_v2/lib/python3.9/site-packages/oneflow/nn/graph/graph.py:1577, in Graph.__build_graph(self, *args, **kwargs)
   1576 self._is_user_mode = True
-> 1577 outputs = self.build(*lazy_args, **lazy_kwargs)
   1578 self._is_user_mode = False

File ~/anaconda3/envs/cogvideo_v2/lib/python3.9/site-packages/onediff/infer_compiler/backends/oneflow/graph.py:19, in OneflowGraph.build(self, *args, **kwargs)
     18 def build(self, *args, **kwargs):
---> 19     return self.model(*args, **kwargs)

File ~/anaconda3/envs/cogvideo_v2/lib/python3.9/site-packages/oneflow/nn/graph/proxy.py:188, in ProxyModule.__call__(self, *args, **kwargs)
    181 with graph_build_util.DebugScopeContext(
    182     self.to(GraphModule)._debug_min_s_level,
    183     self.to(GraphModule)._debug_max_v_level,
   (...)
    186     self.to(GraphModule)._debug_only_user_py_stack,
    187 ):
--> 188     result = self.__block_forward(*args, **kwargs)
    190 outputs = ()

File ~/anaconda3/envs/cogvideo_v2/lib/python3.9/site-packages/oneflow/nn/graph/proxy.py:239, in ProxyModule.__block_forward(self, *args, **kwargs)
    238 try:
--> 239     result = unbound_forward_of_module_instance(self, *args, **kwargs)
    240 # for callback to torch.Module.forward when forward is not implemented

File ~/desktop/video_model/CogVideoX-Fun/cogvideox/models/transformer3d.py:518, in CogVideoXTransformer3DModel.forward(self, hidden_states, encoder_hidden_states, timestep, timestep_cond, inpaint_latents, control_latents, image_rotary_emb, return_dict)
    517     else:
--> 518         hidden_states, encoder_hidden_states = block(
    519             hidden_states=hidden_states,
    520             encoder_hidden_states=encoder_hidden_states,
    521             temb=emb,
    522             image_rotary_emb=image_rotary_emb,
    523         )
    525 if not self.config.use_rotary_positional_embeddings:
    526     # CogVideoX-2B

File ~/anaconda3/envs/cogvideo_v2/lib/python3.9/site-packages/oneflow/nn/graph/proxy.py:188, in ProxyModule.__call__(self, *args, **kwargs)
    181 with graph_build_util.DebugScopeContext(
    182     self.to(GraphModule)._debug_min_s_level,
    183     self.to(GraphModule)._debug_max_v_level,
   (...)
    186     self.to(GraphModule)._debug_only_user_py_stack,
    187 ):
--> 188     result = self.__block_forward(*args, **kwargs)
    190 outputs = ()

File ~/anaconda3/envs/cogvideo_v2/lib/python3.9/site-packages/oneflow/nn/graph/proxy.py:239, in ProxyModule.__block_forward(self, *args, **kwargs)
    238 try:
--> 239     result = unbound_forward_of_module_instance(self, *args, **kwargs)
    240 # for callback to torch.Module.forward when forward is not implemented

File ~/desktop/video_model/CogVideoX-Fun/cogvideox/models/transformer3d.py:173, in CogVideoXBlock.forward(self, hidden_states, encoder_hidden_states, temb, image_rotary_emb)
    172 # attention
--> 173 attn_hidden_states, attn_encoder_hidden_states = self.attn1(
    174     hidden_states=norm_hidden_states,
    175     encoder_hidden_states=norm_encoder_hidden_states,
    176     image_rotary_emb=image_rotary_emb,
    177 )
    179 hidden_states = hidden_states + gate_msa * attn_hidden_states

File ~/anaconda3/envs/cogvideo_v2/lib/python3.9/site-packages/oneflow/nn/graph/proxy.py:188, in ProxyModule.__call__(self, *args, **kwargs)
    181 with graph_build_util.DebugScopeContext(
    182     self.to(GraphModule)._debug_min_s_level,
    183     self.to(GraphModule)._debug_max_v_level,
   (...)
    186     self.to(GraphModule)._debug_only_user_py_stack,
    187 ):
--> 188     result = self.__block_forward(*args, **kwargs)
    190 outputs = ()

File ~/anaconda3/envs/cogvideo_v2/lib/python3.9/site-packages/oneflow/nn/graph/proxy.py:239, in ProxyModule.__block_forward(self, *args, **kwargs)
    238 try:
--> 239     result = unbound_forward_of_module_instance(self, *args, **kwargs)
    240 # for callback to torch.Module.forward when forward is not implemented

File ~/anaconda3/envs/cogvideo_v2/lib/python3.9/site-packages/infer_compiler_registry/register_diffusers/attention_processor_oflow.py:364, in Attention.forward(self, hidden_states, encoder_hidden_states, attention_mask, **cross_attention_kwargs)
    353 def forward(
    354     self,
    355     hidden_states,
   (...)
    361     # here we simply pass along all tensors to the selected processor class
    362     # For standard processors that are defined here, `**cross_attention_kwargs` is empty
--> 364     return self.processor(
    365         self,
    366         hidden_states,
    367         encoder_hidden_states=encoder_hidden_states,
    368         attention_mask=attention_mask,
    369         **cross_attention_kwargs,
    370     )

File ~/anaconda3/envs/cogvideo_v2/lib/python3.9/site-packages/diffusers/models/attention_processor.py:1962, in FusedCogVideoXAttnProcessor2_0.__call__(self, attn, hidden_states, encoder_hidden_states, attention_mask, image_rotary_emb)
   1960 text_seq_length = encoder_hidden_states.size(1)
-> 1962 hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
   1964 batch_size, sequence_length, _ = (
   1965     hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
   1966 )

TypeError: expected Tensor as element 0 in argument 0, but got Tensor

During handling of the above exception, another exception occurred:

TypeError                                 Traceback (most recent call last)
File /home/kursat_dinc/desktop/video_model/CogVideoX-Fun/cog-fun-i2v.py:2
      1 start_time = time.time()
----> 2 generated_video = generate_video(
      3     pipeline = initialized_pipeline,
      4     vae = initialized_vae,
      5     lora_path = lora_path,
      6     lora_weight = lora_weight,
      7     prompt = prompt,
      8     negative_prompt = negative_prompt,
      9     input_image_start = input_image_start,
     10     input_image_end = input_image_end,
     11     num_frames = num_frames,
     12     sample_size = sample_size,
     13     guidance_scale = cfg,
     14     num_videos_per_prompt = num_videos,
     15     num_inference_steps = steps,
     16     seed = seed,
     17 )
     18 end_time = time.time()
     19 exec_time = end_time - start_time

File /home/kursat_dinc/desktop/video_model/CogVideoX-Fun/cog-fun-i2v.py:153
    150 input_video, input_video_mask, clip_image = get_image_to_video_latent(input_image_start, input_image_end, video_length=video_length, sample_size=sample_size)
    152 with torch.no_grad():
--> 153     generated_video = pipeline(
    154         prompt = prompt,
    155         negative_prompt = negative_prompt, 
    156         num_frames = video_length,
    157         height = sample_size[0],
    158         width = sample_size[1],
    159         guidance_scale = guidance_scale,
    160         num_videos_per_prompt = num_videos_per_prompt,
    161         num_inference_steps = num_inference_steps,
    162         generator = torch.Generator().manual_seed(seed),
    163         video = input_video,
    164         mask_video = input_video_mask
    165     ).videos
    167 if lora_path is not None:
    168     pipeline = unmerge_lora(pipeline, lora_path, lora_weight)

File ~/anaconda3/envs/cogvideo_v2/lib/python3.9/site-packages/torch/utils/_contextlib.py:116, in context_decorator.<locals>.decorate_context(*args, **kwargs)
    113 @functools.wraps(func)
    114 def decorate_context(*args, **kwargs):
    115     with ctx_factory():
--> 116         return func(*args, **kwargs)

File ~/desktop/video_model/CogVideoX-Fun/cogvideox/pipeline/pipeline_cogvideox_inpaint.py:956, in CogVideoX_Fun_Pipeline_Inpaint.__call__(self, prompt, negative_prompt, height, width, video, mask_video, masked_video_latents, num_frames, num_inference_steps, timesteps, guidance_scale, use_dynamic_cfg, num_videos_per_prompt, eta, generator, latents, prompt_embeds, negative_prompt_embeds, output_type, return_dict, callback_on_step_end, callback_on_step_end_tensor_inputs, max_sequence_length, strength, noise_aug_strength, comfyui_progressbar)
    953 timestep = t.expand(latent_model_input.shape[0])
    955 # predict noise model_output
--> 956 noise_pred = self.transformer(
    957     hidden_states=latent_model_input,
    958     encoder_hidden_states=prompt_embeds,
    959     timestep=timestep,
    960     image_rotary_emb=image_rotary_emb,
    961     return_dict=False,
    962     inpaint_latents=inpaint_latents,
    963 )[0]
    964 noise_pred = noise_pred.float()
    966 # perform guidance

File ~/anaconda3/envs/cogvideo_v2/lib/python3.9/site-packages/torch/nn/modules/module.py:1553, in Module._wrapped_call_impl(self, *args, **kwargs)
   1551     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1552 else:
-> 1553     return self._call_impl(*args, **kwargs)

File ~/anaconda3/envs/cogvideo_v2/lib/python3.9/site-packages/torch/nn/modules/module.py:1603, in Module._call_impl(self, *args, **kwargs)
   1600     bw_hook = hooks.BackwardHook(self, full_backward_hooks, backward_pre_hooks)
   1601     args = bw_hook.setup_input_hook(args)
-> 1603 result = forward_call(*args, **kwargs)
   1604 if _global_forward_hooks or self._forward_hooks:
   1605     for hook_id, hook in (
   1606         *_global_forward_hooks.items(),
   1607         *self._forward_hooks.items(),
   1608     ):
   1609         # mark that always called hook is run

File ~/anaconda3/envs/cogvideo_v2/lib/python3.9/site-packages/onediff/infer_compiler/backends/oneflow/deployable_module.py:48, in handle_deployable_exception.<locals>.wrapper(self, *args, **kwargs)
     46 del self._deployable_module_model.oneflow_module
     47 self._deployable_module_dpl_graph = None
---> 48 return func(self, *args, **kwargs)

File ~/anaconda3/envs/cogvideo_v2/lib/python3.9/site-packages/onediff/infer_compiler/backends/oneflow/online_quantization_utils.py:65, in quantize_and_deploy_wrapper.<locals>.wrapper(self, *args, **kwargs)
     56     torch_model, _ = online_quantize_model(
     57         torch_model,
     58         args,
   (...)
     62         inplace=True,
     63     )
     64     self._deployable_module_quant_config = None
---> 65 output = func(self, *args, **kwargs)
     66 return output

File ~/anaconda3/envs/cogvideo_v2/lib/python3.9/site-packages/onediff/infer_compiler/backends/oneflow/graph_management_utils.py:123, in graph_file_management.<locals>.wrapper(self, *args, **kwargs)
    121     handle_graph_saving()
    122 else:
--> 123     ret = func(self, *args, **kwargs)
    125 return ret

File ~/anaconda3/envs/cogvideo_v2/lib/python3.9/site-packages/onediff/infer_compiler/backends/oneflow/args_tree_util.py:70, in input_output_processor.<locals>.wrapper(self, *args, **kwargs)
     67         self._deployable_module_input_structure_key = None
     68         self._load_graph_first_run = True
---> 70 output = func(self, *mapped_args, **mapped_kwargs)
     71 return process_output(output)

File ~/anaconda3/envs/cogvideo_v2/lib/python3.9/site-packages/onediff/infer_compiler/backends/oneflow/deployable_module.py:154, in OneflowDeployableModule.forward(self, *args, **kwargs)
    152     dpl_graph = self.get_graph()
    153     with oneflow_exec_mode():
--> 154         output = dpl_graph(*args, **kwargs)
    155 else:
    156     with oneflow_exec_mode():

File ~/anaconda3/envs/cogvideo_v2/lib/python3.9/site-packages/oneflow/nn/graph/graph.py:295, in Graph.__call__(self, *args, **kwargs)
    292     return self._dynamic_input_graph_cache(*args, **kwargs)
    294 if not self._is_compiled:
--> 295     self._compile(*args, **kwargs)
    297 return self.__run(*args, **kwargs)

File ~/anaconda3/envs/cogvideo_v2/lib/python3.9/site-packages/oneflow/nn/graph/graph.py:861, in Graph._compile(self, *args, **kwargs)
    859 def _compile(self, *args, **kwargs):
    860     if self._run_with_cache:
--> 861         return self._dynamic_input_graph_cache._compile(*args, **kwargs)
    863     if not self._is_compiled:
    864         if not self._build_with_shared_graph:

File ~/anaconda3/envs/cogvideo_v2/lib/python3.9/site-packages/oneflow/nn/graph/cache.py:121, in GraphCache._compile(self, *args, **kwargs)
    119 graph = self.get_graph(*args, **kwargs)
    120 with AvoidRecursiveCacheCall(graph):
--> 121     return graph._compile(*args, **kwargs)

File ~/anaconda3/envs/cogvideo_v2/lib/python3.9/site-packages/oneflow/nn/graph/graph.py:865, in Graph._compile(self, *args, **kwargs)
    863 if not self._is_compiled:
    864     if not self._build_with_shared_graph:
--> 865         return self._compile_new(*args, **kwargs)
    866     else:
    867         return self._compile_from_shared(*args, **kwargs)

File ~/anaconda3/envs/cogvideo_v2/lib/python3.9/site-packages/oneflow/nn/graph/graph.py:884, in Graph._compile_new(self, *args, **kwargs)
    881     self._is_simple_tuple_input = True
    883 self.__ensure_input_tensors_contiguous(*args, **kwargs)
--> 884 _, eager_outputs = self.build_graph(*args, **kwargs)
    885 if isinstance(eager_outputs, (tuple, list)) and all(
    886     isinstance(arg, Tensor) for arg in eager_outputs
    887 ):
    888     self._is_simple_tuple_output = True

File ~/anaconda3/envs/cogvideo_v2/lib/python3.9/site-packages/oneflow/nn/graph/graph.py:1429, in Graph.build_graph(self, *args, **kwargs)
   1421 build_graph_start = time.perf_counter()
   1422 with graph_build_util.DebugScopeContext(
   1423     self._debug_min_s_level,
   1424     self._debug_max_v_level,
   (...)
   1427     self._debug_only_user_py_stack,
   1428 ):
-> 1429     outputs = self.__build_graph(*args, **kwargs)
   1430 build_graph_end = time.perf_counter()
   1431 self.__print(
   1432     0,
   1433     0,
   (...)
   1438     + \"\
\",
   1439 )

File ~/anaconda3/envs/cogvideo_v2/lib/python3.9/site-packages/oneflow/nn/graph/graph.py:1577, in Graph.__build_graph(self, *args, **kwargs)
   1575 self.__print(0, 1, self._shallow_repr() + \" start building graph modules.\")
   1576 self._is_user_mode = True
-> 1577 outputs = self.build(*lazy_args, **lazy_kwargs)
   1578 self._is_user_mode = False
   1579 self.__print(0, 1, self._shallow_repr() + \" end building graph modules.\")

File ~/anaconda3/envs/cogvideo_v2/lib/python3.9/site-packages/onediff/infer_compiler/backends/oneflow/graph.py:19, in OneflowGraph.build(self, *args, **kwargs)
     18 def build(self, *args, **kwargs):
---> 19     return self.model(*args, **kwargs)

File ~/anaconda3/envs/cogvideo_v2/lib/python3.9/site-packages/oneflow/nn/graph/proxy.py:188, in ProxyModule.__call__(self, *args, **kwargs)
    178 # NOTE: The original nn.Module's __call__ method is ignored, which means
    179 # that hooks of nn.Modules are ignored. It is not recommended
    180 # to use hooks of nn.Module in nn.Graph for the moment.
    181 with graph_build_util.DebugScopeContext(
    182     self.to(GraphModule)._debug_min_s_level,
    183     self.to(GraphModule)._debug_max_v_level,
   (...)
    186     self.to(GraphModule)._debug_only_user_py_stack,
    187 ):
--> 188     result = self.__block_forward(*args, **kwargs)
    190 outputs = ()
    191 if not (type(result) is tuple or type(result) is list):

File ~/anaconda3/envs/cogvideo_v2/lib/python3.9/site-packages/oneflow/nn/graph/proxy.py:239, in ProxyModule.__block_forward(self, *args, **kwargs)
    237 unbound_forward_of_module_instance = self.to(Module).forward.__func__
    238 try:
--> 239     result = unbound_forward_of_module_instance(self, *args, **kwargs)
    240 # for callback to torch.Module.forward when forward is not implemented
    241 except NotImplementedError:

File ~/desktop/video_model/CogVideoX-Fun/cogvideox/models/transformer3d.py:518, in CogVideoXTransformer3DModel.forward(self, hidden_states, encoder_hidden_states, timestep, timestep_cond, inpaint_latents, control_latents, image_rotary_emb, return_dict)
    509         hidden_states, encoder_hidden_states = torch.utils.checkpoint.checkpoint(
    510             create_custom_forward(block),
    511             hidden_states,
   (...)
    515             **ckpt_kwargs,
    516         )
    517     else:
--> 518         hidden_states, encoder_hidden_states = block(
    519             hidden_states=hidden_states,
    520             encoder_hidden_states=encoder_hidden_states,
    521             temb=emb,
    522             image_rotary_emb=image_rotary_emb,
    523         )
    525 if not self.config.use_rotary_positional_embeddings:
    526     # CogVideoX-2B
    527     hidden_states = self.norm_final(hidden_states)

File ~/anaconda3/envs/cogvideo_v2/lib/python3.9/site-packages/oneflow/nn/graph/proxy.py:188, in ProxyModule.__call__(self, *args, **kwargs)
    178 # NOTE: The original nn.Module's __call__ method is ignored, which means
    179 # that hooks of nn.Modules are ignored. It is not recommended
    180 # to use hooks of nn.Module in nn.Graph for the moment.
    181 with graph_build_util.DebugScopeContext(
    182     self.to(GraphModule)._debug_min_s_level,
    183     self.to(GraphModule)._debug_max_v_level,
   (...)
    186     self.to(GraphModule)._debug_only_user_py_stack,
    187 ):
--> 188     result = self.__block_forward(*args, **kwargs)
    190 outputs = ()
    191 if not (type(result) is tuple or type(result) is list):

File ~/anaconda3/envs/cogvideo_v2/lib/python3.9/site-packages/oneflow/nn/graph/proxy.py:239, in ProxyModule.__block_forward(self, *args, **kwargs)
    237 unbound_forward_of_module_instance = self.to(Module).forward.__func__
    238 try:
--> 239     result = unbound_forward_of_module_instance(self, *args, **kwargs)
    240 # for callback to torch.Module.forward when forward is not implemented
    241 except NotImplementedError:

File ~/desktop/video_model/CogVideoX-Fun/cogvideox/models/transformer3d.py:173, in CogVideoXBlock.forward(self, hidden_states, encoder_hidden_states, temb, image_rotary_emb)
    168 norm_hidden_states, norm_encoder_hidden_states, gate_msa, enc_gate_msa = self.norm1(
    169     hidden_states, encoder_hidden_states, temb
    170 )
    172 # attention
--> 173 attn_hidden_states, attn_encoder_hidden_states = self.attn1(
    174     hidden_states=norm_hidden_states,
    175     encoder_hidden_states=norm_encoder_hidden_states,
    176     image_rotary_emb=image_rotary_emb,
    177 )
    179 hidden_states = hidden_states + gate_msa * attn_hidden_states
    180 encoder_hidden_states = encoder_hidden_states + enc_gate_msa * attn_encoder_hidden_states

File ~/anaconda3/envs/cogvideo_v2/lib/python3.9/site-packages/oneflow/nn/graph/proxy.py:188, in ProxyModule.__call__(self, *args, **kwargs)
    178 # NOTE: The original nn.Module's __call__ method is ignored, which means
    179 # that hooks of nn.Modules are ignored. It is not recommended
    180 # to use hooks of nn.Module in nn.Graph for the moment.
    181 with graph_build_util.DebugScopeContext(
    182     self.to(GraphModule)._debug_min_s_level,
    183     self.to(GraphModule)._debug_max_v_level,
   (...)
    186     self.to(GraphModule)._debug_only_user_py_stack,
    187 ):
--> 188     result = self.__block_forward(*args, **kwargs)
    190 outputs = ()
    191 if not (type(result) is tuple or type(result) is list):

File ~/anaconda3/envs/cogvideo_v2/lib/python3.9/site-packages/oneflow/nn/graph/proxy.py:239, in ProxyModule.__block_forward(self, *args, **kwargs)
    237 unbound_forward_of_module_instance = self.to(Module).forward.__func__
    238 try:
--> 239     result = unbound_forward_of_module_instance(self, *args, **kwargs)
    240 # for callback to torch.Module.forward when forward is not implemented
    241 except NotImplementedError:

File ~/anaconda3/envs/cogvideo_v2/lib/python3.9/site-packages/infer_compiler_registry/register_diffusers/attention_processor_oflow.py:364, in Attention.forward(self, hidden_states, encoder_hidden_states, attention_mask, **cross_attention_kwargs)
    353 def forward(
    354     self,
    355     hidden_states,
   (...)
    361     # here we simply pass along all tensors to the selected processor class
    362     # For standard processors that are defined here, `**cross_attention_kwargs` is empty
--> 364     return self.processor(
    365         self,
    366         hidden_states,
    367         encoder_hidden_states=encoder_hidden_states,
    368         attention_mask=attention_mask,
    369         **cross_attention_kwargs,
    370     )

File ~/anaconda3/envs/cogvideo_v2/lib/python3.9/site-packages/diffusers/models/attention_processor.py:1962, in FusedCogVideoXAttnProcessor2_0.__call__(self, attn, hidden_states, encoder_hidden_states, attention_mask, image_rotary_emb)
   1952 def __call__(
   1953     self,
   1954     attn: Attention,
   (...)
   1958     image_rotary_emb: Optional[torch.Tensor] = None,
   1959 ) -> torch.Tensor:
   1960     text_seq_length = encoder_hidden_states.size(1)
-> 1962     hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
   1964     batch_size, sequence_length, _ = (
   1965         hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
   1966     )
   1968     if attention_mask is not None:

TypeError: expected Tensor as element 0 in argument 0, but got Tensor"
}


@kursatdinc kursatdinc added the Request-bug Something isn't working label Oct 24, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Request-bug Something isn't working
Projects
None yet
Development

No branches or pull requests

1 participant