Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

About I2V finetune in wan #116

Closed
svjack opened this issue Mar 8, 2025 · 4 comments
Closed

About I2V finetune in wan #116

svjack opened this issue Mar 8, 2025 · 4 comments

Comments

@svjack
Copy link

svjack commented Mar 8, 2025

Hi when i try

#### Pre Compute
python wan_cache_latents.py --dataset_config pixel_video_config.toml --vae Wan2.1_VAE.pth --clip models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth
python wan_cache_text_encoder_outputs.py --dataset_config pixel_video_config.toml --t5 models_t5_umt5-xxl-enc-bf16.pth --batch_size 16

wget https://huggingface.co/Comfy-Org/Wan_2.1_ComfyUI_repackaged/resolve/main/split_files/diffusion_models/wan2.1_i2v_480p_14B_fp8_e4m3fn.safetensors
wget https://huggingface.co/Comfy-Org/Wan_2.1_ComfyUI_repackaged/resolve/main/split_files/diffusion_models/wan2.1_i2v_480p_14B_bf16.safetensors

accelerate launch --num_cpu_threads_per_process 1 --mixed_precision bf16 wan_train_network.py \
    --task i2v-14B --t5 models_t5_umt5-xxl-enc-bf16.pth --clip models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth \
    --dit wan2.1_i2v_480p_14B_fp8_e4m3fn.safetensors \
    --dataset_config pixel_video_config.toml --sdpa --mixed_precision bf16 --fp8_base \
    --optimizer_type adamw8bit --learning_rate 2e-4 --gradient_checkpointing \
    --max_data_loader_n_workers 2 --persistent_data_loader_workers \
    --network_module networks.lora_wan --network_dim 32 \
    --timestep_sampling shift --discrete_flow_shift 3.0 \
    --max_train_epochs 16 --save_every_n_epochs 1 --seed 42 \
    --output_dir pixel_outputs --output_name pixel_w14_lora

I get channel mismatch error (weight have channel 36 (because in_dim in script) but tensor has 21)
The data config like

# general configurations
[general]
resolution = [960, 544]
caption_extension = ".txt"
batch_size = 1
enable_bucket = true
bucket_no_upscale = false

[[datasets]]
video_directory = "test-HunyuanVideo-pixelart-videos_960x544x6"
cache_directory = "test-HunyuanVideo-pixelart-videos_960x544x6_cache" # recommended to set cache directory
target_frames = [25, 45]
frame_extraction = "head"

Error log

epoch 1/16
INFO:dataset.image_video_dataset:epoch is incremented. current_epoch: 0, epoch: 1
INFO:dataset.image_video_dataset:epoch is incremented. current_epoch: 0, epoch: 1
Traceback (most recent call last):
  File "/home/featurize/musubi-tuner/wan_train_network.py", line 414, in <module>
    trainer.train(args)
  File "/home/featurize/musubi-tuner/hv_train_network.py", line 1847, in train
    model_pred, target = self.call_dit(
                         ^^^^^^^^^^^^^^
  File "/home/featurize/musubi-tuner/wan_train_network.py", line 376, in call_dit
    model_pred = model(noisy_model_input, t=timesteps, context=context, clip_fea=clip_fea, seq_len=seq_len, y=image_latents)
                 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/environment/miniconda3/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/environment/miniconda3/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/environment/miniconda3/lib/python3.11/site-packages/accelerate/utils/operations.py", line 823, in forward
    return model_forward(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/environment/miniconda3/lib/python3.11/site-packages/accelerate/utils/operations.py", line 811, in __call__
    return convert_to_fp32(self.model_forward(*args, **kwargs))
                           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/environment/miniconda3/lib/python3.11/site-packages/torch/amp/autocast_mode.py", line 16, in decorate_autocast
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/home/featurize/musubi-tuner/wan/modules/model.py", line 698, in forward
    x = [self.patch_embedding(u.unsqueeze(0)) for u in x]
        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/featurize/musubi-tuner/wan/modules/model.py", line 698, in <listcomp>
    x = [self.patch_embedding(u.unsqueeze(0)) for u in x]
         ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/environment/miniconda3/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/environment/miniconda3/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/environment/miniconda3/lib/python3.11/site-packages/torch/nn/modules/conv.py", line 610, in forward
    return self._conv_forward(input, self.weight, self.bias)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/environment/miniconda3/lib/python3.11/site-packages/torch/nn/modules/conv.py", line 605, in _conv_forward
    return F.conv3d(
           ^^^^^^^^^
RuntimeError: Given groups=1, weight of size [5120, 36, 1, 2, 2], expected input[1, 21, 1, 68, 120] to have 36 channels, but got 21 channels instead
steps:   0%|                                                                                                                                                     | 0/9264 [00:00<?, ?it/s]
Traceback (most recent call last):
  File "/environment/miniconda3/bin/accelerate", line 8, in <module>
    sys.exit(main())
             ^^^^^^
  File "/environment/miniconda3/lib/python3.11/site-packages/accelerate/commands/accelerate_cli.py", line 48, in main
    args.func(args)
  File "/environment/miniconda3/lib/python3.11/site-packages/accelerate/commands/launch.py", line 1168, in launch_command
    simple_launcher(args)
  File "/environment/miniconda3/lib/python3.11/site-packages/accelerate/commands/launch.py", line 763, in simple_launcher
    raise subprocess.CalledProcessError(returncode=process.returncode, cmd=cmd)
subprocess.CalledProcessError: Command '['/environment/miniconda3/bin/python', 'wan_train_network.py', '--task', 'i2v-14B', '--t5', 'models_t5_umt5-xxl-enc-bf16.pth', '--clip', 'models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth', '--dit', 'wan2.1_i2v_480p_14B_fp8_e4m3fn.safetensors', '--dataset_config', 'pixel_video_config.toml', '--sdpa', '--mixed_precision', 'bf16', '--fp8_base', '--optimizer_type', 'adamw8bit', '--learning_rate', '2e-4', '--gradient_checkpointing', '--max_data_loader_n_workers', '2', '--persistent_data_loader_workers', '--network_module', 'networks.lora_wan', '--network_dim', '32', '--timestep_sampling', 'shift', '--discrete_flow_shift', '3.0', '--max_train_epochs', '16', '--save_every_n_epochs', '1', '--seed', '42', '--output_dir', 'pixel_outputs', '--output_name', 'pixel_w14_lora']' returned non-zero exit status 1.

Can you give me some suggestions ?🙂

@luckystrike23
Copy link

luckystrike23 commented Mar 8, 2025

i had the same errors. still dont know how to solve it. i tried fp16 or fp8 models, none seem to work.

@orssorbit
Copy link

I don't think I2V works with images or single frame latents, you can see its using F.conv3d which expects an additional dimension, which is length in this case. I had the same issue but worked when I made sure frames were > 1.

@luckystrike23
Copy link

I don't think I2V works with images or single frame latents, you can see its using F.conv3d which expects an additional dimension, which is length in this case. I had the same issue but worked when I made sure frames were > 1.

but my input dataset only has videos that are at least 48 frames long. there are no images in the dataset

@kohya-ss
Copy link
Owner

kohya-ss commented Mar 8, 2025

I believe this issue has been fixed. Please run wan_cache_latents.py again, and re-open if any issue persists.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants