diff --git a/src/diffusers/modular_pipelines/mellon_node_utils.py b/src/diffusers/modular_pipelines/mellon_node_utils.py index c931055c794f..aae0427196f1 100644 --- a/src/diffusers/modular_pipelines/mellon_node_utils.py +++ b/src/diffusers/modular_pipelines/mellon_node_utils.py @@ -68,6 +68,10 @@ def latents(cls, display: str = "input") -> "MellonParam": def image_latents(cls, display: str = "input") -> "MellonParam": return cls(name="image_latents", label="Image Latents", type="latents", display=display) + @classmethod + def first_frame_latents(cls, display: str = "input") -> "MellonParam": + return cls(name="first_frame_latents", label="First Frame Latents", type="latents", display=display) + @classmethod def image_latents_with_strength(cls) -> "MellonParam": return cls( @@ -89,6 +93,10 @@ def latents_preview(cls) -> "MellonParam": def embeddings(cls, display: str = "output") -> "MellonParam": return cls(name="embeddings", label="Text Embeddings", type="embeddings", display=display) + @classmethod + def image_embeds(cls, display: str = "output") -> "MellonParam": + return cls(name="image_embeds", label="Image Embeddings", type="image_embeds", display=display) + @classmethod def controlnet_conditioning_scale(cls, default: float = 0.5) -> "MellonParam": return cls( @@ -172,6 +180,10 @@ def num_inference_steps(cls, default: int = 25) -> "MellonParam": def num_frames(cls, default: int = 81) -> "MellonParam": return cls(name="num_frames", label="Frames", type="int", default=default, min=1, max=480, display="slider") + @classmethod + def layers(cls, default: int = 4) -> "MellonParam": + return cls(name="layers", label="Layers", type="int", default=default, min=1, max=10, display="slider") + @classmethod def videos(cls) -> "MellonParam": return cls(name="videos", label="Videos", type="video", display="output") @@ -186,6 +198,16 @@ def vae(cls) -> "MellonParam": """ return cls(name="vae", label="VAE", type="diffusers_auto_model", display="input") + @classmethod + def image_encoder(cls) -> "MellonParam": + """ + Image Encoder model info dict. + + Contains keys like 'model_id', 'repo_id', 'execution_device' etc. Use components.get_one(model_id) to retrieve + the actual model. + """ + return cls(name="image_encoder", label="Image Encoder", type="diffusers_auto_model", display="input") + @classmethod def unet(cls) -> "MellonParam": """ diff --git a/src/diffusers/modular_pipelines/wan/modular_blocks.py b/src/diffusers/modular_pipelines/wan/modular_blocks.py index b3b70b2f9be1..905111bcf42d 100644 --- a/src/diffusers/modular_pipelines/wan/modular_blocks.py +++ b/src/diffusers/modular_pipelines/wan/modular_blocks.py @@ -84,7 +84,7 @@ def description(self): class WanImage2VideoVaeImageEncoderStep(SequentialPipelineBlocks): model_name = "wan" block_classes = [WanImageResizeStep, WanVaeImageEncoderStep] - block_names = ["image_resize", "vae_image_encoder"] + block_names = ["image_resize", "vae_encoder"] @property def description(self): @@ -142,7 +142,7 @@ def description(self): class WanFLF2VVaeImageEncoderStep(SequentialPipelineBlocks): model_name = "wan" block_classes = [WanImageResizeStep, WanImageCropResizeStep, WanFirstLastFrameVaeImageEncoderStep] - block_names = ["image_resize", "last_image_resize", "vae_image_encoder"] + block_names = ["image_resize", "last_image_resize", "vae_encoder"] @property def description(self): @@ -203,7 +203,7 @@ def description(self): ## vae encoder class WanAutoVaeImageEncoderStep(AutoPipelineBlocks): block_classes = [WanFLF2VVaeImageEncoderStep, WanImage2VideoVaeImageEncoderStep] - block_names = ["flf2v_vae_image_encoder", "image2video_vae_image_encoder"] + block_names = ["flf2v_vae_encoder", "image2video_vae_encoder"] block_trigger_inputs = ["last_image", "image"] @property @@ -251,7 +251,7 @@ class WanAutoBlocks(SequentialPipelineBlocks): block_names = [ "text_encoder", "image_encoder", - "vae_image_encoder", + "vae_encoder", "denoise", "decode", ] @@ -353,7 +353,7 @@ class Wan22AutoBlocks(SequentialPipelineBlocks): ] block_names = [ "text_encoder", - "vae_image_encoder", + "vae_encoder", "denoise", "decode", ] @@ -384,7 +384,7 @@ def description(self): [ ("image_resize", WanImageResizeStep), ("image_encoder", WanImage2VideoImageEncoderStep), - ("vae_image_encoder", WanImage2VideoVaeImageEncoderStep), + ("vae_encoder", WanImage2VideoVaeImageEncoderStep), ("input", WanTextInputStep), ("additional_inputs", WanAdditionalInputsStep(image_latent_inputs=["first_frame_latents"])), ("set_timesteps", WanSetTimestepsStep), @@ -401,7 +401,7 @@ def description(self): ("image_resize", WanImageResizeStep), ("last_image_resize", WanImageCropResizeStep), ("image_encoder", WanFLF2VImageEncoderStep), - ("vae_image_encoder", WanFLF2VVaeImageEncoderStep), + ("vae_encoder", WanFLF2VVaeImageEncoderStep), ("input", WanTextInputStep), ("additional_inputs", WanAdditionalInputsStep(image_latent_inputs=["first_last_frame_latents"])), ("set_timesteps", WanSetTimestepsStep), @@ -416,7 +416,7 @@ def description(self): [ ("text_encoder", WanTextEncoderStep), ("image_encoder", WanAutoImageEncoderStep), - ("vae_image_encoder", WanAutoVaeImageEncoderStep), + ("vae_encoder", WanAutoVaeImageEncoderStep), ("denoise", WanAutoDenoiseStep), ("decode", WanImageVaeDecoderStep), ] @@ -438,7 +438,7 @@ def description(self): IMAGE2VIDEO_BLOCKS_WAN22 = InsertableDict( [ ("image_resize", WanImageResizeStep), - ("vae_image_encoder", WanImage2VideoVaeImageEncoderStep), + ("vae_encoder", WanImage2VideoVaeImageEncoderStep), ("input", WanTextInputStep), ("set_timesteps", WanSetTimestepsStep), ("prepare_latents", WanPrepareLatentsStep), @@ -450,7 +450,7 @@ def description(self): AUTO_BLOCKS_WAN22 = InsertableDict( [ ("text_encoder", WanTextEncoderStep), - ("vae_image_encoder", WanAutoVaeImageEncoderStep), + ("vae_encoder", WanAutoVaeImageEncoderStep), ("denoise", Wan22AutoDenoiseStep), ("decode", WanImageVaeDecoderStep), ]