diff --git a/scripts/convert_ltx2_to_diffusers.py b/scripts/convert_ltx2_to_diffusers.py index 5367113365a2..72b334b71e71 100644 --- a/scripts/convert_ltx2_to_diffusers.py +++ b/scripts/convert_ltx2_to_diffusers.py @@ -63,6 +63,8 @@ "up_blocks.4": "up_blocks.1", "up_blocks.5": "up_blocks.2.upsamplers.0", "up_blocks.6": "up_blocks.2", + "last_time_embedder": "time_embedder", + "last_scale_shift_table": "scale_shift_table", # Common # For all 3D ResNets "res_blocks": "resnets", @@ -372,7 +374,9 @@ def convert_ltx2_connectors(original_state_dict: Dict[str, Any], version: str) - return connectors -def get_ltx2_video_vae_config(version: str) -> Tuple[Dict[str, Any], Dict[str, Any], Dict[str, Any]]: +def get_ltx2_video_vae_config( + version: str, timestep_conditioning: bool = False +) -> Tuple[Dict[str, Any], Dict[str, Any], Dict[str, Any]]: if version == "test": config = { "model_id": "diffusers-internal-dev/dummy-ltx2", @@ -396,7 +400,7 @@ def get_ltx2_video_vae_config(version: str) -> Tuple[Dict[str, Any], Dict[str, A "downsample_type": ("spatial", "temporal", "spatiotemporal", "spatiotemporal"), "upsample_residual": (True, True, True), "upsample_factor": (2, 2, 2), - "timestep_conditioning": False, + "timestep_conditioning": timestep_conditioning, "patch_size": 4, "patch_size_t": 1, "resnet_norm_eps": 1e-6, @@ -433,7 +437,7 @@ def get_ltx2_video_vae_config(version: str) -> Tuple[Dict[str, Any], Dict[str, A "downsample_type": ("spatial", "temporal", "spatiotemporal", "spatiotemporal"), "upsample_residual": (True, True, True), "upsample_factor": (2, 2, 2), - "timestep_conditioning": False, + "timestep_conditioning": timestep_conditioning, "patch_size": 4, "patch_size_t": 1, "resnet_norm_eps": 1e-6, @@ -450,8 +454,10 @@ def get_ltx2_video_vae_config(version: str) -> Tuple[Dict[str, Any], Dict[str, A return config, rename_dict, special_keys_remap -def convert_ltx2_video_vae(original_state_dict: Dict[str, Any], version: str) -> Dict[str, Any]: - config, rename_dict, special_keys_remap = get_ltx2_video_vae_config(version) +def convert_ltx2_video_vae( + original_state_dict: Dict[str, Any], version: str, timestep_conditioning: bool +) -> Dict[str, Any]: + config, rename_dict, special_keys_remap = get_ltx2_video_vae_config(version, timestep_conditioning) diffusers_config = config["diffusers_config"] with init_empty_weights(): @@ -717,6 +723,9 @@ def get_args(): help="Latent upsampler filename", ) + parser.add_argument( + "--timestep_conditioning", action="store_true", help="Whether to add timestep condition to the video VAE model" + ) parser.add_argument("--vae", action="store_true", help="Whether to convert the video VAE model") parser.add_argument("--audio_vae", action="store_true", help="Whether to convert the audio VAE model") parser.add_argument("--dit", action="store_true", help="Whether to convert the DiT model") @@ -786,7 +795,9 @@ def main(args): original_vae_ckpt = load_hub_or_local_checkpoint(filename=args.vae_filename) elif combined_ckpt is not None: original_vae_ckpt = get_model_state_dict_from_combined_ckpt(combined_ckpt, args.vae_prefix) - vae = convert_ltx2_video_vae(original_vae_ckpt, version=args.version) + vae = convert_ltx2_video_vae( + original_vae_ckpt, version=args.version, timestep_conditioning=args.timestep_conditioning + ) if not args.full_pipeline and not args.upsample_pipeline: vae.to(vae_dtype).save_pretrained(os.path.join(args.output_path, "vae")) diff --git a/src/diffusers/pipelines/ltx2/pipeline_ltx2.py b/src/diffusers/pipelines/ltx2/pipeline_ltx2.py index 9cf847926347..54f1061da5c9 100644 --- a/src/diffusers/pipelines/ltx2/pipeline_ltx2.py +++ b/src/diffusers/pipelines/ltx2/pipeline_ltx2.py @@ -653,6 +653,11 @@ def prepare_latents( latents: Optional[torch.Tensor] = None, ) -> torch.Tensor: if latents is not None: + if latents.ndim == 5: + # latents are of shape [B, C, F, H, W], need to be packed + latents = self._pack_latents( + latents, self.transformer_spatial_patch_size, self.transformer_temporal_patch_size + ) return latents.to(device=device, dtype=dtype) height = height // self.vae_spatial_compression_ratio @@ -677,29 +682,23 @@ def prepare_audio_latents( self, batch_size: int = 1, num_channels_latents: int = 8, + audio_latent_length: int = 1, # 1 is just a dummy value num_mel_bins: int = 64, - num_frames: int = 121, - frame_rate: float = 25.0, - sampling_rate: int = 16000, - hop_length: int = 160, dtype: Optional[torch.dtype] = None, device: Optional[torch.device] = None, generator: Optional[torch.Generator] = None, latents: Optional[torch.Tensor] = None, ) -> torch.Tensor: - duration_s = num_frames / frame_rate - latents_per_second = ( - float(sampling_rate) / float(hop_length) / float(self.audio_vae_temporal_compression_ratio) - ) - latent_length = round(duration_s * latents_per_second) - if latents is not None: - return latents.to(device=device, dtype=dtype), latent_length + if latents.ndim == 4: + # latents are of shape [B, C, L, M], need to be packed + latents = self._pack_audio_latents(latents) + return latents.to(device=device, dtype=dtype) # TODO: confirm whether this logic is correct latent_mel_bins = num_mel_bins // self.audio_vae_mel_compression_ratio - shape = (batch_size, num_channels_latents, latent_length, latent_mel_bins) + shape = (batch_size, num_channels_latents, audio_latent_length, latent_mel_bins) if isinstance(generator, list) and len(generator) != batch_size: raise ValueError( @@ -709,7 +708,7 @@ def prepare_audio_latents( latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) latents = self._pack_audio_latents(latents) - return latents, latent_length + return latents @property def guidance_scale(self): @@ -750,6 +749,7 @@ def __call__( num_frames: int = 121, frame_rate: float = 24.0, num_inference_steps: int = 40, + sigmas: Optional[List[float]] = None, timesteps: List[int] = None, guidance_scale: float = 4.0, guidance_rescale: float = 0.0, @@ -788,6 +788,10 @@ def __call__( num_inference_steps (`int`, *optional*, defaults to 40): The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. + sigmas (`List[float]`, *optional*): + Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in + their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed + will be used. timesteps (`List[int]`, *optional*): Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is @@ -922,6 +926,14 @@ def __call__( latent_num_frames = (num_frames - 1) // self.vae_temporal_compression_ratio + 1 latent_height = height // self.vae_spatial_compression_ratio latent_width = width // self.vae_spatial_compression_ratio + if latents is not None: + if latents.ndim == 5: + _, _, latent_num_frames, latent_height, latent_width = latents.shape # [B, C, F, H, W] + else: + logger.warning( + f"You have supplied packed `latents` of shape {latents.shape}, so the latent dims cannot be" + f" inferred. Make sure the supplied `height`, `width`, and `num_frames` are correct." + ) video_sequence_length = latent_num_frames * latent_height * latent_width num_channels_latents = self.transformer.config.in_channels @@ -937,20 +949,30 @@ def __call__( latents, ) + duration_s = num_frames / frame_rate + audio_latents_per_second = ( + self.audio_sampling_rate / self.audio_hop_length / float(self.audio_vae_temporal_compression_ratio) + ) + audio_num_frames = round(duration_s * audio_latents_per_second) + if audio_latents is not None: + if audio_latents.ndim == 4: + _, _, audio_num_frames, _ = audio_latents.shape # [B, C, L, M] + else: + logger.warning( + f"You have supplied packed `audio_latents` of shape {audio_latents.shape}, so the latent dims" + f" cannot be inferred. Make sure the supplied `num_frames` is correct." + ) + num_mel_bins = self.audio_vae.config.mel_bins if getattr(self, "audio_vae", None) is not None else 64 latent_mel_bins = num_mel_bins // self.audio_vae_mel_compression_ratio - num_channels_latents_audio = ( self.audio_vae.config.latent_channels if getattr(self, "audio_vae", None) is not None else 8 ) - audio_latents, audio_num_frames = self.prepare_audio_latents( + audio_latents = self.prepare_audio_latents( batch_size * num_videos_per_prompt, num_channels_latents=num_channels_latents_audio, + audio_latent_length=audio_num_frames, num_mel_bins=num_mel_bins, - num_frames=num_frames, # Video frames, audio frames will be calculated from this - frame_rate=frame_rate, - sampling_rate=self.audio_sampling_rate, - hop_length=self.audio_hop_length, dtype=torch.float32, device=device, generator=generator, @@ -958,7 +980,7 @@ def __call__( ) # 5. Prepare timesteps - sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) + sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas mu = calculate_shift( video_sequence_length, self.scheduler.config.get("base_image_seq_len", 1024), diff --git a/src/diffusers/pipelines/ltx2/pipeline_ltx2_image2video.py b/src/diffusers/pipelines/ltx2/pipeline_ltx2_image2video.py index b1711e283191..460ff8eec7de 100644 --- a/src/diffusers/pipelines/ltx2/pipeline_ltx2_image2video.py +++ b/src/diffusers/pipelines/ltx2/pipeline_ltx2_image2video.py @@ -689,6 +689,11 @@ def prepare_latents( conditioning_mask = self._pack_latents( conditioning_mask, self.transformer_spatial_patch_size, self.transformer_temporal_patch_size ).squeeze(-1) + if latents.ndim == 5: + # latents are of shape [B, C, F, H, W], need to be packed + latents = self._pack_latents( + latents, self.transformer_spatial_patch_size, self.transformer_temporal_patch_size + ) if latents.ndim != 3 or latents.shape[:2] != conditioning_mask.shape: raise ValueError( f"Provided `latents` tensor has shape {latents.shape}, but the expected shape is {conditioning_mask.shape + (num_channels_latents,)}." @@ -737,29 +742,23 @@ def prepare_audio_latents( self, batch_size: int = 1, num_channels_latents: int = 8, + audio_latent_length: int = 1, # 1 is just a dummy value num_mel_bins: int = 64, - num_frames: int = 121, - frame_rate: float = 25.0, - sampling_rate: int = 16000, - hop_length: int = 160, dtype: Optional[torch.dtype] = None, device: Optional[torch.device] = None, generator: Optional[torch.Generator] = None, latents: Optional[torch.Tensor] = None, ) -> torch.Tensor: - duration_s = num_frames / frame_rate - latents_per_second = ( - float(sampling_rate) / float(hop_length) / float(self.audio_vae_temporal_compression_ratio) - ) - latent_length = round(duration_s * latents_per_second) - if latents is not None: - return latents.to(device=device, dtype=dtype), latent_length + if latents.ndim == 4: + # latents are of shape [B, C, L, M], need to be packed + latents = self._pack_audio_latents(latents) + return latents.to(device=device, dtype=dtype) # TODO: confirm whether this logic is correct latent_mel_bins = num_mel_bins // self.audio_vae_mel_compression_ratio - shape = (batch_size, num_channels_latents, latent_length, latent_mel_bins) + shape = (batch_size, num_channels_latents, audio_latent_length, latent_mel_bins) if isinstance(generator, list) and len(generator) != batch_size: raise ValueError( @@ -769,7 +768,7 @@ def prepare_audio_latents( latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) latents = self._pack_audio_latents(latents) - return latents, latent_length + return latents @property def guidance_scale(self): @@ -811,6 +810,7 @@ def __call__( num_frames: int = 121, frame_rate: float = 24.0, num_inference_steps: int = 40, + sigmas: Optional[List[float]] = None, timesteps: List[int] = None, guidance_scale: float = 4.0, guidance_rescale: float = 0.0, @@ -851,6 +851,10 @@ def __call__( num_inference_steps (`int`, *optional*, defaults to 40): The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. + sigmas (`List[float]`, *optional*): + Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in + their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed + will be used. timesteps (`List[int]`, *optional*): Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is @@ -982,6 +986,19 @@ def __call__( ) # 4. Prepare latent variables + latent_num_frames = (num_frames - 1) // self.vae_temporal_compression_ratio + 1 + latent_height = height // self.vae_spatial_compression_ratio + latent_width = width // self.vae_spatial_compression_ratio + if latents is not None: + if latents.ndim == 5: + _, _, latent_num_frames, latent_height, latent_width = latents.shape # [B, C, F, H, W] + else: + logger.warning( + f"You have supplied packed `latents` of shape {latents.shape}, so the latent dims cannot be" + f" inferred. Make sure the supplied `height`, `width`, and `num_frames` are correct." + ) + video_sequence_length = latent_num_frames * latent_height * latent_width + if latents is None: image = self.video_processor.preprocess(image, height=height, width=width) image = image.to(device=device, dtype=prompt_embeds.dtype) @@ -1002,20 +1019,30 @@ def __call__( if self.do_classifier_free_guidance: conditioning_mask = torch.cat([conditioning_mask, conditioning_mask]) + duration_s = num_frames / frame_rate + audio_latents_per_second = ( + self.audio_sampling_rate / self.audio_hop_length / float(self.audio_vae_temporal_compression_ratio) + ) + audio_num_frames = round(duration_s * audio_latents_per_second) + if audio_latents is not None: + if audio_latents.ndim == 4: + _, _, audio_num_frames, _ = audio_latents.shape # [B, C, L, M] + else: + logger.warning( + f"You have supplied packed `audio_latents` of shape {audio_latents.shape}, so the latent dims" + f" cannot be inferred. Make sure the supplied `num_frames` is correct." + ) + num_mel_bins = self.audio_vae.config.mel_bins if getattr(self, "audio_vae", None) is not None else 64 latent_mel_bins = num_mel_bins // self.audio_vae_mel_compression_ratio - num_channels_latents_audio = ( self.audio_vae.config.latent_channels if getattr(self, "audio_vae", None) is not None else 8 ) - audio_latents, audio_num_frames = self.prepare_audio_latents( + audio_latents = self.prepare_audio_latents( batch_size * num_videos_per_prompt, num_channels_latents=num_channels_latents_audio, + audio_latent_length=audio_num_frames, num_mel_bins=num_mel_bins, - num_frames=num_frames, # Video frames, audio frames will be calculated from this - frame_rate=frame_rate, - sampling_rate=self.audio_sampling_rate, - hop_length=self.audio_hop_length, dtype=torch.float32, device=device, generator=generator, @@ -1023,12 +1050,7 @@ def __call__( ) # 5. Prepare timesteps - latent_num_frames = (num_frames - 1) // self.vae_temporal_compression_ratio + 1 - latent_height = height // self.vae_spatial_compression_ratio - latent_width = width // self.vae_spatial_compression_ratio - video_sequence_length = latent_num_frames * latent_height * latent_width - - sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) + sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas mu = calculate_shift( video_sequence_length, self.scheduler.config.get("base_image_seq_len", 1024), diff --git a/src/diffusers/pipelines/ltx2/utils.py b/src/diffusers/pipelines/ltx2/utils.py new file mode 100644 index 000000000000..f80469817fe6 --- /dev/null +++ b/src/diffusers/pipelines/ltx2/utils.py @@ -0,0 +1,6 @@ +# Pre-trained sigma values for distilled model are taken from +# https://github.com/Lightricks/LTX-2/blob/main/packages/ltx-pipelines/src/ltx_pipelines/utils/constants.py +DISTILLED_SIGMA_VALUES = [1.0, 0.99375, 0.9875, 0.98125, 0.975, 0.909375, 0.725, 0.421875] + +# Reduced schedule for super-resolution stage 2 (subset of distilled values) +STAGE_2_DISTILLED_SIGMA_VALUES = [0.909375, 0.725, 0.421875]