Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[core] LTX Video 0.9.1 #10330

Merged
merged 19 commits into from
Dec 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 40 additions & 2 deletions docs/source/en/api/pipelines/ltx_video.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License. -->

# LTX
# LTX Video

[LTX Video](https://huggingface.co/Lightricks/LTX-Video) is the first DiT-based video generation model capable of generating high-quality videos in real-time. It produces 24 FPS videos at a 768x512 resolution faster than they can be watched. Trained on a large-scale dataset of diverse videos, the model generates high-resolution videos with realistic and varied content. We provide a model for both text-to-video as well as image + text-to-video usecases.

Expand All @@ -22,14 +22,24 @@ Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers.m

</Tip>

Available models:

| Model name | Recommended dtype |
|:-------------:|:-----------------:|
| [`LTX Video 0.9.0`](https://huggingface.co/Lightricks/LTX-Video/blob/main/ltx-video-2b-v0.9.safetensors) | `torch.bfloat16` |
| [`LTX Video 0.9.1`](https://huggingface.co/Lightricks/LTX-Video/blob/main/ltx-video-2b-v0.9.1.safetensors) | `torch.bfloat16` |

Note: The recommended dtype is for the transformer component. The VAE and text encoders can be either `torch.float32`, `torch.bfloat16` or `torch.float16` but the recommended dtype is `torch.bfloat16` as used in the original repository.

## Loading Single Files

Loading the original LTX Video checkpoints is also possible with [`~ModelMixin.from_single_file`].
Loading the original LTX Video checkpoints is also possible with [`~ModelMixin.from_single_file`]. We recommend using `from_single_file` for the Lightricks series of models, as they plan to release multiple models in the future in the single file format.

```python
import torch
from diffusers import AutoencoderKLLTXVideo, LTXImageToVideoPipeline, LTXVideoTransformer3DModel

# `single_file_url` could also be https://huggingface.co/Lightricks/LTX-Video/ltx-video-2b-v0.9.1.safetensors
single_file_url = "https://huggingface.co/Lightricks/LTX-Video/ltx-video-2b-v0.9.safetensors"
transformer = LTXVideoTransformer3DModel.from_single_file(
single_file_url, torch_dtype=torch.bfloat16
Expand Down Expand Up @@ -99,6 +109,34 @@ export_to_video(video, "output_gguf_ltx.mp4", fps=24)

Make sure to read the [documentation on GGUF](../../quantization/gguf) to learn more about our GGUF support.

<!-- TODO(aryan): Update this when official weights are supported -->

Loading and running inference with [LTX Video 0.9.1](https://huggingface.co/Lightricks/LTX-Video/blob/main/ltx-video-2b-v0.9.1.safetensors) weights.

```python
import torch
from diffusers import LTXPipeline
from diffusers.utils import export_to_video

pipe = LTXPipeline.from_pretrained("a-r-r-o-w/LTX-Video-0.9.1-diffusers", torch_dtype=torch.bfloat16)
pipe.to("cuda")

prompt = "A woman with long brown hair and light skin smiles at another woman with long blonde hair. The woman with brown hair wears a black jacket and has a small, barely noticeable mole on her right cheek. The camera angle is a close-up, focused on the woman with brown hair's face. The lighting is warm and natural, likely from the setting sun, casting a soft glow on the scene. The scene appears to be real-life footage"
negative_prompt = "worst quality, inconsistent motion, blurry, jittery, distorted"

video = pipe(
prompt=prompt,
negative_prompt=negative_prompt,
width=768,
height=512,
num_frames=161,
decode_timestep=0.03,
decode_noise_scale=0.025,
num_inference_steps=50,
).frames[0]
export_to_video(video, "output.mp4", fps=24)
```

Refer to [this section](https://huggingface.co/docs/diffusers/main/en/api/pipelines/cogvideox#memory-optimization) to learn more about optimizing memory consumption.

## LTXPipeline
Expand Down
110 changes: 99 additions & 11 deletions scripts/convert_ltx_to_diffusers.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import argparse
from pathlib import Path
from typing import Any, Dict

import torch
from accelerate import init_empty_weights
from safetensors.torch import load_file
from transformers import T5EncoderModel, T5Tokenizer

Expand All @@ -21,7 +23,9 @@ def remove_keys_(key: str, state_dict: Dict[str, Any]):
"k_norm": "norm_k",
}

TRANSFORMER_SPECIAL_KEYS_REMAP = {}
TRANSFORMER_SPECIAL_KEYS_REMAP = {
"vae": remove_keys_,
}

VAE_KEYS_RENAME_DICT = {
# decoder
Expand Down Expand Up @@ -54,10 +58,31 @@ def remove_keys_(key: str, state_dict: Dict[str, Any]):
"per_channel_statistics.std-of-means": "latents_std",
}

VAE_091_RENAME_DICT = {
# decoder
"up_blocks.0": "mid_block",
"up_blocks.1": "up_blocks.0.upsamplers.0",
"up_blocks.2": "up_blocks.0",
"up_blocks.3": "up_blocks.1.upsamplers.0",
"up_blocks.4": "up_blocks.1",
"up_blocks.5": "up_blocks.2.upsamplers.0",
"up_blocks.6": "up_blocks.2",
"up_blocks.7": "up_blocks.3.upsamplers.0",
"up_blocks.8": "up_blocks.3",
# common
"last_time_embedder": "time_embedder",
"last_scale_shift_table": "scale_shift_table",
}

VAE_SPECIAL_KEYS_REMAP = {
"per_channel_statistics.channel": remove_keys_,
"per_channel_statistics.mean-of-means": remove_keys_,
"per_channel_statistics.mean-of-stds": remove_keys_,
"model.diffusion_model": remove_keys_,
}

VAE_091_SPECIAL_KEYS_REMAP = {
"timestep_scale_multiplier": remove_keys_,
}


Expand All @@ -80,13 +105,16 @@ def convert_transformer(
ckpt_path: str,
dtype: torch.dtype,
):
PREFIX_KEY = ""
PREFIX_KEY = "model.diffusion_model."

original_state_dict = get_state_dict(load_file(ckpt_path))
transformer = LTXVideoTransformer3DModel().to(dtype=dtype)
with init_empty_weights():
transformer = LTXVideoTransformer3DModel()

for key in list(original_state_dict.keys()):
new_key = key[len(PREFIX_KEY) :]
new_key = key[:]
if new_key.startswith(PREFIX_KEY):
new_key = key[len(PREFIX_KEY) :]
for replace_key, rename_key in TRANSFORMER_KEYS_RENAME_DICT.items():
new_key = new_key.replace(replace_key, rename_key)
update_state_dict_inplace(original_state_dict, key, new_key)
Expand All @@ -97,16 +125,21 @@ def convert_transformer(
continue
handler_fn_inplace(key, original_state_dict)

transformer.load_state_dict(original_state_dict, strict=True)
transformer.load_state_dict(original_state_dict, strict=True, assign=True)
return transformer


def convert_vae(ckpt_path: str, dtype: torch.dtype):
def convert_vae(ckpt_path: str, config, dtype: torch.dtype):
PREFIX_KEY = "vae."

original_state_dict = get_state_dict(load_file(ckpt_path))
vae = AutoencoderKLLTXVideo().to(dtype=dtype)
with init_empty_weights():
vae = AutoencoderKLLTXVideo(**config)

for key in list(original_state_dict.keys()):
new_key = key[:]
if new_key.startswith(PREFIX_KEY):
new_key = key[len(PREFIX_KEY) :]
for replace_key, rename_key in VAE_KEYS_RENAME_DICT.items():
new_key = new_key.replace(replace_key, rename_key)
update_state_dict_inplace(original_state_dict, key, new_key)
Expand All @@ -117,10 +150,60 @@ def convert_vae(ckpt_path: str, dtype: torch.dtype):
continue
handler_fn_inplace(key, original_state_dict)

vae.load_state_dict(original_state_dict, strict=True)
vae.load_state_dict(original_state_dict, strict=True, assign=True)
return vae


def get_vae_config(version: str) -> Dict[str, Any]:
if version == "0.9.0":
config = {
"in_channels": 3,
"out_channels": 3,
"latent_channels": 128,
"block_out_channels": (128, 256, 512, 512),
"decoder_block_out_channels": (128, 256, 512, 512),
"layers_per_block": (4, 3, 3, 3, 4),
"decoder_layers_per_block": (4, 3, 3, 3, 4),
"spatio_temporal_scaling": (True, True, True, False),
"decoder_spatio_temporal_scaling": (True, True, True, False),
"decoder_inject_noise": (False, False, False, False, False),
"upsample_residual": (False, False, False, False),
"upsample_factor": (1, 1, 1, 1),
"patch_size": 4,
"patch_size_t": 1,
"resnet_norm_eps": 1e-6,
"scaling_factor": 1.0,
"encoder_causal": True,
"decoder_causal": False,
"timestep_conditioning": False,
}
elif version == "0.9.1":
config = {
"in_channels": 3,
"out_channels": 3,
"latent_channels": 128,
"block_out_channels": (128, 256, 512, 512),
"decoder_block_out_channels": (256, 512, 1024),
"layers_per_block": (4, 3, 3, 3, 4),
"decoder_layers_per_block": (5, 6, 7, 8),
"spatio_temporal_scaling": (True, True, True, False),
"decoder_spatio_temporal_scaling": (True, True, True),
"decoder_inject_noise": (True, True, True, False),
"upsample_residual": (True, True, True),
"upsample_factor": (2, 2, 2),
"timestep_conditioning": True,
"patch_size": 4,
"patch_size_t": 1,
"resnet_norm_eps": 1e-6,
"scaling_factor": 1.0,
"encoder_causal": True,
"decoder_causal": False,
}
VAE_KEYS_RENAME_DICT.update(VAE_091_RENAME_DICT)
VAE_SPECIAL_KEYS_REMAP.update(VAE_091_SPECIAL_KEYS_REMAP)
return config


def get_args():
parser = argparse.ArgumentParser()
parser.add_argument(
Expand All @@ -139,6 +222,9 @@ def get_args():
parser.add_argument("--save_pipeline", action="store_true")
parser.add_argument("--output_path", type=str, required=True, help="Path where converted model should be saved")
parser.add_argument("--dtype", default="fp32", help="Torch dtype to save the model in.")
parser.add_argument(
"--version", type=str, default="0.9.0", choices=["0.9.0", "0.9.1"], help="Version of the LTX model"
)
return parser.parse_args()


Expand All @@ -161,6 +247,7 @@ def get_args():
transformer = None
dtype = DTYPE_MAPPING[args.dtype]
variant = VARIANT_MAPPING[args.dtype]
output_path = Path(args.output_path)

if args.save_pipeline:
assert args.transformer_ckpt_path is not None and args.vae_ckpt_path is not None
Expand All @@ -169,13 +256,14 @@ def get_args():
transformer: LTXVideoTransformer3DModel = convert_transformer(args.transformer_ckpt_path, dtype)
if not args.save_pipeline:
transformer.save_pretrained(
args.output_path, safe_serialization=True, max_shard_size="5GB", variant=variant
output_path / "transformer", safe_serialization=True, max_shard_size="5GB", variant=variant
)

if args.vae_ckpt_path is not None:
vae: AutoencoderKLLTXVideo = convert_vae(args.vae_ckpt_path, dtype)
config = get_vae_config(args.version)
vae: AutoencoderKLLTXVideo = convert_vae(args.vae_ckpt_path, config, dtype)
if not args.save_pipeline:
vae.save_pretrained(args.output_path, safe_serialization=True, max_shard_size="5GB", variant=variant)
vae.save_pretrained(output_path / "vae", safe_serialization=True, max_shard_size="5GB", variant=variant)

if args.save_pipeline:
text_encoder_id = "google/t5-v1_1-xxl"
Expand Down
28 changes: 26 additions & 2 deletions src/diffusers/loaders/single_file_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,8 @@
"flux-fill": {"pretrained_model_name_or_path": "black-forest-labs/FLUX.1-Fill-dev"},
"flux-depth": {"pretrained_model_name_or_path": "black-forest-labs/FLUX.1-Depth-dev"},
"flux-schnell": {"pretrained_model_name_or_path": "black-forest-labs/FLUX.1-schnell"},
"ltx-video": {"pretrained_model_name_or_path": "Lightricks/LTX-Video"},
"ltx-video": {"pretrained_model_name_or_path": "diffusers/LTX-Video-0.9.0"},
"ltx-video-0.9.1": {"pretrained_model_name_or_path": "diffusers/LTX-Video-0.9.1"},
"autoencoder-dc-f128c512": {"pretrained_model_name_or_path": "mit-han-lab/dc-ae-f128c512-mix-1.0-diffusers"},
"autoencoder-dc-f64c128": {"pretrained_model_name_or_path": "mit-han-lab/dc-ae-f64c128-mix-1.0-diffusers"},
"autoencoder-dc-f32c32": {"pretrained_model_name_or_path": "mit-han-lab/dc-ae-f32c32-mix-1.0-diffusers"},
Expand Down Expand Up @@ -605,7 +606,10 @@ def infer_diffusers_model_type(checkpoint):
model_type = "flux-schnell"

elif any(key in checkpoint for key in CHECKPOINT_KEY_NAMES["ltx-video"]):
model_type = "ltx-video"
if "vae.decoder.last_time_embedder.timestep_embedder.linear_1.weight" in checkpoint:
model_type = "ltx-video-0.9.1"
else:
model_type = "ltx-video"

elif CHECKPOINT_KEY_NAMES["autoencoder-dc"] in checkpoint:
encoder_key = "encoder.project_in.conv.conv.bias"
Expand Down Expand Up @@ -2338,12 +2342,32 @@ def remove_keys_(key: str, state_dict):
"per_channel_statistics.std-of-means": "latents_std",
}

VAE_091_RENAME_DICT = {
# decoder
"up_blocks.0": "mid_block",
"up_blocks.1": "up_blocks.0.upsamplers.0",
"up_blocks.2": "up_blocks.0",
"up_blocks.3": "up_blocks.1.upsamplers.0",
"up_blocks.4": "up_blocks.1",
"up_blocks.5": "up_blocks.2.upsamplers.0",
"up_blocks.6": "up_blocks.2",
"up_blocks.7": "up_blocks.3.upsamplers.0",
"up_blocks.8": "up_blocks.3",
# common
"last_time_embedder": "time_embedder",
"last_scale_shift_table": "scale_shift_table",
}

VAE_SPECIAL_KEYS_REMAP = {
"per_channel_statistics.channel": remove_keys_,
"per_channel_statistics.mean-of-means": remove_keys_,
"per_channel_statistics.mean-of-stds": remove_keys_,
"timestep_scale_multiplier": remove_keys_,
}

if "vae.decoder.last_time_embedder.timestep_embedder.linear_1.weight" in converted_state_dict:
VAE_KEYS_RENAME_DICT.update(VAE_091_RENAME_DICT)

for key in list(converted_state_dict.keys()):
new_key = key
for replace_key, rename_key in VAE_KEYS_RENAME_DICT.items():
Expand Down
Loading
Loading