-
Notifications
You must be signed in to change notification settings - Fork 6.7k
LTX2 distilled checkpoint support #12934
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
3d78f9d
9c754a4
6fbeacf
82c2e7f
d988fc3
837fd85
96fbcd8
faeccc5
9575e06
ce5a514
18f1603
eb01780
31b0f5d
7574bf9
c22eed5
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -653,6 +653,11 @@ def prepare_latents( | |||||||||||||
| latents: Optional[torch.Tensor] = None, | ||||||||||||||
| ) -> torch.Tensor: | ||||||||||||||
| if latents is not None: | ||||||||||||||
| if latents.ndim == 5: | ||||||||||||||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is this needed to support precomputed latents?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think it will add another support layer for precomputed latents. The latents returned are in shape |
||||||||||||||
| # latents are of shape [B, C, F, H, W], need to be packed | ||||||||||||||
| latents = self._pack_latents( | ||||||||||||||
| latents, self.transformer_spatial_patch_size, self.transformer_temporal_patch_size | ||||||||||||||
| ) | ||||||||||||||
| return latents.to(device=device, dtype=dtype) | ||||||||||||||
|
|
||||||||||||||
| height = height // self.vae_spatial_compression_ratio | ||||||||||||||
|
|
@@ -677,29 +682,23 @@ def prepare_audio_latents( | |||||||||||||
| self, | ||||||||||||||
| batch_size: int = 1, | ||||||||||||||
| num_channels_latents: int = 8, | ||||||||||||||
| audio_latent_length: int = 1, # 1 is just a dummy value | ||||||||||||||
| num_mel_bins: int = 64, | ||||||||||||||
| num_frames: int = 121, | ||||||||||||||
| frame_rate: float = 25.0, | ||||||||||||||
| sampling_rate: int = 16000, | ||||||||||||||
| hop_length: int = 160, | ||||||||||||||
| dtype: Optional[torch.dtype] = None, | ||||||||||||||
| device: Optional[torch.device] = None, | ||||||||||||||
| generator: Optional[torch.Generator] = None, | ||||||||||||||
| latents: Optional[torch.Tensor] = None, | ||||||||||||||
| ) -> torch.Tensor: | ||||||||||||||
| duration_s = num_frames / frame_rate | ||||||||||||||
| latents_per_second = ( | ||||||||||||||
| float(sampling_rate) / float(hop_length) / float(self.audio_vae_temporal_compression_ratio) | ||||||||||||||
| ) | ||||||||||||||
| latent_length = round(duration_s * latents_per_second) | ||||||||||||||
|
|
||||||||||||||
| if latents is not None: | ||||||||||||||
| return latents.to(device=device, dtype=dtype), latent_length | ||||||||||||||
| if latents.ndim == 4: | ||||||||||||||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Same question as above. |
||||||||||||||
| # latents are of shape [B, C, L, M], need to be packed | ||||||||||||||
| latents = self._pack_audio_latents(latents) | ||||||||||||||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. When the diffusers/src/diffusers/pipelines/ltx2/pipeline_ltx2.py Lines 1112 to 1115 in 18f1603
Since the DiT expects normalized latents, I think we need to normalize the audio latents here: latents = self._pack_audio_latents(latents)
latents = self._normalize_audio_latents(latents, self.audio_vae.latents_mean, self.audio_vae.latents_std)where @staticmethod
def _normalize_audio_latents(
latents: torch.Tensor, latents_mean: torch.Tensor, latents_std: torch.Tensor
) -> torch.Tensor:
# Normalize latents across the combined channel and mel bin dimension [B, L, C * M]
latents_mean = latents_mean.to(latents.device, latents.dtype)
latents_std = latents_std.to(latents.device, latents.dtype)
latents = (latents - latents_mean) / latents_std
return latentsThis should make the Stage 2 audio latents have the expected distribution.
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Note that although
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. (However,
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Hmm, this is a bit of a spiraling situation and I wonder if exposing a flag like The situation is getting confusing likely because the video latents returned from the upsampling pipeline is always normalized. If the users wants them unnormalized (with a flag
I think this is both true and false. Inside Flux Kontext, the image latents are normalized inside
This is what we also follow for LTX2 I2V:
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. One perspective we should consider to keep the returned latents denormalized as default is that there could be a pipeline where each module has different methods for normalizing. Only passing denormalized latents and each module need to normalize the latents themselves will help them maintain the plug-and-play ability. Moreover, the computation required for normalizing is acceptable |
||||||||||||||
| return latents.to(device=device, dtype=dtype) | ||||||||||||||
|
|
||||||||||||||
| # TODO: confirm whether this logic is correct | ||||||||||||||
| latent_mel_bins = num_mel_bins // self.audio_vae_mel_compression_ratio | ||||||||||||||
|
|
||||||||||||||
| shape = (batch_size, num_channels_latents, latent_length, latent_mel_bins) | ||||||||||||||
| shape = (batch_size, num_channels_latents, audio_latent_length, latent_mel_bins) | ||||||||||||||
|
|
||||||||||||||
| if isinstance(generator, list) and len(generator) != batch_size: | ||||||||||||||
| raise ValueError( | ||||||||||||||
|
|
@@ -709,7 +708,7 @@ def prepare_audio_latents( | |||||||||||||
|
|
||||||||||||||
| latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) | ||||||||||||||
| latents = self._pack_audio_latents(latents) | ||||||||||||||
| return latents, latent_length | ||||||||||||||
| return latents | ||||||||||||||
|
|
||||||||||||||
| @property | ||||||||||||||
| def guidance_scale(self): | ||||||||||||||
|
|
@@ -750,6 +749,7 @@ def __call__( | |||||||||||||
| num_frames: int = 121, | ||||||||||||||
| frame_rate: float = 24.0, | ||||||||||||||
| num_inference_steps: int = 40, | ||||||||||||||
| sigmas: Optional[List[float]] = None, | ||||||||||||||
| timesteps: List[int] = None, | ||||||||||||||
| guidance_scale: float = 4.0, | ||||||||||||||
| guidance_rescale: float = 0.0, | ||||||||||||||
|
|
@@ -788,6 +788,10 @@ def __call__( | |||||||||||||
| num_inference_steps (`int`, *optional*, defaults to 40): | ||||||||||||||
| The number of denoising steps. More denoising steps usually lead to a higher quality image at the | ||||||||||||||
| expense of slower inference. | ||||||||||||||
| sigmas (`List[float]`, *optional*): | ||||||||||||||
| Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in | ||||||||||||||
| their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed | ||||||||||||||
| will be used. | ||||||||||||||
| timesteps (`List[int]`, *optional*): | ||||||||||||||
| Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument | ||||||||||||||
| in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is | ||||||||||||||
|
|
@@ -922,6 +926,14 @@ def __call__( | |||||||||||||
| latent_num_frames = (num_frames - 1) // self.vae_temporal_compression_ratio + 1 | ||||||||||||||
| latent_height = height // self.vae_spatial_compression_ratio | ||||||||||||||
| latent_width = width // self.vae_spatial_compression_ratio | ||||||||||||||
| if latents is not None: | ||||||||||||||
| if latents.ndim == 5: | ||||||||||||||
| _, _, latent_num_frames, latent_height, latent_width = latents.shape # [B, C, F, H, W] | ||||||||||||||
| else: | ||||||||||||||
| logger.warning( | ||||||||||||||
| f"You have supplied packed `latents` of shape {latents.shape}, so the latent dims cannot be" | ||||||||||||||
| f" inferred. Make sure the supplied `height`, `width`, and `num_frames` are correct." | ||||||||||||||
| ) | ||||||||||||||
| video_sequence_length = latent_num_frames * latent_height * latent_width | ||||||||||||||
|
|
||||||||||||||
| num_channels_latents = self.transformer.config.in_channels | ||||||||||||||
|
|
@@ -937,28 +949,38 @@ def __call__( | |||||||||||||
| latents, | ||||||||||||||
| ) | ||||||||||||||
|
|
||||||||||||||
| duration_s = num_frames / frame_rate | ||||||||||||||
| audio_latents_per_second = ( | ||||||||||||||
| self.audio_sampling_rate / self.audio_hop_length / float(self.audio_vae_temporal_compression_ratio) | ||||||||||||||
| ) | ||||||||||||||
| audio_num_frames = round(duration_s * audio_latents_per_second) | ||||||||||||||
| if audio_latents is not None: | ||||||||||||||
| if audio_latents.ndim == 4: | ||||||||||||||
| _, _, audio_num_frames, _ = audio_latents.shape # [B, C, L, M] | ||||||||||||||
| else: | ||||||||||||||
| logger.warning( | ||||||||||||||
| f"You have supplied packed `audio_latents` of shape {audio_latents.shape}, so the latent dims" | ||||||||||||||
| f" cannot be inferred. Make sure the supplied `num_frames` is correct." | ||||||||||||||
| ) | ||||||||||||||
|
|
||||||||||||||
| num_mel_bins = self.audio_vae.config.mel_bins if getattr(self, "audio_vae", None) is not None else 64 | ||||||||||||||
| latent_mel_bins = num_mel_bins // self.audio_vae_mel_compression_ratio | ||||||||||||||
|
|
||||||||||||||
| num_channels_latents_audio = ( | ||||||||||||||
| self.audio_vae.config.latent_channels if getattr(self, "audio_vae", None) is not None else 8 | ||||||||||||||
| ) | ||||||||||||||
| audio_latents, audio_num_frames = self.prepare_audio_latents( | ||||||||||||||
| audio_latents = self.prepare_audio_latents( | ||||||||||||||
| batch_size * num_videos_per_prompt, | ||||||||||||||
| num_channels_latents=num_channels_latents_audio, | ||||||||||||||
| audio_latent_length=audio_num_frames, | ||||||||||||||
| num_mel_bins=num_mel_bins, | ||||||||||||||
| num_frames=num_frames, # Video frames, audio frames will be calculated from this | ||||||||||||||
| frame_rate=frame_rate, | ||||||||||||||
| sampling_rate=self.audio_sampling_rate, | ||||||||||||||
| hop_length=self.audio_hop_length, | ||||||||||||||
| dtype=torch.float32, | ||||||||||||||
| device=device, | ||||||||||||||
| generator=generator, | ||||||||||||||
| latents=audio_latents, | ||||||||||||||
| ) | ||||||||||||||
|
|
||||||||||||||
| # 5. Prepare timesteps | ||||||||||||||
| sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) | ||||||||||||||
| sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas | ||||||||||||||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Let's apply similar changes to the I2V module as well?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sure, will apply the same change |
||||||||||||||
| mu = calculate_shift( | ||||||||||||||
| video_sequence_length, | ||||||||||||||
| self.scheduler.config.get("base_image_seq_len", 1024), | ||||||||||||||
|
|
||||||||||||||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,6 @@ | ||
| # Pre-trained sigma values for distilled model are taken from | ||
| # https://github.com/Lightricks/LTX-2/blob/main/packages/ltx-pipelines/src/ltx_pipelines/utils/constants.py | ||
| DISTILLED_SIGMA_VALUES = [1.0, 0.99375, 0.9875, 0.98125, 0.975, 0.909375, 0.725, 0.421875] | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Let's also mention the source for these values.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sure! |
||
|
|
||
| # Reduced schedule for super-resolution stage 2 (subset of distilled values) | ||
| STAGE_2_DISTILLED_SIGMA_VALUES = [0.909375, 0.725, 0.421875] | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Where is
timestep_conditioningused in the corresponding Video VAE class?Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
timestep_conditioningis used here. It will depend on the metadata stored in the original checkpoint and distilled checkpoint is configured to usetimestep_conditioning. You can confirm it using below script:Hence, I got some unexpected keys if don't turn
timestep_conditioningto True when converting distilled weightsThere was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi @rootonchair, there was a recent update on the official
Lightricks/LTX-2repo which updated the VAE for the distilled checkpoint: https://huggingface.co/Lightricks/LTX-2/commit/1931de987c8e265eb64a9123227b903754e3cc68. So I thinkrootonchair/LTX-2-19b-distilledneeds to be updated with the new converted VAE as well.It looks like
timestep_conditioningshould be set toFalsefor the new video VAE, according to theconfigmetadata onltx-2-19b-distilled.safetensors. So we should probably remove the unexpected keys (last_time_embedder,last_scale_shift_table) rather than processing them in the VAE conversion script.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think I will convert the new weight. For the convert script, I thinnk we should keep it there until the original repo decide to delete it