Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 17 additions & 6 deletions scripts/convert_ltx2_to_diffusers.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,8 @@
"up_blocks.4": "up_blocks.1",
"up_blocks.5": "up_blocks.2.upsamplers.0",
"up_blocks.6": "up_blocks.2",
"last_time_embedder": "time_embedder",
"last_scale_shift_table": "scale_shift_table",
# Common
# For all 3D ResNets
"res_blocks": "resnets",
Expand Down Expand Up @@ -372,7 +374,9 @@ def convert_ltx2_connectors(original_state_dict: Dict[str, Any], version: str) -
return connectors


def get_ltx2_video_vae_config(version: str) -> Tuple[Dict[str, Any], Dict[str, Any], Dict[str, Any]]:
def get_ltx2_video_vae_config(
version: str, timestep_conditioning: bool = False
) -> Tuple[Dict[str, Any], Dict[str, Any], Dict[str, Any]]:
if version == "test":
config = {
"model_id": "diffusers-internal-dev/dummy-ltx2",
Expand All @@ -396,7 +400,7 @@ def get_ltx2_video_vae_config(version: str) -> Tuple[Dict[str, Any], Dict[str, A
"downsample_type": ("spatial", "temporal", "spatiotemporal", "spatiotemporal"),
"upsample_residual": (True, True, True),
"upsample_factor": (2, 2, 2),
"timestep_conditioning": False,
"timestep_conditioning": timestep_conditioning,
"patch_size": 4,
"patch_size_t": 1,
"resnet_norm_eps": 1e-6,
Expand Down Expand Up @@ -433,7 +437,7 @@ def get_ltx2_video_vae_config(version: str) -> Tuple[Dict[str, Any], Dict[str, A
"downsample_type": ("spatial", "temporal", "spatiotemporal", "spatiotemporal"),
"upsample_residual": (True, True, True),
"upsample_factor": (2, 2, 2),
"timestep_conditioning": False,
"timestep_conditioning": timestep_conditioning,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Where is timestep_conditioning used in the corresponding Video VAE class?

Copy link
Contributor Author

@rootonchair rootonchair Jan 12, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

timestep_conditioning is used here. It will depend on the metadata stored in the original checkpoint and distilled checkpoint is configured to use timestep_conditioning. You can confirm it using below script:

from ltx_core.loader.sft_loader import SafetensorsModelStateDictLoader
model_loader = SafetensorsModelStateDictLoader()
model_loader.metadata("weights/ltx-2-19b-distilled.safetensors")["vae"]["timestep_conditioning"] # True

Hence, I got some unexpected keys if don't turn timestep_conditioning to True when converting distilled weights

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @rootonchair, there was a recent update on the official Lightricks/LTX-2 repo which updated the VAE for the distilled checkpoint: https://huggingface.co/Lightricks/LTX-2/commit/1931de987c8e265eb64a9123227b903754e3cc68. So I think rootonchair/LTX-2-19b-distilled needs to be updated with the new converted VAE as well.

It looks like timestep_conditioning should be set to False for the new video VAE, according to the config metadata on ltx-2-19b-distilled.safetensors. So we should probably remove the unexpected keys (last_time_embedder, last_scale_shift_table) rather than processing them in the VAE conversion script.

import json
import safetensors
from huggingface_hub import hf_hub_download

ckpt_path = hf_hub_download("Lightricks/LTX-2", "ltx-2-19b-distilled.safetensors")
with safetensors.safe_open(ckpt_path, framework="pt") as f:
    config = json.loads(f.metadata()["config"])
config["vae"]["timestep_conditioning"]  # Should be False

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think I will convert the new weight. For the convert script, I thinnk we should keep it there until the original repo decide to delete it

"patch_size": 4,
"patch_size_t": 1,
"resnet_norm_eps": 1e-6,
Expand All @@ -450,8 +454,10 @@ def get_ltx2_video_vae_config(version: str) -> Tuple[Dict[str, Any], Dict[str, A
return config, rename_dict, special_keys_remap


def convert_ltx2_video_vae(original_state_dict: Dict[str, Any], version: str) -> Dict[str, Any]:
config, rename_dict, special_keys_remap = get_ltx2_video_vae_config(version)
def convert_ltx2_video_vae(
original_state_dict: Dict[str, Any], version: str, timestep_conditioning: bool
) -> Dict[str, Any]:
config, rename_dict, special_keys_remap = get_ltx2_video_vae_config(version, timestep_conditioning)
diffusers_config = config["diffusers_config"]

with init_empty_weights():
Expand Down Expand Up @@ -717,6 +723,9 @@ def get_args():
help="Latent upsampler filename",
)

parser.add_argument(
"--timestep_conditioning", action="store_true", help="Whether to add timestep condition to the video VAE model"
)
parser.add_argument("--vae", action="store_true", help="Whether to convert the video VAE model")
parser.add_argument("--audio_vae", action="store_true", help="Whether to convert the audio VAE model")
parser.add_argument("--dit", action="store_true", help="Whether to convert the DiT model")
Expand Down Expand Up @@ -786,7 +795,9 @@ def main(args):
original_vae_ckpt = load_hub_or_local_checkpoint(filename=args.vae_filename)
elif combined_ckpt is not None:
original_vae_ckpt = get_model_state_dict_from_combined_ckpt(combined_ckpt, args.vae_prefix)
vae = convert_ltx2_video_vae(original_vae_ckpt, version=args.version)
vae = convert_ltx2_video_vae(
original_vae_ckpt, version=args.version, timestep_conditioning=args.timestep_conditioning
)
if not args.full_pipeline and not args.upsample_pipeline:
vae.to(vae_dtype).save_pretrained(os.path.join(args.output_path, "vae"))

Expand Down
62 changes: 42 additions & 20 deletions src/diffusers/pipelines/ltx2/pipeline_ltx2.py
Original file line number Diff line number Diff line change
Expand Up @@ -653,6 +653,11 @@ def prepare_latents(
latents: Optional[torch.Tensor] = None,
) -> torch.Tensor:
if latents is not None:
if latents.ndim == 5:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this needed to support precomputed latents?

Copy link
Contributor Author

@rootonchair rootonchair Jan 12, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it will add another support layer for precomputed latents. The latents returned are in shape [B, C, F, H, W], however, the model takes input as [B, seq_len, C] and passing the returned latents directly would cause shape mismatch. The same applies for audio vae

# latents are of shape [B, C, F, H, W], need to be packed
latents = self._pack_latents(
latents, self.transformer_spatial_patch_size, self.transformer_temporal_patch_size
)
return latents.to(device=device, dtype=dtype)

height = height // self.vae_spatial_compression_ratio
Expand All @@ -677,29 +682,23 @@ def prepare_audio_latents(
self,
batch_size: int = 1,
num_channels_latents: int = 8,
audio_latent_length: int = 1, # 1 is just a dummy value
num_mel_bins: int = 64,
num_frames: int = 121,
frame_rate: float = 25.0,
sampling_rate: int = 16000,
hop_length: int = 160,
dtype: Optional[torch.dtype] = None,
device: Optional[torch.device] = None,
generator: Optional[torch.Generator] = None,
latents: Optional[torch.Tensor] = None,
) -> torch.Tensor:
duration_s = num_frames / frame_rate
latents_per_second = (
float(sampling_rate) / float(hop_length) / float(self.audio_vae_temporal_compression_ratio)
)
latent_length = round(duration_s * latents_per_second)

if latents is not None:
return latents.to(device=device, dtype=dtype), latent_length
if latents.ndim == 4:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same question as above.

# latents are of shape [B, C, L, M], need to be packed
latents = self._pack_audio_latents(latents)
Copy link
Collaborator

@dg845 dg845 Jan 14, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When the output_type is set to "latent", LTX2Pipeline will return the denormalized audio latents:

audio_latents = self._denormalize_audio_latents(
audio_latents, self.audio_vae.latents_mean, self.audio_vae.latents_std
)
audio_latents = self._unpack_audio_latents(audio_latents, audio_num_frames, num_mel_bins=latent_mel_bins)

Since the DiT expects normalized latents, I think we need to normalize the audio latents here:

                latents = self._pack_audio_latents(latents)
                latents = self._normalize_audio_latents(latents, self.audio_vae.latents_mean, self.audio_vae.latents_std)

where _normalize_audio_latents is something like

    @staticmethod
    def _normalize_audio_latents(
        latents: torch.Tensor, latents_mean: torch.Tensor, latents_std: torch.Tensor
    ) -> torch.Tensor:
        # Normalize latents across the combined channel and mel bin dimension [B, L, C * M]
        latents_mean = latents_mean.to(latents.device, latents.dtype)
        latents_std = latents_std.to(latents.device, latents.dtype)
        latents = (latents - latents_mean) / latents_std
        return latents

This should make the Stage 2 audio latents have the expected distribution.

Copy link
Collaborator

@dg845 dg845 Jan 14, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note that although LTX2Pipeline also returns the denormalized video latents when output_type="latent", LTX2LatentUpsamplePipeline returns the normalized latents, so the Stage 2 video latents are not affected by this issue.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(However, prepare_latents-type methods usually expect supplied latents to be normalized, since we generally return them as-is without normalizing them, so we might need to think through the design on where we expect latents to be normalized or denormalized. CC @sayakpaul.)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, this is a bit of a spiraling situation and I wonder if exposing a flag like normalize_latents (or leverage the latents_normalized flag as used in the upsampling pipeline) could make sense here (we keep it to a reasonable default) and we educate the users about when to use what for normalize_latents. I think this also comes with the uniqueness of the LTX2 pipeline a bit.

The situation is getting confusing likely because the video latents returned from the upsampling pipeline is always normalized. If the users wants them unnormalized (with a flag normalize_latents=False, then I guess both upsampled_video_latent and audio_latent could be passed as is to the stage 2 pipeline?

(However, prepare_latents-type methods usually expect supplied latents to be normalized, since we generally return them as-is without normalizing them, so we might need to think through the design on where we expect latents to be normalized or denormalized.

I think this is both true and false. Inside Flux Kontext, the image latents are normalized inside prepare_latents():

image_latents = self._encode_vae_image(image=image, generator=generator)

This is what we also follow for LTX2 I2V:

init_latents = self._normalize_latents(init_latents, self.vae.latents_mean, self.vae.latents_std)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One perspective we should consider to keep the returned latents denormalized as default is that there could be a pipeline where each module has different methods for normalizing. Only passing denormalized latents and each module need to normalize the latents themselves will help them maintain the plug-and-play ability. Moreover, the computation required for normalizing is acceptable

return latents.to(device=device, dtype=dtype)

# TODO: confirm whether this logic is correct
latent_mel_bins = num_mel_bins // self.audio_vae_mel_compression_ratio

shape = (batch_size, num_channels_latents, latent_length, latent_mel_bins)
shape = (batch_size, num_channels_latents, audio_latent_length, latent_mel_bins)

if isinstance(generator, list) and len(generator) != batch_size:
raise ValueError(
Expand All @@ -709,7 +708,7 @@ def prepare_audio_latents(

latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
latents = self._pack_audio_latents(latents)
return latents, latent_length
return latents

@property
def guidance_scale(self):
Expand Down Expand Up @@ -750,6 +749,7 @@ def __call__(
num_frames: int = 121,
frame_rate: float = 24.0,
num_inference_steps: int = 40,
sigmas: Optional[List[float]] = None,
timesteps: List[int] = None,
guidance_scale: float = 4.0,
guidance_rescale: float = 0.0,
Expand Down Expand Up @@ -788,6 +788,10 @@ def __call__(
num_inference_steps (`int`, *optional*, defaults to 40):
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
expense of slower inference.
sigmas (`List[float]`, *optional*):
Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in
their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
will be used.
timesteps (`List[int]`, *optional*):
Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
Expand Down Expand Up @@ -922,6 +926,14 @@ def __call__(
latent_num_frames = (num_frames - 1) // self.vae_temporal_compression_ratio + 1
latent_height = height // self.vae_spatial_compression_ratio
latent_width = width // self.vae_spatial_compression_ratio
if latents is not None:
if latents.ndim == 5:
_, _, latent_num_frames, latent_height, latent_width = latents.shape # [B, C, F, H, W]
else:
logger.warning(
f"You have supplied packed `latents` of shape {latents.shape}, so the latent dims cannot be"
f" inferred. Make sure the supplied `height`, `width`, and `num_frames` are correct."
)
video_sequence_length = latent_num_frames * latent_height * latent_width

num_channels_latents = self.transformer.config.in_channels
Expand All @@ -937,28 +949,38 @@ def __call__(
latents,
)

duration_s = num_frames / frame_rate
audio_latents_per_second = (
self.audio_sampling_rate / self.audio_hop_length / float(self.audio_vae_temporal_compression_ratio)
)
audio_num_frames = round(duration_s * audio_latents_per_second)
if audio_latents is not None:
if audio_latents.ndim == 4:
_, _, audio_num_frames, _ = audio_latents.shape # [B, C, L, M]
else:
logger.warning(
f"You have supplied packed `audio_latents` of shape {audio_latents.shape}, so the latent dims"
f" cannot be inferred. Make sure the supplied `num_frames` is correct."
)

num_mel_bins = self.audio_vae.config.mel_bins if getattr(self, "audio_vae", None) is not None else 64
latent_mel_bins = num_mel_bins // self.audio_vae_mel_compression_ratio

num_channels_latents_audio = (
self.audio_vae.config.latent_channels if getattr(self, "audio_vae", None) is not None else 8
)
audio_latents, audio_num_frames = self.prepare_audio_latents(
audio_latents = self.prepare_audio_latents(
batch_size * num_videos_per_prompt,
num_channels_latents=num_channels_latents_audio,
audio_latent_length=audio_num_frames,
num_mel_bins=num_mel_bins,
num_frames=num_frames, # Video frames, audio frames will be calculated from this
frame_rate=frame_rate,
sampling_rate=self.audio_sampling_rate,
hop_length=self.audio_hop_length,
dtype=torch.float32,
device=device,
generator=generator,
latents=audio_latents,
)

# 5. Prepare timesteps
sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps)
sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's apply similar changes to the I2V module as well?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, will apply the same change

mu = calculate_shift(
video_sequence_length,
self.scheduler.config.get("base_image_seq_len", 1024),
Expand Down
72 changes: 47 additions & 25 deletions src/diffusers/pipelines/ltx2/pipeline_ltx2_image2video.py
Original file line number Diff line number Diff line change
Expand Up @@ -689,6 +689,11 @@ def prepare_latents(
conditioning_mask = self._pack_latents(
conditioning_mask, self.transformer_spatial_patch_size, self.transformer_temporal_patch_size
).squeeze(-1)
if latents.ndim == 5:
# latents are of shape [B, C, F, H, W], need to be packed
latents = self._pack_latents(
latents, self.transformer_spatial_patch_size, self.transformer_temporal_patch_size
)
if latents.ndim != 3 or latents.shape[:2] != conditioning_mask.shape:
raise ValueError(
f"Provided `latents` tensor has shape {latents.shape}, but the expected shape is {conditioning_mask.shape + (num_channels_latents,)}."
Expand Down Expand Up @@ -737,29 +742,23 @@ def prepare_audio_latents(
self,
batch_size: int = 1,
num_channels_latents: int = 8,
audio_latent_length: int = 1, # 1 is just a dummy value
num_mel_bins: int = 64,
num_frames: int = 121,
frame_rate: float = 25.0,
sampling_rate: int = 16000,
hop_length: int = 160,
dtype: Optional[torch.dtype] = None,
device: Optional[torch.device] = None,
generator: Optional[torch.Generator] = None,
latents: Optional[torch.Tensor] = None,
) -> torch.Tensor:
duration_s = num_frames / frame_rate
latents_per_second = (
float(sampling_rate) / float(hop_length) / float(self.audio_vae_temporal_compression_ratio)
)
latent_length = round(duration_s * latents_per_second)

if latents is not None:
return latents.to(device=device, dtype=dtype), latent_length
if latents.ndim == 4:
# latents are of shape [B, C, L, M], need to be packed
latents = self._pack_audio_latents(latents)
return latents.to(device=device, dtype=dtype)

# TODO: confirm whether this logic is correct
latent_mel_bins = num_mel_bins // self.audio_vae_mel_compression_ratio

shape = (batch_size, num_channels_latents, latent_length, latent_mel_bins)
shape = (batch_size, num_channels_latents, audio_latent_length, latent_mel_bins)

if isinstance(generator, list) and len(generator) != batch_size:
raise ValueError(
Expand All @@ -769,7 +768,7 @@ def prepare_audio_latents(

latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
latents = self._pack_audio_latents(latents)
return latents, latent_length
return latents

@property
def guidance_scale(self):
Expand Down Expand Up @@ -811,6 +810,7 @@ def __call__(
num_frames: int = 121,
frame_rate: float = 24.0,
num_inference_steps: int = 40,
sigmas: Optional[List[float]] = None,
timesteps: List[int] = None,
guidance_scale: float = 4.0,
guidance_rescale: float = 0.0,
Expand Down Expand Up @@ -851,6 +851,10 @@ def __call__(
num_inference_steps (`int`, *optional*, defaults to 40):
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
expense of slower inference.
sigmas (`List[float]`, *optional*):
Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in
their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
will be used.
timesteps (`List[int]`, *optional*):
Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
Expand Down Expand Up @@ -982,6 +986,19 @@ def __call__(
)

# 4. Prepare latent variables
latent_num_frames = (num_frames - 1) // self.vae_temporal_compression_ratio + 1
latent_height = height // self.vae_spatial_compression_ratio
latent_width = width // self.vae_spatial_compression_ratio
if latents is not None:
if latents.ndim == 5:
_, _, latent_num_frames, latent_height, latent_width = latents.shape # [B, C, F, H, W]
else:
logger.warning(
f"You have supplied packed `latents` of shape {latents.shape}, so the latent dims cannot be"
f" inferred. Make sure the supplied `height`, `width`, and `num_frames` are correct."
)
video_sequence_length = latent_num_frames * latent_height * latent_width

if latents is None:
image = self.video_processor.preprocess(image, height=height, width=width)
image = image.to(device=device, dtype=prompt_embeds.dtype)
Expand All @@ -1002,33 +1019,38 @@ def __call__(
if self.do_classifier_free_guidance:
conditioning_mask = torch.cat([conditioning_mask, conditioning_mask])

duration_s = num_frames / frame_rate
audio_latents_per_second = (
self.audio_sampling_rate / self.audio_hop_length / float(self.audio_vae_temporal_compression_ratio)
)
audio_num_frames = round(duration_s * audio_latents_per_second)
if audio_latents is not None:
if audio_latents.ndim == 4:
_, _, audio_num_frames, _ = audio_latents.shape # [B, C, L, M]
else:
logger.warning(
f"You have supplied packed `audio_latents` of shape {audio_latents.shape}, so the latent dims"
f" cannot be inferred. Make sure the supplied `num_frames` is correct."
)

num_mel_bins = self.audio_vae.config.mel_bins if getattr(self, "audio_vae", None) is not None else 64
latent_mel_bins = num_mel_bins // self.audio_vae_mel_compression_ratio

num_channels_latents_audio = (
self.audio_vae.config.latent_channels if getattr(self, "audio_vae", None) is not None else 8
)
audio_latents, audio_num_frames = self.prepare_audio_latents(
audio_latents = self.prepare_audio_latents(
batch_size * num_videos_per_prompt,
num_channels_latents=num_channels_latents_audio,
audio_latent_length=audio_num_frames,
num_mel_bins=num_mel_bins,
num_frames=num_frames, # Video frames, audio frames will be calculated from this
frame_rate=frame_rate,
sampling_rate=self.audio_sampling_rate,
hop_length=self.audio_hop_length,
dtype=torch.float32,
device=device,
generator=generator,
latents=audio_latents,
)

# 5. Prepare timesteps
latent_num_frames = (num_frames - 1) // self.vae_temporal_compression_ratio + 1
latent_height = height // self.vae_spatial_compression_ratio
latent_width = width // self.vae_spatial_compression_ratio
video_sequence_length = latent_num_frames * latent_height * latent_width

sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps)
sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas
mu = calculate_shift(
video_sequence_length,
self.scheduler.config.get("base_image_seq_len", 1024),
Expand Down
6 changes: 6 additions & 0 deletions src/diffusers/pipelines/ltx2/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
# Pre-trained sigma values for distilled model are taken from
# https://github.com/Lightricks/LTX-2/blob/main/packages/ltx-pipelines/src/ltx_pipelines/utils/constants.py
DISTILLED_SIGMA_VALUES = [1.0, 0.99375, 0.9875, 0.98125, 0.975, 0.909375, 0.725, 0.421875]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's also mention the source for these values.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure!


# Reduced schedule for super-resolution stage 2 (subset of distilled values)
STAGE_2_DISTILLED_SIGMA_VALUES = [0.909375, 0.725, 0.421875]