diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml
index 770093438ed5..55c0777885a5 100644
--- a/docs/source/en/_toctree.yml
+++ b/docs/source/en/_toctree.yml
@@ -321,6 +321,8 @@
title: Lumina2Transformer2DModel
- local: api/models/lumina_nextdit2d
title: LuminaNextDiT2DModel
+ - local: api/models/magi_transformer_3d
+ title: Magi1Transformer3DModel
- local: api/models/mochi_transformer3d
title: MochiTransformer3DModel
- local: api/models/omnigen_transformer
@@ -375,6 +377,8 @@
title: AutoencoderKLHunyuanVideo
- local: api/models/autoencoderkl_ltx_video
title: AutoencoderKLLTXVideo
+ - local: api/models/autoencoder_kl_magi1
+ title: AutoencoderKLMagi1
- local: api/models/autoencoderkl_magvit
title: AutoencoderKLMagvit
- local: api/models/autoencoderkl_mochi
@@ -497,6 +501,8 @@
title: Lumina 2.0
- local: api/pipelines/lumina
title: Lumina-T2X
+ - local: api/pipelines/magi1
+ title: MAGI-1
- local: api/pipelines/marigold
title: Marigold
- local: api/pipelines/mochi
diff --git a/docs/source/en/api/models/autoencoder_kl_magi1.md b/docs/source/en/api/models/autoencoder_kl_magi1.md
new file mode 100644
index 000000000000..2301d08b72da
--- /dev/null
+++ b/docs/source/en/api/models/autoencoder_kl_magi1.md
@@ -0,0 +1,34 @@
+
+
+# AutoencoderKLMagi1
+
+The 3D variational autoencoder (VAE) model with KL loss used in [MAGI-1: Autoregressive Video Generation at Scale](https://arxiv.org/abs/2505.13211) by Sand.ai.
+
+MAGI-1 uses a transformer-based VAE with 8x spatial and 4x temporal compression, providing fast average decoding time and highly competitive reconstruction quality.
+
+The model can be loaded with the following code snippet.
+
+```python
+from diffusers import AutoencoderKLMagi1
+
+vae = AutoencoderKLMagi1.from_pretrained("sand-ai/MAGI-1", subfolder="vae", torch_dtype=torch.float32)
+```
+
+## AutoencoderKLMagi1
+
+[[autodoc]] AutoencoderKLMagi1
+ - decode
+ - all
+
+## DecoderOutput
+
+[[autodoc]] models.autoencoders.vae.DecoderOutput
diff --git a/docs/source/en/api/models/magi1_transformer_3d.md b/docs/source/en/api/models/magi1_transformer_3d.md
new file mode 100644
index 000000000000..8fb369f16253
--- /dev/null
+++ b/docs/source/en/api/models/magi1_transformer_3d.md
@@ -0,0 +1,32 @@
+
+
+# Magi1Transformer3DModel
+
+A Diffusion Transformer model for 3D video-like data was introduced in [MAGI-1: Autoregressive Video Generation at Scale](https://arxiv.org/abs/2505.13211) by Sand.ai.
+
+MAGI-1 is an autoregressive denoising video generation model that generates videos chunk-by-chunk instead of as a whole. Each chunk (24 frames) is denoised holistically, and the generation of the next chunk begins as soon as the current one reaches a certain level of denoising.
+
+The model can be loaded with the following code snippet.
+
+```python
+from diffusers import Magi1Transformer3DModel
+
+transformer = Magi1Transformer3DModel.from_pretrained("sand-ai/MAGI-1", subfolder="transformer", torch_dtype=torch.bfloat16)
+```
+
+## Magi1Transformer3DModel
+
+[[autodoc]] Magi1Transformer3DModel
+
+## Transformer2DModelOutput
+
+[[autodoc]] models.modeling_outputs.Transformer2DModelOutput
\ No newline at end of file
diff --git a/docs/source/en/api/pipelines/magi1.md b/docs/source/en/api/pipelines/magi1.md
new file mode 100644
index 000000000000..73618f02ead1
--- /dev/null
+++ b/docs/source/en/api/pipelines/magi1.md
@@ -0,0 +1,309 @@
+
+
+
+
+# MAGI-1
+
+[MAGI-1: Autoregressive Video Generation at Scale](https://arxiv.org/abs/2505.13211) by Sand.ai.
+
+*MAGI-1 is an autoregressive video generation model that generates videos chunk-by-chunk instead of as a whole. Each chunk (24 frames) is denoised holistically, and the generation of the next chunk begins as soon as the current one reaches a certain level of denoising. This pipeline design enables concurrent processing of up to four chunks for efficient video generation. The model leverages a specialized architecture with a transformer-based VAE with 8x spatial and 4x temporal compression, and a diffusion transformer with several key innovations including Block-Causal Attention, Parallel Attention Block, QK-Norm and GQA, Sandwich Normalization in FFN, SwiGLU, and Softcap Modulation.*
+
+You can find the MAGI-1 checkpoints under the [sand-ai](https://huggingface.co/sand-ai) organization.
+
+The following MAGI models are supported in Diffusers:
+- [MAGI-1 24B](https://huggingface.co/sand-ai/MAGI-1)
+- [MAGI-1 4.5B](https://huggingface.co/sand-ai/MAGI-1-4.5B)
+
+> [!TIP]
+> Click on the MAGI-1 models in the right sidebar for more examples of video generation.
+
+### Text-to-Video Generation
+
+The example below demonstrates how to generate a video from text optimized for memory or inference speed.
+
+
+
+
+Refer to the [Reduce memory usage](../../optimization/memory) guide for more details about the various memory saving techniques.
+
+The MAGI-1 text-to-video model below requires ~13GB of VRAM.
+
+```py
+import torch
+import numpy as np
+from diffusers import AutoModel, Magi1Pipeline
+from diffusers.hooks.group_offloading import apply_group_offloading
+from diffusers.utils import export_to_video
+from transformers import T5EncoderModel
+
+text_encoder = T5EncoderModel.from_pretrained("sand-ai/MAGI-1", subfolder="text_encoder", torch_dtype=torch.bfloat16)
+vae = AutoModel.from_pretrained("sand-ai/MAGI-1", subfolder="vae", torch_dtype=torch.float32)
+transformer = AutoModel.from_pretrained("sand-ai/MAGI-1", subfolder="transformer", torch_dtype=torch.bfloat16)
+
+# group-offloading
+onload_device = torch.device("cuda")
+offload_device = torch.device("cpu")
+apply_group_offloading(text_encoder,
+ onload_device=onload_device,
+ offload_device=offload_device,
+ offload_type="block_level",
+ num_blocks_per_group=4
+)
+transformer.enable_group_offload(
+ onload_device=onload_device,
+ offload_device=offload_device,
+ offload_type="leaf_level",
+ use_stream=True
+)
+
+pipeline = Magi1Pipeline.from_pretrained(
+ "sand-ai/MAGI-1",
+ vae=vae,
+ transformer=transformer,
+ text_encoder=text_encoder,
+ torch_dtype=torch.bfloat16
+)
+pipeline.to("cuda")
+
+prompt = """
+A majestic eagle soaring over a mountain landscape. The eagle's wings are spread wide,
+catching the golden sunlight as it glides through the clear blue sky. Below, snow-capped
+mountains stretch to the horizon, with pine forests and a winding river visible in the valley.
+"""
+negative_prompt = """
+Poor quality, blurry, pixelated, low resolution, distorted proportions, unnatural colors,
+watermark, text overlay, incomplete rendering, glitches, artifacts, unrealistic lighting
+"""
+
+output = pipeline(
+ prompt=prompt,
+ negative_prompt=negative_prompt,
+ num_frames=24,
+ guidance_scale=7.0,
+).frames[0]
+export_to_video(output, "output.mp4", fps=8)
+```
+
+
+
+
+[Compilation](../../optimization/fp16#torchcompile) is slow the first time but subsequent calls to the pipeline are faster.
+
+```py
+import torch
+import numpy as np
+from diffusers import AutoModel, Magi1Pipeline
+from diffusers.utils import export_to_video
+from transformers import T5EncoderModel
+
+text_encoder = T5EncoderModel.from_pretrained("sand-ai/MAGI-1", subfolder="text_encoder", torch_dtype=torch.bfloat16)
+vae = AutoModel.from_pretrained("sand-ai/MAGI-1", subfolder="vae", torch_dtype=torch.float32)
+transformer = AutoModel.from_pretrained("sand-ai/MAGI-1", subfolder="transformer", torch_dtype=torch.bfloat16)
+
+pipeline = Magi1Pipeline.from_pretrained(
+ "sand-ai/MAGI-1",
+ vae=vae,
+ transformer=transformer,
+ text_encoder=text_encoder,
+ torch_dtype=torch.bfloat16
+)
+pipeline.to("cuda")
+
+# torch.compile
+pipeline.transformer.to(memory_format=torch.channels_last)
+pipeline.transformer = torch.compile(
+ pipeline.transformer, mode="max-autotune", fullgraph=True
+)
+
+prompt = """
+A majestic eagle soaring over a mountain landscape. The eagle's wings are spread wide,
+catching the golden sunlight as it glides through the clear blue sky. Below, snow-capped
+mountains stretch to the horizon, with pine forests and a winding river visible in the valley.
+"""
+negative_prompt = """
+Poor quality, blurry, pixelated, low resolution, distorted proportions, unnatural colors,
+watermark, text overlay, incomplete rendering, glitches, artifacts, unrealistic lighting
+"""
+
+output = pipeline(
+ prompt=prompt,
+ negative_prompt=negative_prompt,
+ num_frames=24,
+ guidance_scale=7.0,
+).frames[0]
+export_to_video(output, "output.mp4", fps=8)
+```
+
+
+
+
+### Image-to-Video Generation
+
+The example below demonstrates how to use the image-to-video pipeline to generate a video using a text description and a starting frame.
+
+
+
+
+```python
+import numpy as np
+import torch
+import torchvision.transforms.functional as TF
+from diffusers import AutoencoderKLMagi1, Magi1ImageToVideoPipeline
+from diffusers.utils import export_to_video, load_image
+from transformers import CLIPVisionModel
+
+model_id = "sand-ai/MAGI-1"
+image_encoder = CLIPVisionModel.from_pretrained(model_id, subfolder="image_encoder", torch_dtype=torch.float32)
+vae = AutoencoderKLMagi1.from_pretrained(model_id, subfolder="vae", torch_dtype=torch.float32)
+pipe = Magi1ImageToVideoPipeline.from_pretrained(
+ model_id, vae=vae, image_encoder=image_encoder, torch_dtype=torch.bfloat16
+)
+pipe.to("cuda")
+
+image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/input_image.png")
+
+def aspect_ratio_resize(image, pipe, max_area=720 * 1280):
+ aspect_ratio = image.height / image.width
+ mod_value = pipe.vae_scale_factor_spatial * pipe.transformer.config.patch_size[1]
+ height = round(np.sqrt(max_area * aspect_ratio)) // mod_value * mod_value
+ width = round(np.sqrt(max_area / aspect_ratio)) // mod_value * mod_value
+ image = image.resize((width, height))
+ return image, height, width
+
+image, height, width = aspect_ratio_resize(image, pipe)
+
+prompt = "A beautiful landscape with mountains and a lake. The camera slowly pans from left to right, revealing more of the landscape."
+
+output = pipe(
+ image=image, prompt=prompt, height=height, width=width, guidance_scale=7.5, num_frames=24
+).frames[0]
+export_to_video(output, "output.mp4", fps=8)
+```
+
+
+
+
+### First-Last-Frame-to-Video Generation
+
+The example below demonstrates how to use the image-to-video pipeline to generate a video using a text description, a starting frame, and an ending frame.
+
+
+
+
+```python
+import numpy as np
+import torch
+import torchvision.transforms.functional as TF
+from diffusers import AutoencoderKLMagi1, Magi1ImageToVideoPipeline
+from diffusers.utils import export_to_video, load_image
+from transformers import CLIPVisionModel
+
+model_id = "sand-ai/MAGI-1"
+image_encoder = CLIPVisionModel.from_pretrained(model_id, subfolder="image_encoder", torch_dtype=torch.float32)
+vae = AutoencoderKLMagi1.from_pretrained(model_id, subfolder="vae", torch_dtype=torch.float32)
+pipe = Magi1ImageToVideoPipeline.from_pretrained(
+ model_id, vae=vae, image_encoder=image_encoder, torch_dtype=torch.bfloat16
+)
+pipe.to("cuda")
+
+first_frame = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/first_frame.png")
+last_frame = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/last_frame.png")
+
+def aspect_ratio_resize(image, pipe, max_area=720 * 1280):
+ aspect_ratio = image.height / image.width
+ mod_value = pipe.vae_scale_factor_spatial * pipe.transformer.config.patch_size[1]
+ height = round(np.sqrt(max_area * aspect_ratio)) // mod_value * mod_value
+ width = round(np.sqrt(max_area / aspect_ratio)) // mod_value * mod_value
+ image = image.resize((width, height))
+ return image, height, width
+
+def center_crop_resize(image, height, width):
+ # Calculate resize ratio to match first frame dimensions
+ resize_ratio = max(width / image.width, height / image.height)
+
+ # Resize the image
+ width = round(image.width * resize_ratio)
+ height = round(image.height * resize_ratio)
+ size = [width, height]
+ image = TF.center_crop(image, size)
+
+ return image, height, width
+
+first_frame, height, width = aspect_ratio_resize(first_frame, pipe)
+if last_frame.size != first_frame.size:
+ last_frame, _, _ = center_crop_resize(last_frame, height, width)
+
+prompt = "A car driving down a winding mountain road. The camera follows the car as it navigates the curves, revealing beautiful mountain scenery in the background."
+
+output = pipe(
+ image=first_frame, last_image=last_frame, prompt=prompt, height=height, width=width, guidance_scale=7.5, num_frames=24
+).frames[0]
+export_to_video(output, "output.mp4", fps=8)
+```
+
+
+
+
+### Video-to-Video Generation
+
+The example below demonstrates how to use the video-to-video pipeline to generate a video based on an existing video and text prompt.
+
+
+
+
+```python
+import torch
+import numpy as np
+from diffusers import AutoencoderKLMagi1, Magi1VideoToVideoPipeline
+from diffusers.utils import export_to_video, load_video
+from transformers import T5EncoderModel
+
+model_id = "sand-ai/MAGI-1"
+text_encoder = T5EncoderModel.from_pretrained(model_id, subfolder="text_encoder", torch_dtype=torch.bfloat16)
+vae = AutoencoderKLMagi1.from_pretrained(model_id, subfolder="vae", torch_dtype=torch.float32)
+pipe = Magi1VideoToVideoPipeline.from_pretrained(
+ model_id, vae=vae, text_encoder=text_encoder, torch_dtype=torch.bfloat16
+)
+pipe.to("cuda")
+
+# Load input video
+video_path = "input_video.mp4"
+video = load_video(video_path)
+
+prompt = "Convert this video to an anime style with vibrant colors and exaggerated features"
+negative_prompt = "Poor quality, blurry, distorted, unrealistic lighting, bad composition"
+
+output = pipe(
+ prompt=prompt,
+ negative_prompt=negative_prompt,
+ video=video,
+ strength=0.7, # Controls how much to preserve from original video
+ guidance_scale=7.5,
+).frames[0]
+export_to_video(output, "output.mp4", fps=8)
+```
+
+
+
+
+## Notes
+
+- MAGI-1 supports LoRAs with [`~loaders.MagiLoraLoaderMixin.load_lora_weights`].
\ No newline at end of file
diff --git a/scripts/convert_magi1_to_diffusers.py b/scripts/convert_magi1_to_diffusers.py
new file mode 100644
index 000000000000..9e52e108e728
--- /dev/null
+++ b/scripts/convert_magi1_to_diffusers.py
@@ -0,0 +1,592 @@
+import argparse
+import json
+import os
+import shutil
+import tempfile
+
+import torch
+from huggingface_hub import hf_hub_download
+from safetensors import safe_open
+from safetensors.torch import load_file
+
+from diffusers import Magi1Pipeline, Magi1Transformer3DModel
+from diffusers.models.autoencoders import AutoencoderKLMagi1
+
+
+TRANSFORMER_KEYS_RENAME_DICT = {
+ "t_embedder.mlp.0": "condition_embedder.time_embedder.linear_1",
+ "t_embedder.mlp.2": "condition_embedder.time_embedder.linear_2",
+ "y_embedder.y_proj_adaln.0": "condition_embedder.text_embedder.linear_1",
+ "y_embedder.y_proj_adaln.2": "condition_embedder.text_embedder.linear_2",
+ "y_embedder.y_proj_xattn.0": "condition_embedder.text_proj",
+ "videodit_blocks.final_layernorm": "norm_out",
+ "final_linear.linear": "proj_out",
+ "x_embedder": "patch_embedding",
+}
+
+
+BLOCK_COMPONENT_MAPPINGS = {
+ "self_attention.linear_qkv.q": "attn1.to_q",
+ "self_attention.linear_qkv.k": "attn1.to_k",
+ "self_attention.linear_qkv.v": "attn1.to_v",
+ "self_attention.linear_proj": "attn1.to_out.0",
+ "self_attention.q_layernorm": "attn1.norm_q",
+ "self_attention.k_layernorm": "attn1.norm_k",
+ "self_attention.linear_qkv.layer_norm": "norm1",
+ "self_attention.linear_qkv.qx": "attn2.to_q",
+ "self_attention.q_layernorm_xattn": "attn2.norm_q",
+ "self_attention.k_layernorm_xattn": "attn2.norm_k",
+ "mlp.linear_fc1": "ff.net.0.proj",
+ "mlp.linear_fc2": "ff.net.2",
+ "mlp.layer_norm": "norm3",
+ "self_attn_post_norm": "norm2",
+ "mlp_post_norm": "norm4",
+ "ada_modulate_layer.proj.0": "scale_shift_table",
+}
+
+
+TRANSFORMER_SPECIAL_KEYS_REMAP = {}
+
+
+def convert_magi_transformer(model_type):
+ """
+ Convert MAGI-1 transformer for a specific model type.
+
+ Args:
+ model_type: The model type (e.g., "MAGI-1-T2V-4.5B-distill", "MAGI-1-T2V-24B-distill", etc.)
+
+ Returns:
+ The converted transformer model.
+ """
+
+ model_type_mapping = {
+ "MAGI-1-T2V-4.5B-distill": "4.5B_distill",
+ "MAGI-1-T2V-24B-distill": "24B_distill",
+ "MAGI-1-T2V-4.5B": "4.5B",
+ "MAGI-1-T2V-24B": "24B",
+ "4.5B_distill": "4.5B_distill",
+ "24B_distill": "24B_distill",
+ "4.5B": "4.5B",
+ "24B": "24B",
+ }
+
+ repo_path = model_type_mapping.get(model_type, model_type)
+
+ temp_dir = tempfile.mkdtemp()
+ transformer_ckpt_dir = os.path.join(temp_dir, "transformer_checkpoint")
+ os.makedirs(transformer_ckpt_dir, exist_ok=True)
+
+ checkpoint_files = []
+ shard_index = 1
+ while True:
+ try:
+ if shard_index == 1:
+ shard_filename = f"model-{shard_index:05d}-of-00002.safetensors"
+ shard_path = hf_hub_download(
+ "sand-ai/MAGI-1", f"ckpt/magi/{repo_path}/inference_weight.distill/{shard_filename}"
+ )
+ checkpoint_files.append(shard_path)
+ print(f"Downloaded {shard_filename}")
+ shard_index += 1
+ elif shard_index == 2:
+ shard_filename = f"model-{shard_index:05d}-of-00002.safetensors"
+ shard_path = hf_hub_download(
+ "sand-ai/MAGI-1", f"ckpt/magi/{repo_path}/inference_weight.distill/{shard_filename}"
+ )
+ checkpoint_files.append(shard_path)
+ print(f"Downloaded {shard_filename}")
+ break
+ else:
+ break
+ except Exception as e:
+ print(f"No more shards found or error downloading shard {shard_index}: {e}")
+ break
+
+ if not checkpoint_files:
+ raise ValueError(f"No checkpoint files found for model type: {model_type}")
+
+ for i, shard_path in enumerate(checkpoint_files):
+ dest_path = os.path.join(transformer_ckpt_dir, f"model-{i + 1:05d}-of-{len(checkpoint_files):05d}.safetensors")
+ shutil.copy2(shard_path, dest_path)
+
+ transformer = convert_magi_transformer_checkpoint(transformer_ckpt_dir)
+
+ return transformer
+
+
+def convert_magi_vae():
+ vae_ckpt_path = hf_hub_download("sand-ai/MAGI-1", "ckpt/vae/diffusion_pytorch_model.safetensors")
+ checkpoint = load_file(vae_ckpt_path)
+
+ config = {
+ "patch_size": (4, 8, 8),
+ "num_attention_heads": 16,
+ "attention_head_dim": 64,
+ "z_dim": 16,
+ "height": 256,
+ "width": 256,
+ "num_frames": 16,
+ "ffn_dim": 4 * 1024,
+ "num_layers": 24,
+ "eps": 1e-6,
+ }
+
+ vae = AutoencoderKLMagi1(
+ patch_size=config["patch_size"],
+ num_attention_heads=config["num_attention_heads"],
+ attention_head_dim=config["attention_head_dim"],
+ z_dim=config["z_dim"],
+ height=config["height"],
+ width=config["width"],
+ num_frames=config["num_frames"],
+ ffn_dim=config["ffn_dim"],
+ num_layers=config["num_layers"],
+ eps=config["eps"],
+ )
+
+ converted_state_dict = convert_vae_state_dict(checkpoint)
+
+ vae.load_state_dict(converted_state_dict, strict=True)
+
+ return vae
+
+
+def convert_vae_state_dict(checkpoint):
+ """
+ Convert MAGI-1 VAE state dict to diffusers format.
+
+ Maps the keys from the MAGI-1 VAE state dict to the diffusers VAE state dict.
+ """
+ state_dict = {}
+
+ state_dict["encoder.patch_embedding.weight"] = checkpoint["encoder.patch_embed.proj.weight"]
+ state_dict["encoder.patch_embedding.bias"] = checkpoint["encoder.patch_embed.proj.bias"]
+
+ state_dict["encoder.pos_embed"] = checkpoint["encoder.pos_embed"]
+
+ state_dict["encoder.cls_token"] = checkpoint["encoder.cls_token"]
+
+ for i in range(24):
+ qkv_weight = checkpoint[f"encoder.blocks.{i}.attn.qkv.weight"]
+ qkv_bias = checkpoint[f"encoder.blocks.{i}.attn.qkv.bias"]
+
+ q_weight, k_weight, v_weight = qkv_weight.chunk(3, dim=0)
+ q_bias, k_bias, v_bias = qkv_bias.chunk(3, dim=0)
+
+ state_dict[f"encoder.blocks.{i}.attn.to_q.weight"] = q_weight
+ state_dict[f"encoder.blocks.{i}.attn.to_q.bias"] = q_bias
+ state_dict[f"encoder.blocks.{i}.attn.to_k.weight"] = k_weight
+ state_dict[f"encoder.blocks.{i}.attn.to_k.bias"] = k_bias
+ state_dict[f"encoder.blocks.{i}.attn.to_v.weight"] = v_weight
+ state_dict[f"encoder.blocks.{i}.attn.to_v.bias"] = v_bias
+
+ state_dict[f"encoder.blocks.{i}.attn.to_out.0.weight"] = checkpoint[f"encoder.blocks.{i}.attn.proj.weight"]
+ state_dict[f"encoder.blocks.{i}.attn.to_out.0.bias"] = checkpoint[f"encoder.blocks.{i}.attn.proj.bias"]
+
+ state_dict[f"encoder.blocks.{i}.norm2.weight"] = checkpoint[f"encoder.blocks.{i}.norm2.weight"]
+ state_dict[f"encoder.blocks.{i}.norm2.bias"] = checkpoint[f"encoder.blocks.{i}.norm2.bias"]
+
+ state_dict[f"encoder.blocks.{i}.proj_out.net.0.proj.weight"] = checkpoint[f"encoder.blocks.{i}.mlp.fc1.weight"]
+ state_dict[f"encoder.blocks.{i}.proj_out.net.0.proj.bias"] = checkpoint[f"encoder.blocks.{i}.mlp.fc1.bias"]
+ state_dict[f"encoder.blocks.{i}.proj_out.net.2.weight"] = checkpoint[f"encoder.blocks.{i}.mlp.fc2.weight"]
+
+ state_dict[f"encoder.blocks.{i}.proj_out.net.2.bias"] = checkpoint[f"encoder.blocks.{i}.mlp.fc2.bias"]
+
+ state_dict["encoder.norm_out.weight"] = checkpoint["encoder.norm.weight"]
+ state_dict["encoder.norm_out.bias"] = checkpoint["encoder.norm.bias"]
+
+ state_dict["encoder.linear_out.weight"] = checkpoint["encoder.last_layer.weight"]
+ state_dict["encoder.linear_out.bias"] = checkpoint["encoder.last_layer.bias"]
+
+ state_dict["decoder.proj_in.weight"] = checkpoint["decoder.proj_in.weight"]
+ state_dict["decoder.proj_in.bias"] = checkpoint["decoder.proj_in.bias"]
+
+ state_dict["decoder.pos_embed"] = checkpoint["decoder.pos_embed"]
+
+ state_dict["decoder.cls_token"] = checkpoint["decoder.cls_token"]
+
+ for i in range(24):
+ qkv_weight = checkpoint[f"decoder.blocks.{i}.attn.qkv.weight"]
+ qkv_bias = checkpoint[f"decoder.blocks.{i}.attn.qkv.bias"]
+
+ q_weight, k_weight, v_weight = qkv_weight.chunk(3, dim=0)
+ q_bias, k_bias, v_bias = qkv_bias.chunk(3, dim=0)
+
+ state_dict[f"decoder.blocks.{i}.attn.to_q.weight"] = q_weight
+ state_dict[f"decoder.blocks.{i}.attn.to_q.bias"] = q_bias
+ state_dict[f"decoder.blocks.{i}.attn.to_k.weight"] = k_weight
+ state_dict[f"decoder.blocks.{i}.attn.to_k.bias"] = k_bias
+ state_dict[f"decoder.blocks.{i}.attn.to_v.weight"] = v_weight
+ state_dict[f"decoder.blocks.{i}.attn.to_v.bias"] = v_bias
+
+ state_dict[f"decoder.blocks.{i}.attn.to_out.0.weight"] = checkpoint[f"decoder.blocks.{i}.attn.proj.weight"]
+ state_dict[f"decoder.blocks.{i}.attn.to_out.0.bias"] = checkpoint[f"decoder.blocks.{i}.attn.proj.bias"]
+
+ state_dict[f"decoder.blocks.{i}.norm2.weight"] = checkpoint[f"decoder.blocks.{i}.norm2.weight"]
+ state_dict[f"decoder.blocks.{i}.norm2.bias"] = checkpoint[f"decoder.blocks.{i}.norm2.bias"]
+
+ state_dict[f"decoder.blocks.{i}.proj_out.net.0.proj.weight"] = checkpoint[f"decoder.blocks.{i}.mlp.fc1.weight"]
+ state_dict[f"decoder.blocks.{i}.proj_out.net.0.proj.bias"] = checkpoint[f"decoder.blocks.{i}.mlp.fc1.bias"]
+ state_dict[f"decoder.blocks.{i}.proj_out.net.2.weight"] = checkpoint[f"decoder.blocks.{i}.mlp.fc2.weight"]
+ state_dict[f"decoder.blocks.{i}.proj_out.net.2.bias"] = checkpoint[f"decoder.blocks.{i}.mlp.fc2.bias"]
+
+ state_dict["decoder.norm_out.weight"] = checkpoint["decoder.norm.weight"]
+ state_dict["decoder.norm_out.bias"] = checkpoint["decoder.norm.bias"]
+
+ state_dict["decoder.conv_out.weight"] = checkpoint["decoder.last_layer.weight"]
+ state_dict["decoder.conv_out.bias"] = checkpoint["decoder.last_layer.bias"]
+
+ return state_dict
+
+
+def load_magi_transformer_checkpoint(checkpoint_path):
+ """
+ Load a MAGI-1 transformer checkpoint.
+
+ Args:
+ checkpoint_path: Path to the MAGI-1 transformer checkpoint.
+
+ Returns:
+ The loaded checkpoint state dict.
+ """
+ if checkpoint_path.endswith(".safetensors"):
+ state_dict = load_file(checkpoint_path)
+ elif os.path.isdir(checkpoint_path):
+ safetensors_files = [f for f in os.listdir(checkpoint_path) if f.endswith(".safetensors")]
+ if safetensors_files:
+ state_dict = {}
+ for safetensors_file in sorted(safetensors_files):
+ file_path = os.path.join(checkpoint_path, safetensors_file)
+ with safe_open(file_path, framework="pt", device="cpu") as f:
+ for key in f.keys():
+ state_dict[key] = f.get_tensor(key)
+ else:
+ checkpoint_files = [f for f in os.listdir(checkpoint_path) if f.endswith(".pt") or f.endswith(".pth")]
+ if not checkpoint_files:
+ raise ValueError(f"No checkpoint files found in {checkpoint_path}")
+
+ checkpoint_file = os.path.join(checkpoint_path, checkpoint_files[0])
+ checkpoint_data = torch.load(checkpoint_file, map_location="cpu")
+
+ if isinstance(checkpoint_data, dict):
+ if "model" in checkpoint_data:
+ state_dict = checkpoint_data["model"]
+ elif "state_dict" in checkpoint_data:
+ state_dict = checkpoint_data["state_dict"]
+ else:
+ state_dict = checkpoint_data
+ else:
+ state_dict = checkpoint_data
+ else:
+ checkpoint_data = torch.load(checkpoint_path, map_location="cpu")
+
+ if isinstance(checkpoint_data, dict):
+ if "model" in checkpoint_data:
+ state_dict = checkpoint_data["model"]
+ elif "state_dict" in checkpoint_data:
+ state_dict = checkpoint_data["state_dict"]
+ else:
+ state_dict = checkpoint_data
+ else:
+ state_dict = checkpoint_data
+
+ return state_dict
+
+
+def convert_magi_transformer_checkpoint(checkpoint_path, transformer_config_file=None, dtype=None):
+ """
+ Convert a MAGI-1 transformer checkpoint to a diffusers Magi1Transformer3DModel.
+
+ Args:
+ checkpoint_path: Path to the MAGI-1 transformer checkpoint.
+ transformer_config_file: Optional path to a transformer config file.
+ dtype: Optional dtype for the model.
+
+ Returns:
+ A diffusers Magi1Transformer3DModel model.
+ """
+ if transformer_config_file is not None:
+ with open(transformer_config_file, "r") as f:
+ config = json.load(f)
+ else:
+ config = {
+ "in_channels": 16,
+ "out_channels": 16,
+ "num_layers": 34,
+ "num_attention_heads": 24,
+ "attention_head_dim": 128,
+ "cross_attention_dim": 4096,
+ "freq_dim": 256,
+ "ffn_dim": 12288,
+ "patch_size": (1, 2, 2),
+ "use_linear_projection": False,
+ "upcast_attention": False,
+ "cross_attn_norm": True,
+ "qk_norm": "rms_norm_across_heads",
+ "eps": 1e-6,
+ "rope_max_seq_len": 1024,
+ }
+
+ transformer = Magi1Transformer3DModel(
+ in_channels=config["in_channels"],
+ out_channels=config["out_channels"],
+ num_layers=config["num_layers"],
+ num_attention_heads=config["num_attention_heads"],
+ attention_head_dim=config["attention_head_dim"],
+ cross_attention_dim=config["cross_attention_dim"],
+ freq_dim=config["freq_dim"],
+ ffn_dim=config["ffn_dim"],
+ patch_size=config["patch_size"],
+ use_linear_projection=config["use_linear_projection"],
+ upcast_attention=config["upcast_attention"],
+ cross_attn_norm=config["cross_attn_norm"],
+ qk_norm=config["qk_norm"],
+ eps=config["eps"],
+ rope_max_seq_len=config["rope_max_seq_len"],
+ )
+
+ checkpoint = load_magi_transformer_checkpoint(checkpoint_path)
+
+ converted_state_dict = convert_transformer_state_dict(checkpoint)
+
+ missing_keys, unexpected_keys = transformer.load_state_dict(converted_state_dict, strict=False)
+
+ print(f"Missing keys ({len(missing_keys)}): {missing_keys}")
+ print(f"Unexpected keys ({len(unexpected_keys)}): {unexpected_keys}")
+
+ missing_keys, unexpected_keys = transformer.load_state_dict(converted_state_dict, strict=False)
+
+ if dtype is not None:
+ transformer = transformer.to(dtype=dtype)
+
+ return transformer
+
+
+def convert_transformer_state_dict(checkpoint):
+ """
+ Convert MAGI-1 transformer state dict to diffusers format.
+
+ Maps the original MAGI-1 parameter names to diffusers' standard transformer naming.
+ Handles all shape mismatches and key mappings based on actual checkpoint analysis.
+ """
+ converted_state_dict = {}
+
+ print("Converting MAGI-1 checkpoint keys...")
+
+ converted_state_dict["patch_embedding.weight"] = checkpoint["x_embedder.weight"]
+
+ converted_state_dict["condition_embedder.time_embedder.linear_1.weight"] = checkpoint["t_embedder.mlp.0.weight"]
+ converted_state_dict["condition_embedder.time_embedder.linear_1.bias"] = checkpoint["t_embedder.mlp.0.bias"]
+
+ converted_state_dict["condition_embedder.time_embedder.linear_2.weight"] = checkpoint["t_embedder.mlp.2.weight"]
+ converted_state_dict["condition_embedder.time_embedder.linear_2.bias"] = checkpoint["t_embedder.mlp.2.bias"]
+
+ converted_state_dict["condition_embedder.text_embedder.linear_1.weight"] = checkpoint[
+ "y_embedder.y_proj_adaln.0.weight"
+ ]
+ converted_state_dict["condition_embedder.text_embedder.linear_1.bias"] = checkpoint[
+ "y_embedder.y_proj_adaln.0.bias"
+ ]
+
+ converted_state_dict["condition_embedder.text_embedder.linear_2.weight"] = checkpoint[
+ "y_embedder.y_proj_adaln.2.weight"
+ ]
+ converted_state_dict["condition_embedder.text_embedder.linear_2.bias"] = checkpoint[
+ "y_embedder.y_proj_adaln.2.bias"
+ ]
+
+ converted_state_dict["condition_embedder.text_proj.weight"] = checkpoint["y_embedder.y_proj_xattn.0.weight"]
+ converted_state_dict["condition_embedder.text_proj.bias"] = checkpoint["y_embedder.y_proj_xattn.0.bias"]
+
+ converted_state_dict["condition_embedder.text_embedder.null_caption_embedding"] = checkpoint[
+ "y_embedder.null_caption_embedding"
+ ]
+
+ converted_state_dict["norm_out.weight"] = checkpoint["videodit_blocks.final_layernorm.weight"]
+ converted_state_dict["norm_out.bias"] = checkpoint["videodit_blocks.final_layernorm.bias"]
+
+ converted_state_dict["proj_out.weight"] = checkpoint["final_linear.linear.weight"]
+ converted_state_dict["proj_out.bias"] = checkpoint["final_linear.linear.bias"]
+
+ converted_state_dict["rope.freqs"] = checkpoint["rope.bands"]
+
+ for layer_idx in range(34):
+ layer_prefix = f"videodit_blocks.layers.{layer_idx}"
+ block_prefix = f"blocks.{layer_idx}"
+
+ converted_state_dict[f"{block_prefix}.norm1.weight"] = checkpoint[
+ f"{layer_prefix}.self_attention.linear_qkv.layer_norm.weight"
+ ]
+ converted_state_dict[f"{block_prefix}.norm1.bias"] = checkpoint[
+ f"{layer_prefix}.self_attention.linear_qkv.layer_norm.bias"
+ ]
+
+ converted_state_dict[f"{block_prefix}.attn1.to_q.weight"] = checkpoint[
+ f"{layer_prefix}.self_attention.linear_qkv.q.weight"
+ ]
+ converted_state_dict[f"{block_prefix}.attn1.to_q.bias"] = checkpoint[
+ f"{layer_prefix}.self_attention.linear_qkv.q.bias"
+ ]
+
+ converted_state_dict[f"{block_prefix}.attn1.to_k.weight"] = checkpoint[
+ f"{layer_prefix}.self_attention.linear_qkv.k.weight"
+ ]
+ converted_state_dict[f"{block_prefix}.attn1.to_k.bias"] = checkpoint[
+ f"{layer_prefix}.self_attention.linear_qkv.k.bias"
+ ]
+
+ converted_state_dict[f"{block_prefix}.attn1.to_v.weight"] = checkpoint[
+ f"{layer_prefix}.self_attention.linear_qkv.v.weight"
+ ]
+ converted_state_dict[f"{block_prefix}.attn1.to_v.bias"] = checkpoint[
+ f"{layer_prefix}.self_attention.linear_qkv.v.bias"
+ ]
+
+ converted_state_dict[f"{block_prefix}.attn1.to_out.0.weight"] = checkpoint[
+ f"{layer_prefix}.self_attention.linear_proj.weight"
+ ]
+ converted_state_dict[f"{block_prefix}.attn1.to_out.0.bias"] = checkpoint[
+ f"{layer_prefix}.self_attention.linear_proj.bias"
+ ]
+
+ converted_state_dict[f"{block_prefix}.attn1.norm_q.weight"] = checkpoint[
+ f"{layer_prefix}.self_attention.q_layernorm.weight"
+ ]
+ converted_state_dict[f"{block_prefix}.attn1.norm_q.bias"] = checkpoint[
+ f"{layer_prefix}.self_attention.q_layernorm.bias"
+ ]
+
+ converted_state_dict[f"{block_prefix}.attn1.norm_k.weight"] = checkpoint[
+ f"{layer_prefix}.self_attention.k_layernorm.weight"
+ ]
+ converted_state_dict[f"{block_prefix}.attn1.norm_k.bias"] = checkpoint[
+ f"{layer_prefix}.self_attention.k_layernorm.bias"
+ ]
+
+ converted_state_dict[f"{block_prefix}.attn2.to_q.weight"] = checkpoint[
+ f"{layer_prefix}.self_attention.linear_qkv.qx.weight"
+ ]
+ converted_state_dict[f"{block_prefix}.attn2.to_q.bias"] = checkpoint[
+ f"{layer_prefix}.self_attention.linear_qkv.qx.bias"
+ ]
+
+ kv_weight = checkpoint[f"{layer_prefix}.self_attention.linear_kv_xattn.weight"]
+ k_weight, v_weight = kv_weight.chunk(2, dim=0)
+ converted_state_dict[f"{block_prefix}.attn2.to_k.weight"] = k_weight
+ converted_state_dict[f"{block_prefix}.attn2.to_v.weight"] = v_weight
+
+ kv_bias = checkpoint[f"{layer_prefix}.self_attention.linear_kv_xattn.bias"]
+ k_bias, v_bias = kv_bias.chunk(2, dim=0)
+ converted_state_dict[f"{block_prefix}.attn2.to_k.bias"] = k_bias
+ converted_state_dict[f"{block_prefix}.attn2.to_v.bias"] = v_bias
+
+ converted_state_dict[f"{block_prefix}.attn2.to_out.0.weight"] = converted_state_dict[
+ f"{block_prefix}.attn1.to_out.0.weight"
+ ]
+ converted_state_dict[f"{block_prefix}.attn2.to_out.0.bias"] = converted_state_dict[
+ f"{block_prefix}.attn1.to_out.0.bias"
+ ]
+
+ converted_state_dict[f"{block_prefix}.attn2.norm_q.weight"] = checkpoint[
+ f"{layer_prefix}.self_attention.q_layernorm_xattn.weight"
+ ]
+ converted_state_dict[f"{block_prefix}.attn2.norm_q.bias"] = checkpoint[
+ f"{layer_prefix}.self_attention.q_layernorm_xattn.bias"
+ ]
+
+ converted_state_dict[f"{block_prefix}.attn2.norm_k.weight"] = checkpoint[
+ f"{layer_prefix}.self_attention.k_layernorm_xattn.weight"
+ ]
+ converted_state_dict[f"{block_prefix}.attn2.norm_k.bias"] = checkpoint[
+ f"{layer_prefix}.self_attention.k_layernorm_xattn.bias"
+ ]
+
+ converted_state_dict[f"{block_prefix}.norm2.weight"] = checkpoint[f"{layer_prefix}.self_attn_post_norm.weight"]
+ converted_state_dict[f"{block_prefix}.norm2.bias"] = checkpoint[f"{layer_prefix}.self_attn_post_norm.bias"]
+
+ converted_state_dict[f"{block_prefix}.norm3.weight"] = checkpoint[f"{layer_prefix}.mlp.layer_norm.weight"]
+ converted_state_dict[f"{block_prefix}.norm3.bias"] = checkpoint[f"{layer_prefix}.mlp.layer_norm.bias"]
+
+ converted_state_dict[f"{block_prefix}.ff.net.0.proj.weight"] = checkpoint[
+ f"{layer_prefix}.mlp.linear_fc1.weight"
+ ]
+ converted_state_dict[f"{block_prefix}.ff.net.0.proj.bias"] = checkpoint[f"{layer_prefix}.mlp.linear_fc1.bias"]
+
+ converted_state_dict[f"{block_prefix}.ff.net.2.weight"] = checkpoint[f"{layer_prefix}.mlp.linear_fc2.weight"]
+ converted_state_dict[f"{block_prefix}.ff.net.2.bias"] = checkpoint[f"{layer_prefix}.mlp.linear_fc2.bias"]
+
+ converted_state_dict[f"{block_prefix}.norm4.weight"] = checkpoint[f"{layer_prefix}.mlp_post_norm.weight"]
+ converted_state_dict[f"{block_prefix}.norm4.bias"] = checkpoint[f"{layer_prefix}.mlp_post_norm.bias"]
+
+ converted_state_dict[f"{block_prefix}.scale_shift_table.weight"] = checkpoint[
+ f"{layer_prefix}.ada_modulate_layer.proj.0.weight"
+ ]
+ converted_state_dict[f"{block_prefix}.scale_shift_table.bias"] = checkpoint[
+ f"{layer_prefix}.ada_modulate_layer.proj.0.bias"
+ ]
+
+ print(f"Converted {len(converted_state_dict)} parameters")
+ return converted_state_dict
+
+
+def get_args():
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--model_type", type=str, default=None)
+ parser.add_argument("--output_path", type=str, required=True)
+ parser.add_argument("--dtype", default="fp32", choices=["fp32", "fp16", "bf16", "none"])
+ return parser.parse_args()
+
+
+DTYPE_MAPPING = {
+ "fp32": torch.float32,
+ "fp16": torch.float16,
+ "bf16": torch.bfloat16,
+}
+
+if __name__ == "__main__":
+ args = get_args()
+
+ transformer = convert_magi_transformer(args.model_type)
+ # vae = convert_magi_vae()
+ # text_encoder = T5EncoderModel.from_pretrained("DeepFloyd/t5-v1_1-xxl")
+ # tokenizer = AutoTokenizer.from_pretrained("DeepFloyd/t5-v1_1-xxl")
+ # flow_shift = 16.0 if "FLF2V" in args.model_type else 3.0
+ # scheduler = UniPCMultistepScheduler(
+ # prediction_type="flow_prediction", use_flow_sigmas=True, num_train_timesteps=1000, flow_shift=flow_shift
+ # )
+
+ # If user has specified "none", we keep the original dtypes of the state dict without any conversion
+ if args.dtype != "none":
+ dtype = DTYPE_MAPPING[args.dtype]
+ transformer.to(dtype)
+
+ # if "I2V" in args.model_type or "FLF2V" in args.model_type:
+ # image_encoder = CLIPVisionModelWithProjection.from_pretrained(
+ # "laion/CLIP-ViT-H-14-laion2B-s32B-b79K", torch_dtype=torch.bfloat16
+ # )
+ # image_processor = AutoProcessor.from_pretrained("laion/CLIP-ViT-H-14-laion2B-s32B-b79K")
+ # pipe = Magi1ImageToVideoPipeline(
+ # transformer=transformer,
+ # text_encoder=text_encoder,
+ # tokenizer=tokenizer,
+ # vae=vae,
+ # scheduler=scheduler,
+ # image_encoder=image_encoder,
+ # image_processor=image_processor,
+ # )
+ # else:
+ pipe = Magi1Pipeline(
+ transformer=transformer,
+ text_encoder=None, # text_encoder,
+ tokenizer=None, # tokenizer,
+ vae=None, # vae,
+ scheduler=None, # scheduler,
+ )
+
+ pipe.save_pretrained(
+ args.output_path,
+ safe_serialization=True,
+ max_shard_size="5GB",
+ push_to_hub=True,
+ repo_id=f"tolgacangoz/{args.model_type}-Diffusers",
+ )
diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py
index 4c383c817efe..a1ed56725b75 100644
--- a/src/diffusers/__init__.py
+++ b/src/diffusers/__init__.py
@@ -151,6 +151,7 @@
"AutoencoderKLCosmos",
"AutoencoderKLHunyuanVideo",
"AutoencoderKLLTXVideo",
+ "AutoencoderKLMagi1",
"AutoencoderKLMagvit",
"AutoencoderKLMochi",
"AutoencoderKLTemporalDecoder",
@@ -186,6 +187,7 @@
"LTXVideoTransformer3DModel",
"Lumina2Transformer2DModel",
"LuminaNextDiT2DModel",
+ "Magi1Transformer3DModel",
"MochiTransformer3DModel",
"ModelMixin",
"MotionAdapter",
@@ -434,6 +436,9 @@
"Lumina2Text2ImgPipeline",
"LuminaPipeline",
"LuminaText2ImgPipeline",
+ "Magi1ImageToVideoPipeline",
+ "Magi1Pipeline",
+ "Magi1VideoToVideoPipeline",
"MarigoldDepthPipeline",
"MarigoldIntrinsicsPipeline",
"MarigoldNormalsPipeline",
@@ -767,6 +772,7 @@
AutoencoderKLCosmos,
AutoencoderKLHunyuanVideo,
AutoencoderKLLTXVideo,
+ AutoencoderKLMagi1,
AutoencoderKLMagvit,
AutoencoderKLMochi,
AutoencoderKLTemporalDecoder,
@@ -802,6 +808,7 @@
LTXVideoTransformer3DModel,
Lumina2Transformer2DModel,
LuminaNextDiT2DModel,
+ Magi1Transformer3DModel,
MochiTransformer3DModel,
ModelMixin,
MotionAdapter,
@@ -1029,6 +1036,9 @@
Lumina2Text2ImgPipeline,
LuminaPipeline,
LuminaText2ImgPipeline,
+ Magi1ImageToVideoPipeline,
+ Magi1Pipeline,
+ Magi1VideoToVideoPipeline,
MarigoldDepthPipeline,
MarigoldIntrinsicsPipeline,
MarigoldNormalsPipeline,
diff --git a/src/diffusers/models/__init__.py b/src/diffusers/models/__init__.py
index 73903a627415..3345489a5ff5 100755
--- a/src/diffusers/models/__init__.py
+++ b/src/diffusers/models/__init__.py
@@ -35,6 +35,7 @@
_import_structure["autoencoders.autoencoder_kl_cosmos"] = ["AutoencoderKLCosmos"]
_import_structure["autoencoders.autoencoder_kl_hunyuan_video"] = ["AutoencoderKLHunyuanVideo"]
_import_structure["autoencoders.autoencoder_kl_ltx"] = ["AutoencoderKLLTXVideo"]
+ _import_structure["autoencoders.autoencoder_kl_magi1"] = ["AutoencoderKLMagi1"]
_import_structure["autoencoders.autoencoder_kl_magvit"] = ["AutoencoderKLMagvit"]
_import_structure["autoencoders.autoencoder_kl_mochi"] = ["AutoencoderKLMochi"]
_import_structure["autoencoders.autoencoder_kl_temporal_decoder"] = ["AutoencoderKLTemporalDecoder"]
@@ -85,6 +86,7 @@
_import_structure["transformers.transformer_hunyuan_video_framepack"] = ["HunyuanVideoFramepackTransformer3DModel"]
_import_structure["transformers.transformer_ltx"] = ["LTXVideoTransformer3DModel"]
_import_structure["transformers.transformer_lumina2"] = ["Lumina2Transformer2DModel"]
+ _import_structure["transformers.transformer_magi1"] = ["Magi1Transformer3DModel"]
_import_structure["transformers.transformer_mochi"] = ["MochiTransformer3DModel"]
_import_structure["transformers.transformer_omnigen"] = ["OmniGenTransformer2DModel"]
_import_structure["transformers.transformer_sd3"] = ["SD3Transformer2DModel"]
@@ -121,6 +123,7 @@
AutoencoderKLCosmos,
AutoencoderKLHunyuanVideo,
AutoencoderKLLTXVideo,
+ AutoencoderKLMagi1,
AutoencoderKLMagvit,
AutoencoderKLMochi,
AutoencoderKLTemporalDecoder,
@@ -170,6 +173,7 @@
LTXVideoTransformer3DModel,
Lumina2Transformer2DModel,
LuminaNextDiT2DModel,
+ Magi1Transformer3DModel,
MochiTransformer3DModel,
OmniGenTransformer2DModel,
PixArtTransformer2DModel,
diff --git a/src/diffusers/models/autoencoders/__init__.py b/src/diffusers/models/autoencoders/__init__.py
index 742d747ae25e..7b93469961c1 100644
--- a/src/diffusers/models/autoencoders/__init__.py
+++ b/src/diffusers/models/autoencoders/__init__.py
@@ -6,6 +6,7 @@
from .autoencoder_kl_cosmos import AutoencoderKLCosmos
from .autoencoder_kl_hunyuan_video import AutoencoderKLHunyuanVideo
from .autoencoder_kl_ltx import AutoencoderKLLTXVideo
+from .autoencoder_kl_magi1 import AutoencoderKLMagi1
from .autoencoder_kl_magvit import AutoencoderKLMagvit
from .autoencoder_kl_mochi import AutoencoderKLMochi
from .autoencoder_kl_temporal_decoder import AutoencoderKLTemporalDecoder
diff --git a/src/diffusers/models/autoencoders/autoencoder_kl_magi1.py b/src/diffusers/models/autoencoders/autoencoder_kl_magi1.py
new file mode 100644
index 000000000000..8f43e898fea9
--- /dev/null
+++ b/src/diffusers/models/autoencoders/autoencoder_kl_magi1.py
@@ -0,0 +1,768 @@
+# Copyright 2025 The Sand AI Team and The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import List, Optional, Tuple, Union
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from ...configuration_utils import ConfigMixin, register_to_config
+from ...loaders import FromOriginalModelMixin
+from ...utils import logging
+from ...utils.accelerate_utils import apply_forward_hook
+from ..attention import FeedForward
+from ..attention_processor import Attention
+from ..modeling_outputs import AutoencoderKLOutput
+from ..modeling_utils import ModelMixin
+from .vae import DecoderOutput, DiagonalGaussianDistribution
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+
+def resize_pos_embed(posemb, src_shape, target_shape):
+ posemb = posemb.reshape(1, src_shape[0], src_shape[1], src_shape[2], -1)
+ posemb = posemb.permute(0, 4, 1, 2, 3)
+ posemb = nn.functional.interpolate(posemb, size=target_shape, mode="trilinear", align_corners=False)
+ posemb = posemb.permute(0, 2, 3, 4, 1)
+ posemb = posemb.reshape(1, target_shape[0] * target_shape[1] * target_shape[2], -1)
+ return posemb
+
+
+class Magi1VAELayerNorm(nn.Module):
+ def __init__(self, normalized_shape, eps=1e-5, elementwise_affine=True):
+ super(Magi1VAELayerNorm, self).__init__()
+ self.normalized_shape = normalized_shape
+ self.eps = eps
+ self.elementwise_affine = elementwise_affine
+
+ def forward(self, x):
+ mean = x.mean(dim=-1, keepdim=True)
+ std = x.std(dim=-1, keepdim=True, unbiased=False)
+
+ x_normalized = (x - mean) / (std + self.eps)
+
+ return x_normalized
+
+
+class Magi1VAEAttnProcessor2_0:
+ def __init__(self, dim, num_heads=8):
+ if not hasattr(F, "scaled_dot_product_attention"):
+ raise ImportError("WanAttnProcessor2_0 requires PyTorch 2.0. To use it, please upgrade PyTorch to 2.0.")
+
+ self.qkv_norm = Magi1VAELayerNorm(dim // num_heads, elementwise_affine=False)
+
+ def __call__(
+ self,
+ attn: Attention,
+ hidden_states: torch.Tensor,
+ encoder_hidden_states: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ rotary_emb: Optional[torch.Tensor] = None,
+ ) -> torch.Tensor:
+ batch_size, time_height_width, channels = hidden_states.size()
+
+ # compute query, key, value
+ query = attn.to_q(hidden_states)
+ key = attn.to_k(hidden_states)
+ value = attn.to_v(hidden_states)
+
+ qkv = torch.cat((query, key, value), dim=2)
+ qkv = qkv.reshape(batch_size, time_height_width, 3, attn.heads, channels // attn.heads)
+ qkv = self.qkv_norm(qkv)
+ query, key, value = qkv.chunk(3, dim=2)
+
+ # Remove the extra dimension from chunking and transpose for scaled dot product attention
+ # Shape: (batch_size, num_heads, time_height_width, head_dim)
+ query = query.squeeze(2).transpose(1, 2)
+ key = key.squeeze(2).transpose(1, 2)
+ value = value.squeeze(2).transpose(1, 2)
+
+ hidden_states = F.scaled_dot_product_attention(
+ query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
+ ) # the output of sdpa = (batch_size, num_heads, seq_len, head_dim)
+ # Reshape hidden_states to (batch_size, time_height_width, channels)
+ hidden_states = hidden_states.transpose(1, 2).flatten(2, 3)
+
+ hidden_states = attn.to_out[0](hidden_states)
+ hidden_states = attn.to_out[1](hidden_states)
+
+ return hidden_states
+
+
+class Magi1VAETransformerBlock(nn.Module):
+ def __init__(
+ self,
+ dim: int,
+ num_heads: int = 8,
+ ffn_dim: int = 4 * 1024,
+ eps: float = 1e-6,
+ ):
+ super().__init__()
+ self.norm1 = nn.Identity()
+ self.attn = Attention(
+ query_dim=dim,
+ heads=num_heads,
+ kv_heads=num_heads,
+ dim_head=dim // num_heads,
+ eps=eps,
+ bias=True,
+ cross_attention_dim=None,
+ out_bias=True,
+ processor=Magi1VAEAttnProcessor2_0(dim, num_heads),
+ )
+
+ self.drop_path = nn.Identity()
+ self.norm2 = nn.LayerNorm(dim)
+
+ self.proj_out = FeedForward(dim, inner_dim=ffn_dim, activation_fn="gelu")
+
+ self.gradient_checkpointing = False
+
+ def forward(self, x):
+ x = x + self.drop_path(self.attn(self.norm1(x)))
+ x = x + self.drop_path(self.proj_out(self.norm2(x)))
+ return x
+
+
+class Magi1Encoder3d(nn.Module):
+ def __init__(
+ self,
+ inner_dim=128,
+ z_dim=4,
+ patch_size: Tuple[int] = (1, 2, 2),
+ num_frames: int = 16,
+ height: int = 256,
+ width: int = 256,
+ num_attention_heads: int = 40,
+ ffn_dim: int = 4 * 1024,
+ num_layers: int = 24,
+ eps: float = 1e-6,
+ ):
+ super().__init__()
+ self.z_dim = z_dim
+ self.height = height
+ self.width = width
+ self.num_frames = num_frames
+
+ # 1. Patch & position embedding
+ self.patch_embedding = nn.Conv3d(3, inner_dim, kernel_size=patch_size, stride=patch_size)
+ self.patch_size = patch_size
+
+ self.cls_token_nums = 1
+ self.cls_token = nn.Parameter(torch.zeros(1, 1, inner_dim))
+ # `generator` as a parameter?
+ nn.init.trunc_normal_(self.cls_token, std=0.02)
+
+ p_t, p_h, p_w = patch_size
+ post_patch_num_frames = num_frames // p_t
+ post_patch_height = height // p_h
+ post_patch_width = width // p_w
+ num_patches = post_patch_num_frames * post_patch_height * post_patch_width
+
+ self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + self.cls_token_nums, inner_dim))
+ self.pos_drop = nn.Dropout(p=0.0)
+
+ # 3. Transformer blocks
+ self.blocks = nn.ModuleList(
+ [
+ Magi1VAETransformerBlock(
+ inner_dim,
+ num_attention_heads,
+ ffn_dim,
+ eps,
+ )
+ for _ in range(num_layers)
+ ]
+ )
+
+ # output blocks
+ self.norm_out = nn.LayerNorm(inner_dim)
+ self.linear_out = nn.Linear(inner_dim, z_dim * 2)
+
+ # `generator` as a parameter?
+ nn.init.trunc_normal_(self.pos_embed, std=0.02)
+
+ self.gradient_checkpointing = False
+
+ def forward(self, x):
+ B = x.shape[0]
+ # B C T H W -> B C T/pT H/pH W//pW
+ x = self.patch_embedding(x)
+ latentT, latentH, latentW = x.shape[2], x.shape[3], x.shape[4]
+ # B C T/pT H/pH W//pW -> B (T/pT H/pH W//pW) C
+ x = x.flatten(2).transpose(1, 2)
+
+ cls_tokens = self.cls_token.expand(B, -1, -1) # stole cls_tokens impl from Phil Wang, thanks
+ x = torch.cat((cls_tokens, x), dim=1)
+
+ if latentT != self.patch_size[0] or latentH != self.patch_size[1] or latentW != self.patch_size[2]:
+ pos_embed = resize_pos_embed(
+ self.pos_embed[:, 1:, :],
+ src_shape=(
+ self.num_frames // self.patch_size[0],
+ self.height // self.patch_size[1],
+ self.width // self.patch_size[2],
+ ),
+ target_shape=(latentT, latentH, latentW),
+ )
+ pos_embed = torch.cat((self.pos_embed[:, 0:1, :], pos_embed), dim=1)
+ else:
+ pos_embed = self.pos_embed
+
+ x = x + pos_embed
+ x = self.pos_drop(x)
+
+ ## transformer blocks
+ for block in self.blocks:
+ x = block(x)
+
+ ## head
+ x = self.norm_out(x)
+ x = x[:, 1:] # remove cls_token
+ x = self.linear_out(x)
+
+ # B L C - > B , lT, lH, lW, zC (where zC is now z_dim * 2)
+ x = x.reshape(B, latentT, latentH, latentW, self.z_dim * 2)
+
+ # B , lT, lH, lW, zC -> B, zC, lT, lH, lW
+ x = x.permute(0, 4, 1, 2, 3)
+
+ return x
+
+
+class Magi1Decoder3d(nn.Module):
+ def __init__(
+ self,
+ inner_dim=1024,
+ z_dim=16,
+ patch_size: Tuple[int] = (4, 8, 8),
+ num_frames: int = 16,
+ height: int = 256,
+ width: int = 256,
+ num_attention_heads: int = 16,
+ ffn_dim: int = 4 * 1024,
+ num_layers: int = 24,
+ eps: float = 1e-6,
+ ):
+ super().__init__()
+ self.z_dim = z_dim
+ self.patch_size = patch_size
+ self.height = height
+ self.width = width
+ self.num_frames = num_frames
+
+ # init block
+ self.proj_in = nn.Linear(z_dim, inner_dim)
+
+ self.cls_token_nums = 1
+ self.cls_token = nn.Parameter(torch.zeros(1, 1, inner_dim))
+ # `generator` as a parameter?
+ nn.init.trunc_normal_(self.cls_token, std=0.02)
+
+ p_t, p_h, p_w = patch_size
+ post_patch_num_frames = num_frames // p_t
+ post_patch_height = height // p_h
+ post_patch_width = width // p_w
+ num_patches = post_patch_num_frames * post_patch_height * post_patch_width
+
+ self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + self.cls_token_nums, inner_dim))
+ self.pos_drop = nn.Dropout(p=0.0)
+
+ # 3. Transformer blocks
+ self.blocks = nn.ModuleList(
+ [
+ Magi1VAETransformerBlock(
+ inner_dim,
+ num_attention_heads,
+ ffn_dim,
+ eps,
+ )
+ for _ in range(num_layers)
+ ]
+ )
+
+ # output blocks
+ self.norm_out = nn.LayerNorm(inner_dim)
+ self.unpatch_channels = inner_dim // (patch_size[0] * patch_size[1] * patch_size[2])
+ self.conv_out = nn.Conv3d(self.unpatch_channels, 3, 3, padding=1)
+
+ # `generator` as a parameter?
+ nn.init.trunc_normal_(self.pos_embed, std=0.02)
+
+ self.gradient_checkpointing = False
+
+ def forward(self, x):
+ B, C, latentT, latentH, latentW = x.shape
+ x = x.permute(0, 2, 3, 4, 1)
+
+ x = x.reshape(B, -1, C)
+
+ x = self.proj_in(x)
+
+ cls_tokens = self.cls_token.expand(B, -1, -1) # stole cls_tokens impl from Phil Wang, thanks
+ x = torch.cat((cls_tokens, x), dim=1)
+
+ if latentT != self.patch_size[0] or latentH != self.patch_size[1] or latentW != self.patch_size[2]:
+ pos_embed = resize_pos_embed(
+ self.pos_embed[:, 1:, :],
+ src_shape=(
+ self.num_frames // self.patch_size[0],
+ self.height // self.patch_size[1],
+ self.width // self.patch_size[2],
+ ),
+ target_shape=(latentT, latentH, latentW),
+ )
+ pos_embed = torch.cat((self.pos_embed[:, 0:1, :], pos_embed), dim=1)
+ else:
+ pos_embed = self.pos_embed
+
+ x = x + pos_embed
+ x = self.pos_drop(x)
+
+ ## transformer blocks
+ for block in self.blocks:
+ x = block(x)
+
+ ## head
+ x = self.norm_out(x)
+ x = x[:, 1:] # remove cls_token
+
+ x = x.reshape(
+ B,
+ latentT,
+ latentH,
+ latentW,
+ self.patch_size[0],
+ self.patch_size[1],
+ self.patch_size[2],
+ self.unpatch_channels,
+ )
+ # Rearrange from (B, lT, lH, lW, pT, pH, pW, C) to (B, C, lT*pT, lH*pH, lW*pW)
+ x = x.permute(0, 7, 1, 4, 2, 5, 3, 6) # (B, C, lT, pT, lH, pH, lW, pW)
+ x = x.reshape(
+ B,
+ self.unpatch_channels,
+ latentT * self.patch_size[0],
+ latentH * self.patch_size[1],
+ latentW * self.patch_size[2],
+ )
+
+ x = self.conv_out(x)
+ return x
+
+
+class AutoencoderKLMagi1(ModelMixin, ConfigMixin, FromOriginalModelMixin):
+ r"""
+ A VAE model with KL loss for encoding videos into latents and decoding latent representations into videos.
+ Introduced in [Magi1](https://arxiv.org/abs/2505.13211).
+
+ This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented
+ for all models (such as downloading or saving).
+ """
+
+ _supports_gradient_checkpointing = False
+ _skip_layerwise_casting_patterns = ["patch_embedding", "norm"]
+ _no_split_modules = ["Magi1VAETransformerBlock"]
+ # _keep_in_fp32_modules = ["qkv_norm", "norm1", "norm2"]
+ _keys_to_ignore_on_load_unexpected = ["norm_added_q"]
+
+ @register_to_config
+ def __init__(
+ self,
+ patch_size: Tuple[int] = (4, 8, 8),
+ num_attention_heads: int = 16,
+ attention_head_dim: int = 64,
+ z_dim: int = 16,
+ height: int = 256,
+ width: int = 256,
+ num_frames: int = 16,
+ ffn_dim: int = 4 * 1024,
+ num_layers: int = 24,
+ eps: float = 1e-6,
+ latents_mean: List[float] = [
+ -0.7571,
+ -0.7089,
+ -0.9113,
+ 0.1075,
+ -0.1745,
+ 0.9653,
+ -0.1517,
+ 1.5508,
+ 0.4134,
+ -0.0715,
+ 0.5517,
+ -0.3632,
+ -0.1922,
+ -0.9497,
+ 0.2503,
+ -0.2921,
+ ],
+ latents_std: List[float] = [
+ 2.8184,
+ 1.4541,
+ 2.3275,
+ 2.6558,
+ 1.2196,
+ 1.7708,
+ 2.6052,
+ 2.0743,
+ 3.2687,
+ 2.1526,
+ 2.8652,
+ 1.5579,
+ 1.6382,
+ 1.1253,
+ 2.8251,
+ 1.9160,
+ ],
+ ) -> None:
+ super().__init__()
+
+ inner_dim = num_attention_heads * attention_head_dim
+ self.z_dim = z_dim
+
+ self.encoder = Magi1Encoder3d(
+ inner_dim,
+ z_dim,
+ patch_size,
+ num_frames,
+ height,
+ width,
+ num_attention_heads,
+ ffn_dim,
+ num_layers,
+ eps,
+ )
+
+ self.decoder = Magi1Decoder3d(
+ inner_dim,
+ z_dim,
+ patch_size,
+ num_frames,
+ height,
+ width,
+ num_attention_heads,
+ ffn_dim,
+ num_layers,
+ eps,
+ )
+
+ self.spatial_compression_ratio = patch_size[1] or patch_size[2]
+ self.temporal_compression_ratio = patch_size[0]
+
+ # When decoding a batch of video latents at a time, one can save memory by slicing across the batch dimension
+ # to perform decoding of a single video latent at a time.
+ self.use_slicing = False
+
+ # When decoding spatially large video latents, the memory requirement is very high. By breaking the video latent
+ # frames spatially into smaller tiles and performing multiple forward passes for decoding, and then blending the
+ # intermediate tiles together, the memory requirement can be lowered.
+ self.use_tiling = False
+
+ # The minimal tile height and width for spatial tiling to be used
+ self.tile_sample_min_height = 256
+ self.tile_sample_min_width = 256
+
+ # The minimal distance between two spatial tiles
+ self.tile_sample_stride_height = 192
+ self.tile_sample_stride_width = 192
+
+ def enable_tiling(
+ self,
+ tile_sample_min_height: Optional[int] = None,
+ tile_sample_min_width: Optional[int] = None,
+ tile_sample_stride_height: Optional[float] = None,
+ tile_sample_stride_width: Optional[float] = None,
+ ) -> None:
+ r"""
+ Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
+ compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
+ processing larger images.
+
+ Args:
+ tile_sample_min_height (`int`, *optional*):
+ The minimum height required for a sample to be separated into tiles across the height dimension.
+ tile_sample_min_width (`int`, *optional*):
+ The minimum width required for a sample to be separated into tiles across the width dimension.
+ tile_sample_stride_height (`int`, *optional*):
+ The minimum amount of overlap between two consecutive vertical tiles. This is to ensure that there are
+ no tiling artifacts produced across the height dimension.
+ tile_sample_stride_width (`int`, *optional*):
+ The stride between two consecutive horizontal tiles. This is to ensure that there are no tiling
+ artifacts produced across the width dimension.
+ """
+ self.use_tiling = True
+ self.tile_sample_min_height = tile_sample_min_height or self.tile_sample_min_height
+ self.tile_sample_min_width = tile_sample_min_width or self.tile_sample_min_width
+ self.tile_sample_stride_height = tile_sample_stride_height or self.tile_sample_stride_height
+ self.tile_sample_stride_width = tile_sample_stride_width or self.tile_sample_stride_width
+
+ def disable_tiling(self) -> None:
+ r"""
+ Disable tiled VAE decoding. If `enable_tiling` was previously enabled, this method will go back to computing
+ decoding in one step.
+ """
+ self.use_tiling = False
+
+ def enable_slicing(self) -> None:
+ r"""
+ Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
+ compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
+ """
+ self.use_slicing = True
+
+ def disable_slicing(self) -> None:
+ r"""
+ Disable sliced VAE decoding. If `enable_slicing` was previously enabled, this method will go back to computing
+ decoding in one step.
+ """
+ self.use_slicing = False
+
+ def _encode(self, x: torch.Tensor):
+ _, _, num_frame, height, width = x.shape
+
+ if self.use_tiling and (width > self.tile_sample_min_width or height > self.tile_sample_min_height):
+ return self.tiled_encode(x)
+
+ out = self.encoder(x)
+
+ return out
+
+ @apply_forward_hook
+ def encode(
+ self, x: torch.Tensor, return_dict: bool = True
+ ) -> Union[AutoencoderKLOutput, Tuple[DiagonalGaussianDistribution]]:
+ r"""
+ Encode a batch of images into latents.
+
+ Args:
+ x (`torch.Tensor`): Input batch of images.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether to return a [`~models.autoencoder_kl.AutoencoderKLOutput`] instead of a plain tuple.
+
+ Returns:
+ The latent representations of the encoded videos. If `return_dict` is True, a
+ [`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain `tuple` is returned.
+ """
+ if self.use_slicing and x.shape[0] > 1:
+ encoded_slices = [self._encode(x_slice) for x_slice in x.split(1)]
+ h = torch.cat(encoded_slices)
+ else:
+ h = self._encode(x)
+ posterior = DiagonalGaussianDistribution(h)
+
+ if not return_dict:
+ return (posterior,)
+ return AutoencoderKLOutput(latent_dist=posterior)
+
+ def _decode(self, z: torch.Tensor, return_dict: bool = True):
+ _, _, num_frame, height, width = z.shape
+ tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio
+ tile_latent_min_width = self.tile_sample_min_width // self.spatial_compression_ratio
+
+ if self.use_tiling and (width > tile_latent_min_width or height > tile_latent_min_height):
+ return self.tiled_decode(z, return_dict=return_dict)
+
+ out = self.decoder(z)
+
+ if not return_dict:
+ return (out,)
+
+ return DecoderOutput(sample=out)
+
+ @apply_forward_hook
+ def decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
+ r"""
+ Decode a batch of images.
+
+ Args:
+ z (`torch.Tensor`): Input batch of latent vectors.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether to return a [`~models.vae.DecoderOutput`] instead of a plain tuple.
+
+ Returns:
+ [`~models.vae.DecoderOutput`] or `tuple`:
+ If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is
+ returned.
+ """
+ if self.use_slicing and z.shape[0] > 1:
+ decoded_slices = [self._decode(z_slice).sample for z_slice in z.split(1)]
+ decoded = torch.cat(decoded_slices)
+ else:
+ decoded = self._decode(z).sample
+
+ if not return_dict:
+ return (decoded,)
+ return DecoderOutput(sample=decoded)
+
+ def blend_v(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
+ blend_extent = min(a.shape[-2], b.shape[-2], blend_extent)
+ for y in range(blend_extent):
+ b[:, :, :, y, :] = a[:, :, :, -blend_extent + y, :] * (1 - y / blend_extent) + b[:, :, :, y, :] * (
+ y / blend_extent
+ )
+ return b
+
+ def blend_h(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
+ blend_extent = min(a.shape[-1], b.shape[-1], blend_extent)
+ for x in range(blend_extent):
+ b[:, :, :, :, x] = a[:, :, :, :, -blend_extent + x] * (1 - x / blend_extent) + b[:, :, :, :, x] * (
+ x / blend_extent
+ )
+ return b
+
+ def tiled_encode(self, x: torch.Tensor) -> AutoencoderKLOutput:
+ r"""Encode a batch of images using a tiled encoder.
+
+ Args:
+ x (`torch.Tensor`): Input batch of videos.
+
+ Returns:
+ `torch.Tensor`:
+ The latent representation of the encoded videos.
+ """
+ _, _, num_frames, height, width = x.shape
+ latent_height = height // self.spatial_compression_ratio
+ latent_width = width // self.spatial_compression_ratio
+
+ tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio
+ tile_latent_min_width = self.tile_sample_min_width // self.spatial_compression_ratio
+ tile_latent_stride_height = self.tile_sample_stride_height // self.spatial_compression_ratio
+ tile_latent_stride_width = self.tile_sample_stride_width // self.spatial_compression_ratio
+
+ blend_height = tile_latent_min_height - tile_latent_stride_height
+ blend_width = tile_latent_min_width - tile_latent_stride_width
+
+ # Split x into overlapping tiles and encode them separately.
+ # The tiles have an overlap to avoid seams between tiles.
+ rows = []
+ for i in range(0, height, self.tile_sample_stride_height):
+ row = []
+ for j in range(0, width, self.tile_sample_stride_width):
+ time = []
+ frame_range = 1 + (num_frames - 1) // 4
+ for k in range(frame_range):
+ if k == 0:
+ tile = x[:, :, :1, i : i + self.tile_sample_min_height, j : j + self.tile_sample_min_width]
+ else:
+ tile = x[
+ :,
+ :,
+ 1 + 4 * (k - 1) : 1 + 4 * k,
+ i : i + self.tile_sample_min_height,
+ j : j + self.tile_sample_min_width,
+ ]
+ tile = self.encoder(tile)
+ time.append(tile)
+ row.append(torch.cat(time, dim=2))
+ rows.append(row)
+
+ result_rows = []
+ for i, row in enumerate(rows):
+ result_row = []
+ for j, tile in enumerate(row):
+ # blend the above tile and the left tile
+ # to the current tile and add the current tile to the result row
+ if i > 0:
+ tile = self.blend_v(rows[i - 1][j], tile, blend_height)
+ if j > 0:
+ tile = self.blend_h(row[j - 1], tile, blend_width)
+ result_row.append(tile[:, :, :, :tile_latent_stride_height, :tile_latent_stride_width])
+ result_rows.append(torch.cat(result_row, dim=-1))
+
+ enc = torch.cat(result_rows, dim=3)[:, :, :, :latent_height, :latent_width]
+ return enc
+
+ def tiled_decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
+ r"""
+ Decode a batch of images using a tiled decoder.
+
+ Args:
+ z (`torch.Tensor`): Input batch of latent vectors.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~models.vae.DecoderOutput`] instead of a plain tuple.
+
+ Returns:
+ [`~models.vae.DecoderOutput`] or `tuple`:
+ If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is
+ returned.
+ """
+ _, _, num_frames, height, width = z.shape
+ sample_height = height * self.spatial_compression_ratio
+ sample_width = width * self.spatial_compression_ratio
+
+ tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio
+ tile_latent_min_width = self.tile_sample_min_width // self.spatial_compression_ratio
+ tile_latent_stride_height = self.tile_sample_stride_height // self.spatial_compression_ratio
+ tile_latent_stride_width = self.tile_sample_stride_width // self.spatial_compression_ratio
+
+ blend_height = self.tile_sample_min_height - self.tile_sample_stride_height
+ blend_width = self.tile_sample_min_width - self.tile_sample_stride_width
+
+ # Split z into overlapping tiles and decode them separately.
+ # The tiles have an overlap to avoid seams between tiles.
+ rows = []
+ for i in range(0, height, tile_latent_stride_height):
+ row = []
+ for j in range(0, width, tile_latent_stride_width):
+ time = []
+ for k in range(num_frames):
+ tile = z[:, :, k : k + 1, i : i + tile_latent_min_height, j : j + tile_latent_min_width]
+ decoded = self.decoder(tile)
+ time.append(decoded)
+ row.append(torch.cat(time, dim=2))
+ rows.append(row)
+
+ result_rows = []
+ for i, row in enumerate(rows):
+ result_row = []
+ for j, tile in enumerate(row):
+ # blend the above tile and the left tile
+ # to the current tile and add the current tile to the result row
+ if i > 0:
+ tile = self.blend_v(rows[i - 1][j], tile, blend_height)
+ if j > 0:
+ tile = self.blend_h(row[j - 1], tile, blend_width)
+ result_row.append(tile[:, :, :, : self.tile_sample_stride_height, : self.tile_sample_stride_width])
+ result_rows.append(torch.cat(result_row, dim=-1))
+
+ dec = torch.cat(result_rows, dim=3)[:, :, :, :sample_height, :sample_width]
+
+ if not return_dict:
+ return (dec,)
+ return DecoderOutput(sample=dec)
+
+ def forward(
+ self,
+ sample: torch.Tensor,
+ sample_posterior: bool = True,
+ return_dict: bool = True,
+ generator: Optional[torch.Generator] = None,
+ ) -> Union[DecoderOutput, torch.Tensor]:
+ """
+ Args:
+ sample (`torch.Tensor`): Input sample.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`DecoderOutput`] instead of a plain tuple.
+ """
+ x = sample
+ posterior = self.encode(x).latent_dist
+ if sample_posterior:
+ z = posterior.sample(generator=generator)
+ else:
+ z = posterior.mode()
+ dec = self.decode(z, return_dict=return_dict)
+ return dec
diff --git a/src/diffusers/models/embeddings.py b/src/diffusers/models/embeddings.py
index 4f268bfa018f..5d6300f8dca6 100644
--- a/src/diffusers/models/embeddings.py
+++ b/src/diffusers/models/embeddings.py
@@ -825,6 +825,8 @@ def get_3d_rotary_pos_embed(
grid_type: str = "linspace",
max_size: Optional[Tuple[int, int]] = None,
device: Optional[torch.device] = None,
+ center_grid_hw_indices: bool = False,
+ equal_split_ratio: Optional[int] = None,
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
"""
RoPE for video tokens with 3D structure.
@@ -871,10 +873,19 @@ def get_3d_rotary_pos_embed(
else:
raise ValueError("Invalid value passed for `grid_type`.")
- # Compute dimensions for each axis
- dim_t = embed_dim // 4
- dim_h = embed_dim // 8 * 3
- dim_w = embed_dim // 8 * 3
+ if center_grid_hw_indices:
+ # Center the grid height and width indices around zero
+ grid_h = grid_h - grid_h.max() / 2
+ grid_w = grid_w - grid_w.max() / 2
+
+ if equal_split_ratio is None:
+ dim_t = embed_dim // 4
+ dim_h = embed_dim // 8 * 3
+ dim_w = embed_dim // 8 * 3
+ else:
+ dim_t = embed_dim // equal_split_ratio
+ dim_h = embed_dim // equal_split_ratio
+ dim_w = embed_dim // equal_split_ratio
# Temporal frequencies
freqs_t = get_1d_rotary_pos_embed(dim_t, grid_t, theta=theta, use_real=True)
diff --git a/src/diffusers/models/transformers/__init__.py b/src/diffusers/models/transformers/__init__.py
index cc03a0ccbcdf..1b6680fa4413 100755
--- a/src/diffusers/models/transformers/__init__.py
+++ b/src/diffusers/models/transformers/__init__.py
@@ -28,6 +28,7 @@
from .transformer_hunyuan_video_framepack import HunyuanVideoFramepackTransformer3DModel
from .transformer_ltx import LTXVideoTransformer3DModel
from .transformer_lumina2 import Lumina2Transformer2DModel
+ from .transformer_magi1 import Magi1Transformer3DModel
from .transformer_mochi import MochiTransformer3DModel
from .transformer_omnigen import OmniGenTransformer2DModel
from .transformer_sd3 import SD3Transformer2DModel
diff --git a/src/diffusers/models/transformers/transformer_magi1.py b/src/diffusers/models/transformers/transformer_magi1.py
new file mode 100644
index 000000000000..6b1ff939a26a
--- /dev/null
+++ b/src/diffusers/models/transformers/transformer_magi1.py
@@ -0,0 +1,672 @@
+# Copyright 2025 The MAGI Team and The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import math
+from typing import Any, Dict, Optional, Tuple, Union
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from ...configuration_utils import ConfigMixin, register_to_config
+from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
+from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
+from ..attention import FeedForward
+from ..attention_processor import Attention
+from ..cache_utils import CacheMixin
+from ..embeddings import PixArtAlphaTextProjection, TimestepEmbedding, Timesteps, get_1d_rotary_pos_embed
+from ..modeling_outputs import Transformer2DModelOutput
+from ..modeling_utils import ModelMixin
+from ..normalization import FP32LayerNorm
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+
+class Magi1AttnProcessor2_0:
+ r"""
+ Processor for implementing MAGI-1 attention mechanism.
+
+ This processor handles both self-attention and cross-attention for the MAGI-1 model, following diffusers' standard
+ attention processor interface. It supports image conditioning for image-to-video generation tasks.
+ """
+
+ def __init__(self):
+ if not hasattr(F, "scaled_dot_product_attention"):
+ raise ImportError("Magi1AttnProcessor2_0 requires PyTorch 2.0. To use it, please upgrade PyTorch to 2.0.")
+
+ def __call__(
+ self,
+ attn: Attention,
+ hidden_states: torch.Tensor,
+ encoder_hidden_states: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ rotary_emb: Optional[torch.Tensor] = None,
+ ) -> torch.Tensor:
+ # Handle image conditioning if present for cross-attention
+ encoder_hidden_states_img = None
+ if attn.add_k_proj is not None and encoder_hidden_states is not None:
+ # Extract image conditioning from the concatenated encoder states
+ # The text encoder context length is typically 512 tokens
+ text_context_length = getattr(attn, "text_context_length", 512)
+ image_context_length = encoder_hidden_states.shape[1] - text_context_length
+ encoder_hidden_states_img = encoder_hidden_states[:, :image_context_length]
+ encoder_hidden_states = encoder_hidden_states[:, image_context_length:]
+
+ # For self-attention, use hidden_states as encoder_hidden_states
+ if encoder_hidden_states is None:
+ encoder_hidden_states = hidden_states
+
+ # Standard attention computation
+ query = attn.to_q(hidden_states)
+ key = attn.to_k(encoder_hidden_states)
+ value = attn.to_v(encoder_hidden_states)
+
+ # Apply normalization if available
+ if attn.norm_q is not None:
+ query = attn.norm_q(query)
+ if attn.norm_k is not None:
+ key = attn.norm_k(key)
+
+ # Reshape for multi-head attention
+ query = query.unflatten(2, (attn.heads, -1)).transpose(1, 2)
+ key = key.unflatten(2, (attn.heads, -1)).transpose(1, 2)
+ value = value.unflatten(2, (attn.heads, -1)).transpose(1, 2)
+
+ # Apply rotary embeddings if provided
+ if rotary_emb is not None:
+
+ def apply_rotary_emb(hidden_states: torch.Tensor, freqs: torch.Tensor):
+ dtype = torch.float32 if hidden_states.device.type == "mps" else torch.float64
+ x_rotated = torch.view_as_complex(hidden_states.to(dtype).unflatten(3, (-1, 2)))
+ x_out = torch.view_as_real(x_rotated * freqs).flatten(3, 4)
+ return x_out.type_as(hidden_states)
+
+ query = apply_rotary_emb(query, rotary_emb)
+ key = apply_rotary_emb(key, rotary_emb)
+
+ # Compute attention
+ hidden_states = F.scaled_dot_product_attention(
+ query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
+ )
+ hidden_states = hidden_states.transpose(1, 2).flatten(2, 3)
+
+ # Handle image conditioning (I2V task) for cross-attention
+ if encoder_hidden_states_img is not None:
+ key_img = attn.add_k_proj(encoder_hidden_states_img)
+ key_img = attn.norm_added_k(key_img)
+ value_img = attn.add_v_proj(encoder_hidden_states_img)
+
+ key_img = key_img.unflatten(2, (attn.heads, -1)).transpose(1, 2)
+ value_img = value_img.unflatten(2, (attn.heads, -1)).transpose(1, 2)
+
+ attn_output_img = F.scaled_dot_product_attention(
+ query, key_img, value_img, attn_mask=None, dropout_p=0.0, is_causal=False
+ )
+ attn_output_img = attn_output_img.transpose(1, 2).flatten(2, 3)
+ hidden_states = hidden_states + attn_output_img
+
+ # Apply output projection
+ hidden_states = attn.to_out[0](hidden_states)
+ hidden_states = attn.to_out[1](hidden_states)
+ return hidden_states
+
+
+class Magi1ImageEmbedding(torch.nn.Module):
+ """
+ Image embedding layer for the MAGI-1 model.
+
+ This module processes image conditioning features for image-to-video generation tasks. It applies layer
+ normalization, a feed-forward transformation, and optional positional embeddings to prepare image features for
+ cross-attention.
+
+ Args:
+ in_features (`int`): Input feature dimension.
+ out_features (`int`): Output feature dimension.
+ pos_embed_seq_len (`int`, optional): Sequence length for positional embeddings.
+ If provided, learnable positional embeddings will be added to the input.
+ """
+
+ def __init__(self, in_features: int, out_features: int, pos_embed_seq_len=None):
+ super().__init__()
+
+ self.norm1 = FP32LayerNorm(in_features)
+ self.ff = FeedForward(in_features, out_features, mult=1, activation_fn="gelu")
+ self.norm2 = FP32LayerNorm(out_features)
+ if pos_embed_seq_len is not None:
+ self.pos_embed = nn.Parameter(torch.zeros(1, pos_embed_seq_len, in_features))
+ else:
+ self.pos_embed = None
+
+ def forward(self, encoder_hidden_states_image: torch.Tensor) -> torch.Tensor:
+ if self.pos_embed is not None:
+ batch_size, seq_len, embed_dim = encoder_hidden_states_image.shape
+ encoder_hidden_states_image = encoder_hidden_states_image.view(-1, 2 * seq_len, embed_dim)
+ encoder_hidden_states_image = encoder_hidden_states_image + self.pos_embed
+
+ hidden_states = self.norm1(encoder_hidden_states_image)
+ hidden_states = self.ff(hidden_states)
+ hidden_states = self.norm2(hidden_states)
+ return hidden_states
+
+
+class CaptionEmbedder(nn.Module):
+ """
+ Embeds class labels into vector representations. Also handles label dropout for classifier-free guidance.
+ """
+
+ def __init__(self, caption_channels: int, hidden_size: int, caption_max_length: int):
+ super().__init__()
+
+ self.y_proj_xattn = nn.Sequential(nn.Linear(caption_channels, hidden_size), nn.SiLU())
+ self.y_proj_adaln = nn.Linear(caption_channels, int(hidden_size * 0.25))
+ self.null_caption_embedding = nn.Parameter(torch.empty(caption_max_length, caption_channels))
+
+ def caption_drop(self, caption, caption_dropout_mask):
+ """
+ Drops labels to enable classifier-free guidance.
+ caption.shape = (N, 1, cap_len, C)
+ """
+ dropped_caption = torch.where(
+ caption_dropout_mask[:, None, None, None], # (N, 1, 1, 1)
+ self.null_caption_embedding[None, None, :], # (1, 1, cap_len, C)
+ caption, # (N, 1, cap_len, C)
+ )
+ return dropped_caption
+
+ def caption_drop_single_token(self, caption_dropout_mask):
+ dropped_caption = torch.where(
+ caption_dropout_mask[:, None, None], # (N, 1, 1)
+ self.null_caption_embedding[None, -1, :], # (1, 1, C)
+ self.null_caption_embedding[None, -2, :], # (1, 1, C)
+ )
+ return dropped_caption # (N, 1, C)
+
+ def forward(self, caption, train, caption_dropout_mask=None):
+ if train and caption_dropout_mask is not None:
+ caption = self.caption_drop(caption, caption_dropout_mask)
+ caption_xattn = self.y_proj_xattn(caption)
+ if caption_dropout_mask is not None:
+ caption = self.caption_drop_single_token(caption_dropout_mask)
+
+ caption_adaln = self.y_proj_adaln(caption)
+ return caption_xattn, caption_adaln
+
+
+class Magi1TimeTextCaptionEmbedding(nn.Module):
+ """
+ Combined time, text, and image embedding module for the MAGI-1 model.
+
+ This module handles the encoding of three types of conditioning inputs:
+ 1. Timestep embeddings for diffusion process control
+ 2. Text embeddings for text-to-video generation
+ 3. Optional image embeddings for image-to-video generation
+
+ Args:
+ dim (`int`): Hidden dimension of the transformer model.
+ time_freq_dim (`int`): Dimension for sinusoidal time embeddings.
+ text_embed_dim (`int`): Input dimension of text embeddings.
+ image_embed_dim (`int`, optional): Input dimension of image embeddings.
+ pos_embed_seq_len (`int`, optional): Sequence length for image positional embeddings.
+ """
+
+ def __init__(
+ self,
+ dim: int,
+ time_freq_dim: int,
+ text_embed_dim: int,
+ image_embed_dim: Optional[int] = None,
+ pos_embed_seq_len: Optional[int] = None,
+ caption_channels: Optional[int] = None,
+ caption_max_length: Optional[int] = None,
+ ):
+ super().__init__()
+
+ self.timesteps_proj = Timesteps(num_channels=time_freq_dim, flip_sin_to_cos=True, downscale_freq_shift=0)
+ self.time_embedder = TimestepEmbedding(in_channels=time_freq_dim, time_embed_dim=dim)
+ # self.act_fn = nn.SiLU()
+ # self.time_proj = nn.Linear(dim, time_proj_dim)
+ self.caption_embedder = CaptionEmbedder(
+ in_channels=caption_channels, hidden_size=dim, caption_max_length=caption_max_length
+ )
+ self.text_embedder = PixArtAlphaTextProjection(text_embed_dim, dim, act_fn="gelu_tanh")
+
+ self.image_embedder = None
+ if image_embed_dim is not None:
+ self.image_embedder = Magi1ImageEmbedding(image_embed_dim, dim, pos_embed_seq_len=pos_embed_seq_len)
+
+ def forward(
+ self,
+ timestep: torch.Tensor,
+ encoder_hidden_states: torch.Tensor,
+ encoder_hidden_states_image: Optional[torch.Tensor] = None,
+ ):
+ timestep = self.timesteps_proj(timestep)
+
+ time_embedder_dtype = next(iter(self.time_embedder.parameters())).dtype
+ if timestep.dtype != time_embedder_dtype and time_embedder_dtype != torch.int8:
+ timestep = timestep.to(time_embedder_dtype)
+ temb = self.time_embedder(timestep).type_as(encoder_hidden_states)
+ #timestep_proj = self.time_proj(self.act_fn(temb))
+ y_xattn, y_adaln = self.caption_embedder(encoder_hidden_states, self.training, caption_dropout_mask)
+
+ encoder_hidden_states = self.text_embedder(encoder_hidden_states)
+ if encoder_hidden_states_image is not None:
+ encoder_hidden_states_image = self.image_embedder(encoder_hidden_states_image)
+
+ return temb, None, encoder_hidden_states, encoder_hidden_states_image
+
+
+class Magi1RotaryPosEmbed(nn.Module):
+ def __init__(
+ self, attention_head_dim: int, patch_size: Tuple[int, int, int], max_seq_len: int, theta: float = 10000.0
+ ):
+ super().__init__()
+
+ self.attention_head_dim = attention_head_dim
+ self.patch_size = patch_size
+ self.max_seq_len = max_seq_len
+
+ h_dim = w_dim = 2 * (attention_head_dim // 6)
+ t_dim = attention_head_dim - h_dim - w_dim
+
+ freqs = []
+ freqs_dtype = torch.float32 if torch.backends.mps.is_available() else torch.float64
+ for dim in [t_dim, h_dim, w_dim]:
+ freq = get_1d_rotary_pos_embed(
+ dim, max_seq_len, theta, use_real=False, repeat_interleave_real=False, freqs_dtype=freqs_dtype
+ )
+ freqs.append(freq)
+ self.freqs = torch.cat(freqs, dim=1)
+
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ batch_size, num_channels, num_frames, height, width = hidden_states.shape
+ p_t, p_h, p_w = self.patch_size
+ ppf, pph, ppw = num_frames // p_t, height // p_h, width // p_w
+
+ freqs = self.freqs.to(hidden_states.device)
+ freqs = freqs.split_with_sizes(
+ [
+ self.attention_head_dim // 2 - 2 * (self.attention_head_dim // 6),
+ self.attention_head_dim // 6,
+ self.attention_head_dim // 6,
+ ],
+ dim=1,
+ )
+
+ freqs_f = freqs[0][:ppf].view(ppf, 1, 1, -1).expand(ppf, pph, ppw, -1)
+ freqs_h = freqs[1][:pph].view(1, pph, 1, -1).expand(ppf, pph, ppw, -1)
+ freqs_w = freqs[2][:ppw].view(1, 1, ppw, -1).expand(ppf, pph, ppw, -1)
+ freqs = torch.cat([freqs_f, freqs_h, freqs_w], dim=-1).reshape(1, 1, ppf * pph * ppw, -1)
+ return freqs
+
+
+class Magi1TransformerBlock(nn.Module):
+ """
+ A transformer block used in the MAGI-1 model.
+
+ This block follows diffusers' design philosophy with separate self-attention (attn1) and cross-attention (attn2)
+ modules, while faithfully implementing the original MAGI-1 logic through appropriate parameter mapping during
+ conversion.
+
+ Args:
+ dim (`int`): The number of channels in the input and output.
+ ffn_dim (`int`): The number of channels in the feed-forward layer.
+ num_heads (`int`): The number of attention heads.
+ qk_norm (`str`): The type of normalization to apply to query and key projections.
+ cross_attn_norm (`bool`): Whether to apply normalization in cross-attention.
+ eps (`float`): The epsilon value for layer normalization.
+ added_kv_proj_dim (`Optional[int]`): Additional key-value projection dimension for image conditioning.
+ """
+
+ def __init__(
+ self,
+ dim: int,
+ ffn_dim: int,
+ num_heads: int,
+ qk_norm: str = "rms_norm_across_heads",
+ cross_attn_norm: bool = False,
+ eps: float = 1e-6,
+ added_kv_proj_dim: Optional[int] = None,
+ ):
+ super().__init__()
+
+ # 1. Self-attention
+ self.norm1 = FP32LayerNorm(dim, eps, elementwise_affine=False)
+ self.attn1 = Attention(
+ query_dim=dim,
+ heads=num_heads,
+ kv_heads=num_heads,
+ dim_head=dim // num_heads,
+ qk_norm=qk_norm,
+ eps=eps,
+ bias=True,
+ cross_attention_dim=None,
+ out_bias=True,
+ processor=Magi1AttnProcessor2_0(),
+ )
+
+ # 2. Cross-attention
+ self.attn2 = Attention(
+ query_dim=dim,
+ heads=num_heads,
+ kv_heads=num_heads,
+ dim_head=dim // num_heads,
+ qk_norm=qk_norm,
+ eps=eps,
+ bias=True,
+ cross_attention_dim=None,
+ out_bias=True,
+ added_kv_proj_dim=added_kv_proj_dim,
+ added_proj_bias=True,
+ processor=Magi1AttnProcessor2_0(),
+ )
+ self.norm2 = FP32LayerNorm(dim, eps, elementwise_affine=True) if cross_attn_norm else nn.Identity()
+
+ # 3. Feed-forward
+ self.ffn = FeedForward(dim, inner_dim=ffn_dim, activation_fn="gelu-approximate")
+ self.norm3 = FP32LayerNorm(dim, eps, elementwise_affine=False)
+
+ # Scale and shift table for AdaLN - 6 components for gating
+ self.scale_shift_table = nn.Parameter(torch.randn(1, 6, dim) / dim**0.5)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ encoder_hidden_states: torch.Tensor,
+ temb: torch.Tensor,
+ rotary_emb: torch.Tensor,
+ ) -> torch.Tensor:
+ shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = (
+ self.scale_shift_table + temb.float()
+ ).chunk(6, dim=1)
+
+ # 1. Self-attention
+ norm_hidden_states = (self.norm1(hidden_states.float()) * (1 + scale_msa) + shift_msa).type_as(hidden_states)
+ attn_output = self.attn1(hidden_states=norm_hidden_states, rotary_emb=rotary_emb)
+ hidden_states = (hidden_states.float() + attn_output * gate_msa).type_as(hidden_states)
+
+ # 2. Cross-attention
+ norm_hidden_states = self.norm2(hidden_states.float()).type_as(hidden_states)
+ attn_output = self.attn2(hidden_states=norm_hidden_states, encoder_hidden_states=encoder_hidden_states)
+ hidden_states = hidden_states + attn_output
+
+ # 3. Feed-forward
+ norm_hidden_states = (self.norm3(hidden_states.float()) * (1 + c_scale_msa) + c_shift_msa).type_as(
+ hidden_states
+ )
+ ff_output = self.ffn(norm_hidden_states)
+ hidden_states = (hidden_states.float() + ff_output.float() * c_gate_msa).type_as(hidden_states)
+
+ return hidden_states
+
+
+class Magi1Transformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin, CacheMixin):
+ r"""
+ A Transformer model for video-like data used in the Magi1 model.
+
+ This model implements a 3D transformer architecture for video generation with support for text conditioning and
+ optional image conditioning. The model uses rotary position embeddings and adaptive layer normalization for
+ temporal and spatial modeling.
+
+ Args:
+ patch_size (`Tuple[int]`, defaults to `(1, 2, 2)`):
+ 3D patch dimensions for video embedding (t_patch, h_patch, w_patch).
+ num_attention_heads (`int`, defaults to `16`):
+ The number of attention heads in each transformer block.
+ attention_head_dim (`int`, defaults to `64`):
+ The dimension of each attention head.
+ in_channels (`int`, defaults to `16`):
+ The number of input channels (from VAE latent space).
+ out_channels (`int`, defaults to `16`):
+ The number of output channels (to VAE latent space).
+ cross_attention_dim (`int`, defaults to `4096`):
+ The dimension of cross-attention (text encoder hidden size).
+ freq_dim (`int`, defaults to `256`):
+ Dimension for sinusoidal time embeddings.
+ ffn_dim (`int`, defaults to `4096`):
+ Intermediate dimension in feed-forward network.
+ num_layers (`int`, defaults to `34`):
+ The number of transformer layers to use.
+ cross_attn_norm (`bool`, defaults to `True`):
+ Enable cross-attention normalization.
+ qk_norm (`Optional[str]`, defaults to `"rms_norm_across_heads"`):
+ Type of query/key normalization to use.
+ eps (`float`, defaults to `1e-6`):
+ Epsilon value for normalization layers.
+ use_linear_projection (`bool`, defaults to `True`):
+ Whether to use linear projection for patch embedding.
+ upcast_attention (`bool`, defaults to `False`):
+ Whether to upcast attention computation to float32.
+ image_embed_dim (`Optional[int]`, defaults to `None`):
+ Dimension of image embeddings for image-to-video tasks.
+ rope_max_seq_len (`int`, defaults to `1024`):
+ Maximum sequence length for rotary position embeddings.
+ pos_embed_seq_len (`Optional[int]`, defaults to `None`):
+ Sequence length for positional embeddings in image conditioning.
+ """
+
+ _supports_gradient_checkpointing = True
+ _skip_layerwise_casting_patterns = ["patch_embedding", "condition_embedder", "rope"]
+ _no_split_modules = ["Magi1TransformerBlock"]
+ _keep_in_fp32_modules = ["condition_embedder", "scale_shift_table", "norm_out"]
+ _keys_to_ignore_on_load_unexpected = ["norm_added_q"]
+ _repeated_blocks = ["Magi1TransformerBlock"]
+
+ @register_to_config
+ def __init__(
+ self,
+ patch_size: Tuple[int] = (1, 2, 2),
+ num_attention_heads: int = 16,
+ attention_head_dim: int = 64,
+ in_channels: int = 16,
+ out_channels: int = 16,
+ cross_attention_dim: int = 4096,
+ freq_dim: int = 256,
+ ffn_dim: int = 12288,
+ num_layers: int = 34,
+ cross_attn_norm: bool = True,
+ qk_norm: Optional[str] = "rms_norm_across_heads",
+ eps: float = 1e-6,
+ use_linear_projection: bool = True,
+ upcast_attention: bool = False,
+ image_embed_dim: Optional[int] = None,
+ rope_max_seq_len: int = 1024,
+ pos_embed_seq_len: Optional[int] = None,
+ caption_channels: Optional[int] = None,
+ caption_max_length: Optional[int] = None,
+ ) -> None:
+ super().__init__()
+
+ inner_dim = num_attention_heads * attention_head_dim
+ self.inner_dim = inner_dim
+ out_channels = out_channels or in_channels
+
+ # Validate configuration
+ if inner_dim != num_attention_heads * attention_head_dim:
+ raise ValueError(
+ f"inner_dim ({inner_dim}) should be equal to num_attention_heads ({num_attention_heads}) * "
+ f"attention_head_dim ({attention_head_dim})"
+ )
+
+ if any(p <= 0 for p in patch_size):
+ raise ValueError(f"All patch_size values must be positive, got {patch_size}")
+
+ if num_layers <= 0:
+ raise ValueError(f"num_layers must be positive, got {num_layers}")
+
+ if freq_dim <= 0:
+ raise ValueError(f"freq_dim must be positive, got {freq_dim}")
+
+ if image_embed_dim is not None and image_embed_dim <= 0:
+ raise ValueError(f"image_embed_dim must be positive when provided, got {image_embed_dim}")
+
+ # 1. Patch & position embedding
+ self.rope = Magi1RotaryPosEmbed(attention_head_dim, patch_size, rope_max_seq_len)
+
+ if use_linear_projection:
+ self.patch_embedding = nn.Linear(in_channels * math.prod(patch_size), inner_dim)
+ else:
+ self.patch_embedding = nn.Conv3d(
+ in_channels, inner_dim, kernel_size=patch_size, stride=patch_size, bias=False
+ )
+
+ # 2. Condition embeddings
+ self.condition_embedder = Magi1TimeTextCaptionEmbedding(
+ dim=inner_dim,
+ time_freq_dim=freq_dim,
+ #time_proj_dim=inner_dim * 6,
+ text_embed_dim=cross_attention_dim,
+ image_embed_dim=image_embed_dim,
+ pos_embed_seq_len=pos_embed_seq_len,
+ caption_channels=caption_channels,
+ caption_max_length=caption_max_length,
+ )
+
+ # 3. Transformer blocks
+ # For image-to-video tasks, we may need additional projections
+ added_kv_proj_dim = image_embed_dim if image_embed_dim is not None else None
+
+ self.blocks = nn.ModuleList(
+ [
+ Magi1TransformerBlock(
+ inner_dim, ffn_dim, num_attention_heads, qk_norm, cross_attn_norm, eps, added_kv_proj_dim
+ )
+ for _ in range(num_layers)
+ ]
+ )
+
+ # 4. Output norm & projection
+ self.norm_out = FP32LayerNorm(inner_dim, eps, elementwise_affine=False)
+ self.proj_out = nn.Linear(inner_dim, out_channels * math.prod(patch_size), bias=False)
+ self.scale_shift_table = nn.Parameter(torch.randn(1, 2, inner_dim) / inner_dim**0.5)
+
+ self.gradient_checkpointing = False
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ timestep: torch.LongTensor,
+ encoder_hidden_states: torch.Tensor,
+ encoder_hidden_states_image: Optional[torch.Tensor] = None,
+ return_dict: bool = True,
+ attention_kwargs: Optional[Dict[str, Any]] = None,
+ ) -> Union[torch.Tensor, Dict[str, torch.Tensor]]:
+ if attention_kwargs is not None:
+ attention_kwargs = attention_kwargs.copy()
+ lora_scale = attention_kwargs.pop("scale", 1.0)
+ else:
+ lora_scale = 1.0
+
+ if USE_PEFT_BACKEND:
+ # weight the lora layers by setting `lora_scale` for each PEFT layer
+ scale_lora_layers(self, lora_scale)
+ else:
+ if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None:
+ logger.warning(
+ "Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective."
+ )
+
+ batch_size, num_channels, num_frames, height, width = hidden_states.shape
+ p_t, p_h, p_w = self.config.patch_size
+ post_patch_num_frames = num_frames // p_t
+ post_patch_height = height // p_h
+ post_patch_width = width // p_w
+
+ rotary_emb = self.rope(hidden_states)
+
+ # Patch embedding - handle both conv3d and linear projection
+ if self.config.use_linear_projection:
+ # For linear projection, we need to patchify first
+ batch_size, num_channels, num_frames, height, width = hidden_states.shape
+ p_t, p_h, p_w = self.config.patch_size
+
+ # Patchify: (B, C, T, H, W) -> (B, T//p_t, H//p_h, W//p_w, C*p_t*p_h*p_w)
+ hidden_states = hidden_states.unfold(2, p_t, p_t).unfold(3, p_h, p_h).unfold(4, p_w, p_w)
+ hidden_states = hidden_states.contiguous().view(
+ batch_size, num_frames // p_t, height // p_h, width // p_w, num_channels * p_t * p_h * p_w
+ )
+ # Reshape to sequence: (B, T*H*W, C*p_t*p_h*p_w)
+ hidden_states = hidden_states.flatten(1, 3)
+ # Apply linear projection: (B, T*H*W, inner_dim)
+ hidden_states = self.patch_embedding(hidden_states)
+ else:
+ # For conv3d projection
+ hidden_states = self.patch_embedding(hidden_states)
+ hidden_states = hidden_states.flatten(2).transpose(1, 2)
+
+ temb, timestep_proj, encoder_hidden_states, encoder_hidden_states_image = self.condition_embedder(
+ timestep, encoder_hidden_states, encoder_hidden_states_image
+ )
+ timestep_proj = timestep_proj.unflatten(1, (6, -1))
+
+ if encoder_hidden_states_image is not None:
+ encoder_hidden_states = torch.concat([encoder_hidden_states_image, encoder_hidden_states], dim=1)
+
+ # 4. Transformer blocks
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
+ for block in self.blocks:
+ hidden_states = self._gradient_checkpointing_func(
+ block, hidden_states, encoder_hidden_states, timestep_proj, rotary_emb
+ )
+ else:
+ for block in self.blocks:
+ hidden_states = block(hidden_states, encoder_hidden_states, timestep_proj, rotary_emb)
+
+ # 5. Output norm, projection & unpatchify
+ shift, scale = (self.scale_shift_table + temb.unsqueeze(1)).chunk(2, dim=1)
+
+ # Move the shift and scale tensors to the same device as hidden_states.
+ # When using multi-GPU inference via accelerate these will be on the
+ # first device rather than the last device, which hidden_states ends up on.
+ shift = shift.to(hidden_states.device)
+ scale = scale.to(hidden_states.device)
+
+ hidden_states = (self.norm_out(hidden_states.float()) * (1 + scale) + shift).type_as(hidden_states)
+ hidden_states = self.proj_out(hidden_states)
+
+ # Unpatchify: convert from sequence back to video format
+ hidden_states = hidden_states.reshape(
+ batch_size, post_patch_num_frames, post_patch_height, post_patch_width, -1
+ )
+
+ # Rearrange patches: (B, T//p_t, H//p_h, W//p_w, C*p_t*p_h*p_w) -> (B, C, T, H, W)
+ p_t, p_h, p_w = self.config.patch_size
+ hidden_states = hidden_states.view(
+ batch_size,
+ post_patch_num_frames,
+ post_patch_height,
+ post_patch_width,
+ self.config.out_channels,
+ p_t,
+ p_h,
+ p_w,
+ )
+ hidden_states = hidden_states.permute(0, 4, 1, 5, 2, 6, 3, 7)
+ output = hidden_states.contiguous().view(
+ batch_size,
+ self.config.out_channels,
+ post_patch_num_frames * p_t,
+ post_patch_height * p_h,
+ post_patch_width * p_w,
+ )
+
+ if USE_PEFT_BACKEND:
+ # remove `lora_scale` from each PEFT layer
+ unscale_lora_layers(self, lora_scale)
+
+ if not return_dict:
+ return (output,)
+
+ return Transformer2DModelOutput(sample=output)
diff --git a/src/diffusers/pipelines/__init__.py b/src/diffusers/pipelines/__init__.py
index 1904c029997b..9e804628aa1b 100644
--- a/src/diffusers/pipelines/__init__.py
+++ b/src/diffusers/pipelines/__init__.py
@@ -28,6 +28,7 @@
"deprecated": [],
"latent_diffusion": [],
"ledits_pp": [],
+ "magi1": [],
"marigold": [],
"pag": [],
"stable_diffusion": [],
@@ -291,6 +292,7 @@
"MarigoldNormalsPipeline",
]
)
+ _import_structure["magi1"] = ["Magi1Pipeline", "Magi1ImageToVideoPipeline", "Magi1VideoToVideoPipeline"]
_import_structure["mochi"] = ["MochiPipeline"]
_import_structure["musicldm"] = ["MusicLDMPipeline"]
_import_structure["omnigen"] = ["OmniGenPipeline"]
@@ -666,6 +668,7 @@
from .ltx import LTXConditionPipeline, LTXImageToVideoPipeline, LTXLatentUpsamplePipeline, LTXPipeline
from .lumina import LuminaPipeline, LuminaText2ImgPipeline
from .lumina2 import Lumina2Pipeline, Lumina2Text2ImgPipeline
+ from .magi1 import Magi1ImageToVideoPipeline, Magi1Pipeline, Magi1VideoToVideoPipeline
from .marigold import (
MarigoldDepthPipeline,
MarigoldIntrinsicsPipeline,
diff --git a/src/diffusers/pipelines/magi1/__init__.py b/src/diffusers/pipelines/magi1/__init__.py
new file mode 100644
index 000000000000..5fc6735f6357
--- /dev/null
+++ b/src/diffusers/pipelines/magi1/__init__.py
@@ -0,0 +1,53 @@
+from typing import TYPE_CHECKING
+
+from ...utils import (
+ DIFFUSERS_SLOW_IMPORT,
+ OptionalDependencyNotAvailable,
+ _LazyModule,
+ get_objects_from_module,
+ is_torch_available,
+ is_transformers_available,
+)
+
+
+_dummy_objects = {}
+_import_structure = {}
+
+
+try:
+ if not (is_transformers_available() and is_torch_available()):
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ from ...utils import dummy_torch_and_transformers_objects # noqa F403
+
+ _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
+else:
+ _import_structure["pipeline_magi1"] = ["Magi1Pipeline"]
+ _import_structure["pipeline_magi1_i2v"] = ["Magi1ImageToVideoPipeline"]
+ _import_structure["pipeline_magi1_v2v"] = ["Magi1VideoToVideoPipeline"]
+ _import_structure["pipeline_output"] = ["Magi1PipelineOutput"]
+if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
+ try:
+ if not (is_transformers_available() and is_torch_available()):
+ raise OptionalDependencyNotAvailable()
+
+ except OptionalDependencyNotAvailable:
+ from ...utils.dummy_torch_and_transformers_objects import *
+ else:
+ from .pipeline_magi1 import Magi1Pipeline
+ from .pipeline_magi1_i2v import Magi1ImageToVideoPipeline
+ from .pipeline_magi1_v2v import Magi1VideoToVideoPipeline
+ from .pipeline_output import Magi1PipelineOutput
+
+else:
+ import sys
+
+ sys.modules[__name__] = _LazyModule(
+ __name__,
+ globals()["__file__"],
+ _import_structure,
+ module_spec=__spec__,
+ )
+
+ for name, value in _dummy_objects.items():
+ setattr(sys.modules[__name__], name, value)
diff --git a/src/diffusers/pipelines/magi1/pipeline_magi1.py b/src/diffusers/pipelines/magi1/pipeline_magi1.py
new file mode 100644
index 000000000000..a7ac772a02c1
--- /dev/null
+++ b/src/diffusers/pipelines/magi1/pipeline_magi1.py
@@ -0,0 +1,598 @@
+# Copyright 2025 The SandAI Team and The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import html
+import re
+from typing import Any, Callable, Dict, List, Optional, Union
+
+import ftfy
+import torch
+from transformers import AutoTokenizer, UMT5EncoderModel
+
+from ...callbacks import MultiPipelineCallbacks, PipelineCallback
+
+# from ...loaders import Magi1LoraLoaderMixin
+from ...models import AutoencoderKLMagi1, Magi1Transformer3DModel
+from ...schedulers import FlowMatchEulerDiscreteScheduler
+from ...utils import is_ftfy_available, is_torch_xla_available, logging, replace_example_docstring
+from ...utils.torch_utils import randn_tensor
+from ...video_processor import VideoProcessor
+from ..pipeline_utils import DiffusionPipeline
+from .pipeline_output import Magi1PipelineOutput
+
+
+if is_torch_xla_available():
+ import torch_xla.core.xla_model as xm
+
+ XLA_AVAILABLE = True
+else:
+ XLA_AVAILABLE = False
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+if is_ftfy_available():
+ import ftfy
+
+EXAMPLE_DOC_STRING = """
+ Examples:
+ ```python
+ >>> import torch
+ >>> from diffusers.utils import export_to_video
+ >>> from diffusers import AutoencoderKLMagi1, Magi1Pipeline
+ >>> from diffusers.schedulers.scheduling_unipc_multistep import UniPCMultistepScheduler
+
+ >>> model_id = "SandAI/Magi1-T2V-14B-480P-Diffusers"
+ >>> vae = AutoencoderKLMagi1.from_pretrained(model_id, subfolder="vae", torch_dtype=torch.float32)
+ >>> pipe = Magi1Pipeline.from_pretrained(model_id, vae=vae, torch_dtype=torch.bfloat16)
+ >>> flow_shift = 5.0 # 5.0 for 720P, 3.0 for 480P
+ >>> pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config, flow_shift=flow_shift)
+ >>> pipe.to("cuda")
+
+ >>> prompt = "A cat and a dog baking a cake together in a kitchen. The cat is carefully measuring flour, while the dog is stirring the batter with a wooden spoon. The kitchen is cozy, with sunlight streaming through the window."
+ >>> negative_prompt = "Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards"
+
+ >>> output = pipe(
+ ... prompt=prompt,
+ ... negative_prompt=negative_prompt,
+ ... height=720,
+ ... width=1280,
+ ... num_frames=81,
+ ... guidance_scale=5.0,
+ ... ).frames[0]
+ >>> export_to_video(output, "output.mp4", fps=16)
+ ```
+"""
+
+
+def basic_clean(text):
+ text = ftfy.fix_text(text)
+ text = html.unescape(html.unescape(text))
+ return text.strip()
+
+
+def whitespace_clean(text):
+ text = re.sub(r"\s+", " ", text)
+ text = text.strip()
+ return text
+
+
+def prompt_clean(text):
+ text = whitespace_clean(basic_clean(text))
+ return text
+
+
+class Magi1Pipeline(DiffusionPipeline): # , Magi1LoraLoaderMixin):
+ r"""
+ Pipeline for text-to-video generation using Magi1.
+
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods
+ implemented for all pipelines (downloading, saving, running on a particular device, etc.).
+
+ Args:
+ tokenizer ([`T5Tokenizer`]):
+ Tokenizer from [T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5Tokenizer),
+ specifically the [google/umt5-xxl](https://huggingface.co/google/umt5-xxl) variant.
+ text_encoder ([`T5EncoderModel`]):
+ [T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5EncoderModel), specifically
+ the [google/umt5-xxl](https://huggingface.co/google/umt5-xxl) variant.
+ transformer ([`Magi1Transformer3DModel`]):
+ Conditional Transformer to denoise the input latents.
+ scheduler ([`UniPCMultistepScheduler`]):
+ A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
+ vae ([`AutoencoderKLMagi1`]):
+ Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations.
+ """
+
+ model_cpu_offload_seq = "text_encoder->transformer->vae"
+ _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"]
+
+ def __init__(
+ self,
+ tokenizer: AutoTokenizer,
+ text_encoder: UMT5EncoderModel,
+ transformer: Magi1Transformer3DModel,
+ vae: AutoencoderKLMagi1,
+ scheduler: FlowMatchEulerDiscreteScheduler,
+ ):
+ super().__init__()
+
+ self.register_modules(
+ vae=vae,
+ text_encoder=text_encoder,
+ tokenizer=tokenizer,
+ transformer=transformer,
+ scheduler=scheduler,
+ )
+
+ self.vae_scale_factor_temporal = self.vae.temporal_compression_ratio if getattr(self, "vae", None) else 4
+ self.vae_scale_factor_spatial = self.vae.spatial_compression_ratio if getattr(self, "vae", None) else 8
+ self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial)
+
+ def _get_t5_prompt_embeds(
+ self,
+ prompt: Union[str, List[str]] = None,
+ num_videos_per_prompt: int = 1,
+ max_sequence_length: int = 226,
+ device: Optional[torch.device] = None,
+ dtype: Optional[torch.dtype] = None,
+ ):
+ device = device or self._execution_device
+ dtype = dtype or self.text_encoder.dtype
+
+ prompt = [prompt] if isinstance(prompt, str) else prompt
+ prompt = [prompt_clean(u) for u in prompt]
+ batch_size = len(prompt)
+
+ text_inputs = self.tokenizer(
+ prompt,
+ padding="max_length",
+ max_length=max_sequence_length,
+ truncation=True,
+ add_special_tokens=True,
+ return_attention_mask=True,
+ return_tensors="pt",
+ )
+ text_input_ids, mask = text_inputs.input_ids, text_inputs.attention_mask
+ seq_lens = mask.gt(0).sum(dim=1).long()
+
+ prompt_embeds = self.text_encoder(text_input_ids.to(device), mask.to(device)).last_hidden_state
+ prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
+ prompt_embeds = [u[:v] for u, v in zip(prompt_embeds, seq_lens)]
+ prompt_embeds = torch.stack(
+ [torch.cat([u, u.new_zeros(max_sequence_length - u.size(0), u.size(1))]) for u in prompt_embeds], dim=0
+ )
+
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
+ _, seq_len, _ = prompt_embeds.shape
+ prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1)
+ prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1)
+
+ return prompt_embeds
+
+ def encode_prompt(
+ self,
+ prompt: Union[str, List[str]],
+ negative_prompt: Optional[Union[str, List[str]]] = None,
+ do_classifier_free_guidance: bool = True,
+ num_videos_per_prompt: int = 1,
+ prompt_embeds: Optional[torch.Tensor] = None,
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
+ max_sequence_length: int = 226,
+ device: Optional[torch.device] = None,
+ dtype: Optional[torch.dtype] = None,
+ ):
+ r"""
+ Encodes the prompt into text encoder hidden states.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ prompt to be encoded
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
+ less than `1`).
+ do_classifier_free_guidance (`bool`, *optional*, defaults to `True`):
+ Whether to use classifier free guidance or not.
+ num_videos_per_prompt (`int`, *optional*, defaults to 1):
+ Number of videos that should be generated per prompt. torch device to place the resulting embeddings on
+ prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
+ argument.
+ device: (`torch.device`, *optional*):
+ torch device
+ dtype: (`torch.dtype`, *optional*):
+ torch dtype
+ """
+ device = device or self._execution_device
+
+ prompt = [prompt] if isinstance(prompt, str) else prompt
+ if prompt is not None:
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ if prompt_embeds is None:
+ prompt_embeds = self._get_t5_prompt_embeds(
+ prompt=prompt,
+ num_videos_per_prompt=num_videos_per_prompt,
+ max_sequence_length=max_sequence_length,
+ device=device,
+ dtype=dtype,
+ )
+
+ if do_classifier_free_guidance and negative_prompt_embeds is None:
+ negative_prompt = negative_prompt or ""
+ negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
+
+ if prompt is not None and type(prompt) is not type(negative_prompt):
+ raise TypeError(
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
+ f" {type(prompt)}."
+ )
+ elif batch_size != len(negative_prompt):
+ raise ValueError(
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
+ " the batch size of `prompt`."
+ )
+
+ negative_prompt_embeds = self._get_t5_prompt_embeds(
+ prompt=negative_prompt,
+ num_videos_per_prompt=num_videos_per_prompt,
+ max_sequence_length=max_sequence_length,
+ device=device,
+ dtype=dtype,
+ )
+
+ return prompt_embeds, negative_prompt_embeds
+
+ def check_inputs(
+ self,
+ prompt,
+ negative_prompt,
+ height,
+ width,
+ prompt_embeds=None,
+ negative_prompt_embeds=None,
+ callback_on_step_end_tensor_inputs=None,
+ ):
+ if height % 16 != 0 or width % 16 != 0:
+ raise ValueError(f"`height` and `width` have to be divisible by 16 but are {height} and {width}.")
+
+ if callback_on_step_end_tensor_inputs is not None and not all(
+ k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
+ ):
+ raise ValueError(
+ f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
+ )
+
+ if prompt is not None and prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
+ " only forward one of the two."
+ )
+ elif negative_prompt is not None and negative_prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`: {negative_prompt_embeds}. Please make sure to"
+ " only forward one of the two."
+ )
+ elif prompt is None and prompt_embeds is None:
+ raise ValueError(
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
+ )
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
+ elif negative_prompt is not None and (
+ not isinstance(negative_prompt, str) and not isinstance(negative_prompt, list)
+ ):
+ raise ValueError(f"`negative_prompt` has to be of type `str` or `list` but is {type(negative_prompt)}")
+
+ def prepare_latents(
+ self,
+ batch_size: int,
+ num_channels_latents: int = 16,
+ height: int = 480,
+ width: int = 832,
+ num_frames: int = 81,
+ dtype: Optional[torch.dtype] = None,
+ device: Optional[torch.device] = None,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ latents: Optional[torch.Tensor] = None,
+ ) -> torch.Tensor:
+ if latents is not None:
+ return latents.to(device=device, dtype=dtype)
+
+ num_latent_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1
+ shape = (
+ batch_size,
+ num_channels_latents,
+ num_latent_frames,
+ int(height) // self.vae_scale_factor_spatial,
+ int(width) // self.vae_scale_factor_spatial,
+ )
+ if isinstance(generator, list) and len(generator) != batch_size:
+ raise ValueError(
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
+ )
+
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
+ return latents
+
+ @property
+ def guidance_scale(self):
+ return self._guidance_scale
+
+ @property
+ def do_classifier_free_guidance(self):
+ return self._guidance_scale > 1.0
+
+ @property
+ def num_timesteps(self):
+ return self._num_timesteps
+
+ @property
+ def current_timestep(self):
+ return self._current_timestep
+
+ @property
+ def interrupt(self):
+ return self._interrupt
+
+ @property
+ def attention_kwargs(self):
+ return self._attention_kwargs
+
+ @torch.no_grad()
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
+ def __call__(
+ self,
+ prompt: Union[str, List[str]] = None,
+ negative_prompt: Union[str, List[str]] = None,
+ height: int = 480,
+ width: int = 832,
+ num_frames: int = 81,
+ num_inference_steps: int = 50,
+ guidance_scale: float = 5.0,
+ num_videos_per_prompt: Optional[int] = 1,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ latents: Optional[torch.Tensor] = None,
+ prompt_embeds: Optional[torch.Tensor] = None,
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
+ output_type: Optional[str] = "np",
+ return_dict: bool = True,
+ attention_kwargs: Optional[Dict[str, Any]] = None,
+ callback_on_step_end: Optional[
+ Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
+ ] = None,
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
+ max_sequence_length: int = 512,
+ ):
+ r"""
+ The call function to the pipeline for generation.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts to guide the image generation. If not defined, pass `prompt_embeds` instead.
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts to avoid during image generation. If not defined, pass `negative_prompt_embeds`
+ instead. Ignored when not using guidance (`guidance_scale` < `1`).
+ height (`int`, defaults to `480`):
+ The height in pixels of the generated image.
+ width (`int`, defaults to `832`):
+ The width in pixels of the generated image.
+ num_frames (`int`, defaults to `81`):
+ The number of frames in the generated video.
+ num_inference_steps (`int`, defaults to `50`):
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+ expense of slower inference.
+ guidance_scale (`float`, defaults to `5.0`):
+ Guidance scale as defined in [Classifier-Free Diffusion
+ Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2.
+ of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting
+ `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to
+ the text `prompt`, usually at the expense of lower image quality.
+ num_videos_per_prompt (`int`, *optional*, defaults to 1):
+ The number of images to generate per prompt.
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
+ A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
+ generation deterministic.
+ latents (`torch.Tensor`, *optional*):
+ Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
+ tensor is generated by sampling using the supplied random `generator`.
+ prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not
+ provided, text embeddings are generated from the `prompt` input argument.
+ output_type (`str`, *optional*, defaults to `"np"`):
+ The output format of the generated image. Choose between `PIL.Image` or `np.array`.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`WanPipelineOutput`] instead of a plain tuple.
+ attention_kwargs (`dict`, *optional*):
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
+ `self.processor` in
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
+ callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*):
+ A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of
+ each denoising step during the inference. with the following arguments: `callback_on_step_end(self:
+ DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a
+ list of all tensors as specified by `callback_on_step_end_tensor_inputs`.
+ callback_on_step_end_tensor_inputs (`List`, *optional*):
+ The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
+ will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
+ `._callback_tensor_inputs` attribute of your pipeline class.
+ max_sequence_length (`int`, defaults to `512`):
+ The maximum sequence length of the text encoder. If the prompt is longer than this, it will be
+ truncated. If the prompt is shorter, it will be padded to this length.
+
+ Examples:
+
+ Returns:
+ [`~Magi1PipelineOutput`] or `tuple`:
+ If `return_dict` is `True`, [`Magi1PipelineOutput`] is returned, otherwise a `tuple` is returned where
+ the first element is a list with the generated images and the second element is a list of `bool`s
+ indicating whether the corresponding generated image contains "not-safe-for-work" (nsfw) content.
+ """
+
+ if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
+ callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
+
+ # 1. Check inputs. Raise error if not correct
+ self.check_inputs(
+ prompt,
+ negative_prompt,
+ height,
+ width,
+ prompt_embeds,
+ negative_prompt_embeds,
+ callback_on_step_end_tensor_inputs,
+ )
+
+ if num_frames % self.vae_scale_factor_temporal != 1:
+ logger.warning(
+ f"`num_frames - 1` has to be divisible by {self.vae_scale_factor_temporal}. Rounding to the nearest number."
+ )
+ num_frames = num_frames // self.vae_scale_factor_temporal * self.vae_scale_factor_temporal + 1
+ num_frames = max(num_frames, 1)
+
+ self._guidance_scale = guidance_scale
+ self._attention_kwargs = attention_kwargs
+ self._current_timestep = None
+ self._interrupt = False
+
+ device = self._execution_device
+
+ # 2. Define call parameters
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ # 3. Encode input prompt
+ prompt_embeds, negative_prompt_embeds = self.encode_prompt(
+ prompt=prompt,
+ negative_prompt=negative_prompt,
+ do_classifier_free_guidance=self.do_classifier_free_guidance,
+ num_videos_per_prompt=num_videos_per_prompt,
+ prompt_embeds=prompt_embeds,
+ negative_prompt_embeds=negative_prompt_embeds,
+ max_sequence_length=max_sequence_length,
+ device=device,
+ )
+
+ transformer_dtype = self.transformer.dtype
+ prompt_embeds = prompt_embeds.to(transformer_dtype)
+ if negative_prompt_embeds is not None:
+ negative_prompt_embeds = negative_prompt_embeds.to(transformer_dtype)
+
+ # 4. Prepare timesteps
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
+ timesteps = self.scheduler.timesteps
+
+ # 5. Prepare latent variables
+ num_channels_latents = self.transformer.config.in_channels
+ latents = self.prepare_latents(
+ batch_size * num_videos_per_prompt,
+ num_channels_latents,
+ height,
+ width,
+ num_frames,
+ torch.float32,
+ device,
+ generator,
+ latents,
+ )
+
+ # 6. Denoising loop
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
+ self._num_timesteps = len(timesteps)
+
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
+ for i, t in enumerate(timesteps):
+ if self.interrupt:
+ continue
+
+ self._current_timestep = t
+ latent_model_input = latents.to(transformer_dtype)
+ timestep = t.expand(latents.shape[0])
+
+ noise_pred = self.transformer(
+ hidden_states=latent_model_input,
+ timestep=timestep,
+ encoder_hidden_states=prompt_embeds,
+ attention_kwargs=attention_kwargs,
+ return_dict=False,
+ )[0]
+
+ if self.do_classifier_free_guidance:
+ noise_uncond = self.transformer(
+ hidden_states=latent_model_input,
+ timestep=timestep,
+ encoder_hidden_states=negative_prompt_embeds,
+ attention_kwargs=attention_kwargs,
+ return_dict=False,
+ )[0]
+ noise_pred = noise_uncond + guidance_scale * (noise_pred - noise_uncond)
+
+ # compute the previous noisy sample x_t -> x_t-1
+ latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
+
+ if callback_on_step_end is not None:
+ callback_kwargs = {}
+ for k in callback_on_step_end_tensor_inputs:
+ callback_kwargs[k] = locals()[k]
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
+
+ latents = callback_outputs.pop("latents", latents)
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
+ negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
+
+ # call the callback, if provided
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
+ progress_bar.update()
+
+ if XLA_AVAILABLE:
+ xm.mark_step()
+
+ self._current_timestep = None
+
+ if not output_type == "latent":
+ latents = latents.to(self.vae.dtype)
+ latents_mean = (
+ torch.tensor(self.vae.config.latents_mean)
+ .view(1, self.vae.config.z_dim, 1, 1, 1)
+ .to(latents.device, latents.dtype)
+ )
+ latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to(
+ latents.device, latents.dtype
+ )
+ latents = latents / latents_std + latents_mean
+ video = self.vae.decode(latents, return_dict=False)[0]
+ video = self.video_processor.postprocess_video(video, output_type=output_type)
+ else:
+ video = latents
+
+ # Offload all models
+ self.maybe_free_model_hooks()
+
+ if not return_dict:
+ return (video,)
+
+ return Magi1PipelineOutput(frames=video)
diff --git a/src/diffusers/pipelines/magi1/pipeline_magi1_i2v.py b/src/diffusers/pipelines/magi1/pipeline_magi1_i2v.py
new file mode 100644
index 000000000000..667e9467fda4
--- /dev/null
+++ b/src/diffusers/pipelines/magi1/pipeline_magi1_i2v.py
@@ -0,0 +1,743 @@
+# Copyright 2025 The SandAI Team and The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import html
+from typing import Any, Callable, Dict, List, Optional, Tuple, Union
+
+import PIL
+import regex as re
+import torch
+from transformers import AutoTokenizer, CLIPImageProcessor, CLIPVisionModel, UMT5EncoderModel
+
+from ...callbacks import MultiPipelineCallbacks, PipelineCallback
+from ...image_processor import PipelineImageInput
+from ...loaders import Magi1LoraLoaderMixin
+from ...models import AutoencoderKLMagi1, Magi1Transformer3DModel
+from ...schedulers import FlowMatchEulerDiscreteScheduler
+from ...utils import is_ftfy_available, is_torch_xla_available, logging, replace_example_docstring
+from ...utils.torch_utils import randn_tensor
+from ...video_processor import VideoProcessor
+from ..pipeline_utils import DiffusionPipeline
+from .pipeline_output import Magi1PipelineOutput
+
+
+if is_torch_xla_available():
+ import torch_xla.core.xla_model as xm
+
+ XLA_AVAILABLE = True
+else:
+ XLA_AVAILABLE = False
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+if is_ftfy_available():
+ import ftfy
+
+EXAMPLE_DOC_STRING = """
+ Examples:
+ ```python
+ >>> import torch
+ >>> import numpy as np
+ >>> from diffusers import AutoencoderKLMagi1, Magi1ImageToVideoPipeline
+ >>> from diffusers.utils import export_to_video, load_image
+ >>> from transformers import CLIPVisionModel
+
+ >>> model_id = "SandAI/Magi1-I2V-14B-480P-Diffusers"
+ >>> image_encoder = CLIPVisionModel.from_pretrained(
+ ... model_id, subfolder="image_encoder", torch_dtype=torch.float32
+ ... )
+ >>> vae = AutoencoderKLMagi1.from_pretrained(model_id, subfolder="vae", torch_dtype=torch.float32)
+ >>> pipe = Magi1ImageToVideoPipeline.from_pretrained(
+ ... model_id, vae=vae, image_encoder=image_encoder, torch_dtype=torch.bfloat16
+ ... )
+ >>> pipe.to("cuda")
+
+ >>> image = load_image(
+ ... "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/astronaut.jpg"
+ ... )
+ >>> max_area = 480 * 832
+ >>> aspect_ratio = image.height / image.width
+ >>> mod_value = pipe.vae_scale_factor_spatial * pipe.transformer.config.patch_size[1]
+ >>> height = round(np.sqrt(max_area * aspect_ratio)) // mod_value * mod_value
+ >>> width = round(np.sqrt(max_area / aspect_ratio)) // mod_value * mod_value
+ >>> image = image.resize((width, height))
+ >>> prompt = (
+ ... "An astronaut hatching from an egg, on the surface of the moon, the darkness and depth of space realised in "
+ ... "the background. High quality, ultrarealistic detail and breath-taking movie-like camera shot."
+ ... )
+ >>> negative_prompt = "Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards"
+
+ >>> output = pipe(
+ ... image=image,
+ ... prompt=prompt,
+ ... negative_prompt=negative_prompt,
+ ... height=height,
+ ... width=width,
+ ... num_frames=81,
+ ... guidance_scale=5.0,
+ ... ).frames[0]
+ >>> export_to_video(output, "output.mp4", fps=16)
+ ```
+"""
+
+
+def basic_clean(text):
+ text = ftfy.fix_text(text)
+ text = html.unescape(html.unescape(text))
+ return text.strip()
+
+
+def whitespace_clean(text):
+ text = re.sub(r"\s+", " ", text)
+ text = text.strip()
+ return text
+
+
+def prompt_clean(text):
+ text = whitespace_clean(basic_clean(text))
+ return text
+
+
+# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
+def retrieve_latents(
+ encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
+):
+ if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
+ return encoder_output.latent_dist.sample(generator)
+ elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
+ return encoder_output.latent_dist.mode()
+ elif hasattr(encoder_output, "latents"):
+ return encoder_output.latents
+ else:
+ raise AttributeError("Could not access latents of provided encoder_output")
+
+
+class Magi1ImageToVideoPipeline(DiffusionPipeline, Magi1LoraLoaderMixin):
+ r"""
+ Pipeline for image-to-video generation using Magi1.
+
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods
+ implemented for all pipelines (downloading, saving, running on a particular device, etc.).
+
+ Args:
+ tokenizer ([`T5Tokenizer`]):
+ Tokenizer from [T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5Tokenizer),
+ specifically the [google/umt5-xxl](https://huggingface.co/google/umt5-xxl) variant.
+ text_encoder ([`T5EncoderModel`]):
+ [T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5EncoderModel), specifically
+ the [google/umt5-xxl](https://huggingface.co/google/umt5-xxl) variant.
+ image_encoder ([`CLIPVisionModel`]):
+ [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPVisionModel), specifically
+ the
+ [clip-vit-huge-patch14](https://github.com/mlfoundations/open_clip/blob/main/docs/PRETRAINED.md#vit-h14-xlm-roberta-large)
+ variant.
+ transformer ([`Magi1Transformer3DModel`]):
+ Conditional Transformer to denoise the input latents.
+ scheduler ([`UniPCMultistepScheduler`]):
+ A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
+ vae ([`AutoencoderKLMagi1`]):
+ Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations.
+ """
+
+ model_cpu_offload_seq = "text_encoder->image_encoder->transformer->vae"
+ _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"]
+
+ def __init__(
+ self,
+ tokenizer: AutoTokenizer,
+ text_encoder: UMT5EncoderModel,
+ image_encoder: CLIPVisionModel,
+ image_processor: CLIPImageProcessor,
+ transformer: Magi1Transformer3DModel,
+ vae: AutoencoderKLMagi1,
+ scheduler: FlowMatchEulerDiscreteScheduler,
+ ):
+ super().__init__()
+
+ self.register_modules(
+ vae=vae,
+ text_encoder=text_encoder,
+ tokenizer=tokenizer,
+ image_encoder=image_encoder,
+ transformer=transformer,
+ scheduler=scheduler,
+ image_processor=image_processor,
+ )
+
+ self.vae_scale_factor_temporal = 2 ** sum(self.vae.temperal_downsample) if getattr(self, "vae", None) else 4
+ self.vae_scale_factor_spatial = 2 ** len(self.vae.temperal_downsample) if getattr(self, "vae", None) else 8
+ self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial)
+ self.image_processor = image_processor
+
+ def _get_t5_prompt_embeds(
+ self,
+ prompt: Union[str, List[str]] = None,
+ num_videos_per_prompt: int = 1,
+ max_sequence_length: int = 512,
+ device: Optional[torch.device] = None,
+ dtype: Optional[torch.dtype] = None,
+ ):
+ device = device or self._execution_device
+ dtype = dtype or self.text_encoder.dtype
+
+ prompt = [prompt] if isinstance(prompt, str) else prompt
+ prompt = [prompt_clean(u) for u in prompt]
+ batch_size = len(prompt)
+
+ text_inputs = self.tokenizer(
+ prompt,
+ padding="max_length",
+ max_length=max_sequence_length,
+ truncation=True,
+ add_special_tokens=True,
+ return_attention_mask=True,
+ return_tensors="pt",
+ )
+ text_input_ids, mask = text_inputs.input_ids, text_inputs.attention_mask
+ seq_lens = mask.gt(0).sum(dim=1).long()
+
+ prompt_embeds = self.text_encoder(text_input_ids.to(device), mask.to(device)).last_hidden_state
+ prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
+ prompt_embeds = [u[:v] for u, v in zip(prompt_embeds, seq_lens)]
+ prompt_embeds = torch.stack(
+ [torch.cat([u, u.new_zeros(max_sequence_length - u.size(0), u.size(1))]) for u in prompt_embeds], dim=0
+ )
+
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
+ _, seq_len, _ = prompt_embeds.shape
+ prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1)
+ prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1)
+
+ return prompt_embeds
+
+ def encode_image(
+ self,
+ image: PipelineImageInput,
+ device: Optional[torch.device] = None,
+ ):
+ device = device or self._execution_device
+ image = self.image_processor(images=image, return_tensors="pt").to(device)
+ image_embeds = self.image_encoder(**image, output_hidden_states=True)
+ return image_embeds.hidden_states[-2]
+
+ # Copied from diffusers.pipelines.wan.pipeline_wan.WanPipeline.encode_prompt
+ def encode_prompt(
+ self,
+ prompt: Union[str, List[str]],
+ negative_prompt: Optional[Union[str, List[str]]] = None,
+ do_classifier_free_guidance: bool = True,
+ num_videos_per_prompt: int = 1,
+ prompt_embeds: Optional[torch.Tensor] = None,
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
+ max_sequence_length: int = 226,
+ device: Optional[torch.device] = None,
+ dtype: Optional[torch.dtype] = None,
+ ):
+ r"""
+ Encodes the prompt into text encoder hidden states.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ prompt to be encoded
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
+ less than `1`).
+ do_classifier_free_guidance (`bool`, *optional*, defaults to `True`):
+ Whether to use classifier free guidance or not.
+ num_videos_per_prompt (`int`, *optional*, defaults to 1):
+ Number of videos that should be generated per prompt. torch device to place the resulting embeddings on
+ prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
+ argument.
+ device: (`torch.device`, *optional*):
+ torch device
+ dtype: (`torch.dtype`, *optional*):
+ torch dtype
+ """
+ device = device or self._execution_device
+
+ prompt = [prompt] if isinstance(prompt, str) else prompt
+ if prompt is not None:
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ if prompt_embeds is None:
+ prompt_embeds = self._get_t5_prompt_embeds(
+ prompt=prompt,
+ num_videos_per_prompt=num_videos_per_prompt,
+ max_sequence_length=max_sequence_length,
+ device=device,
+ dtype=dtype,
+ )
+
+ if do_classifier_free_guidance and negative_prompt_embeds is None:
+ negative_prompt = negative_prompt or ""
+ negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
+
+ if prompt is not None and type(prompt) is not type(negative_prompt):
+ raise TypeError(
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
+ f" {type(prompt)}."
+ )
+ elif batch_size != len(negative_prompt):
+ raise ValueError(
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
+ " the batch size of `prompt`."
+ )
+
+ negative_prompt_embeds = self._get_t5_prompt_embeds(
+ prompt=negative_prompt,
+ num_videos_per_prompt=num_videos_per_prompt,
+ max_sequence_length=max_sequence_length,
+ device=device,
+ dtype=dtype,
+ )
+
+ return prompt_embeds, negative_prompt_embeds
+
+ def check_inputs(
+ self,
+ prompt,
+ negative_prompt,
+ image,
+ height,
+ width,
+ prompt_embeds=None,
+ negative_prompt_embeds=None,
+ image_embeds=None,
+ callback_on_step_end_tensor_inputs=None,
+ ):
+ if image is not None and image_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `image`: {image} and `image_embeds`: {image_embeds}. Please make sure to"
+ " only forward one of the two."
+ )
+ if image is None and image_embeds is None:
+ raise ValueError(
+ "Provide either `image` or `prompt_embeds`. Cannot leave both `image` and `image_embeds` undefined."
+ )
+ if image is not None and not isinstance(image, torch.Tensor) and not isinstance(image, PIL.Image.Image):
+ raise ValueError(f"`image` has to be of type `torch.Tensor` or `PIL.Image.Image` but is {type(image)}")
+ if height % 16 != 0 or width % 16 != 0:
+ raise ValueError(f"`height` and `width` have to be divisible by 16 but are {height} and {width}.")
+
+ if callback_on_step_end_tensor_inputs is not None and not all(
+ k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
+ ):
+ raise ValueError(
+ f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
+ )
+
+ if prompt is not None and prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
+ " only forward one of the two."
+ )
+ elif negative_prompt is not None and negative_prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`: {negative_prompt_embeds}. Please make sure to"
+ " only forward one of the two."
+ )
+ elif prompt is None and prompt_embeds is None:
+ raise ValueError(
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
+ )
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
+ elif negative_prompt is not None and (
+ not isinstance(negative_prompt, str) and not isinstance(negative_prompt, list)
+ ):
+ raise ValueError(f"`negative_prompt` has to be of type `str` or `list` but is {type(negative_prompt)}")
+
+ def prepare_latents(
+ self,
+ image: PipelineImageInput,
+ batch_size: int,
+ num_channels_latents: int = 16,
+ height: int = 480,
+ width: int = 832,
+ num_frames: int = 81,
+ dtype: Optional[torch.dtype] = None,
+ device: Optional[torch.device] = None,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ latents: Optional[torch.Tensor] = None,
+ last_image: Optional[torch.Tensor] = None,
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ num_latent_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1
+ latent_height = height // self.vae_scale_factor_spatial
+ latent_width = width // self.vae_scale_factor_spatial
+
+ shape = (batch_size, num_channels_latents, num_latent_frames, latent_height, latent_width)
+ if isinstance(generator, list) and len(generator) != batch_size:
+ raise ValueError(
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
+ )
+
+ if latents is None:
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
+ else:
+ latents = latents.to(device=device, dtype=dtype)
+
+ image = image.unsqueeze(2)
+ if last_image is None:
+ video_condition = torch.cat(
+ [image, image.new_zeros(image.shape[0], image.shape[1], num_frames - 1, height, width)], dim=2
+ )
+ else:
+ last_image = last_image.unsqueeze(2)
+ video_condition = torch.cat(
+ [image, image.new_zeros(image.shape[0], image.shape[1], num_frames - 2, height, width), last_image],
+ dim=2,
+ )
+ video_condition = video_condition.to(device=device, dtype=self.vae.dtype)
+
+ latents_mean = (
+ torch.tensor(self.vae.config.latents_mean)
+ .view(1, self.vae.config.z_dim, 1, 1, 1)
+ .to(latents.device, latents.dtype)
+ )
+ latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to(
+ latents.device, latents.dtype
+ )
+
+ if isinstance(generator, list):
+ latent_condition = [
+ retrieve_latents(self.vae.encode(video_condition), sample_mode="argmax") for _ in generator
+ ]
+ latent_condition = torch.cat(latent_condition)
+ else:
+ latent_condition = retrieve_latents(self.vae.encode(video_condition), sample_mode="argmax")
+ latent_condition = latent_condition.repeat(batch_size, 1, 1, 1, 1)
+
+ latent_condition = latent_condition.to(dtype)
+ latent_condition = (latent_condition - latents_mean) * latents_std
+
+ mask_lat_size = torch.ones(batch_size, 1, num_frames, latent_height, latent_width)
+
+ if last_image is None:
+ mask_lat_size[:, :, list(range(1, num_frames))] = 0
+ else:
+ mask_lat_size[:, :, list(range(1, num_frames - 1))] = 0
+ first_frame_mask = mask_lat_size[:, :, 0:1]
+ first_frame_mask = torch.repeat_interleave(first_frame_mask, dim=2, repeats=self.vae_scale_factor_temporal)
+ mask_lat_size = torch.concat([first_frame_mask, mask_lat_size[:, :, 1:, :]], dim=2)
+ mask_lat_size = mask_lat_size.view(batch_size, -1, self.vae_scale_factor_temporal, latent_height, latent_width)
+ mask_lat_size = mask_lat_size.transpose(1, 2)
+ mask_lat_size = mask_lat_size.to(latent_condition.device)
+
+ return latents, torch.concat([mask_lat_size, latent_condition], dim=1)
+
+ @property
+ def guidance_scale(self):
+ return self._guidance_scale
+
+ @property
+ def do_classifier_free_guidance(self):
+ return self._guidance_scale > 1
+
+ @property
+ def num_timesteps(self):
+ return self._num_timesteps
+
+ @property
+ def current_timestep(self):
+ return self._current_timestep
+
+ @property
+ def interrupt(self):
+ return self._interrupt
+
+ @property
+ def attention_kwargs(self):
+ return self._attention_kwargs
+
+ @torch.no_grad()
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
+ def __call__(
+ self,
+ image: PipelineImageInput,
+ prompt: Union[str, List[str]] = None,
+ negative_prompt: Union[str, List[str]] = None,
+ height: int = 480,
+ width: int = 832,
+ num_frames: int = 81,
+ num_inference_steps: int = 50,
+ guidance_scale: float = 5.0,
+ num_videos_per_prompt: Optional[int] = 1,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ latents: Optional[torch.Tensor] = None,
+ prompt_embeds: Optional[torch.Tensor] = None,
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
+ image_embeds: Optional[torch.Tensor] = None,
+ last_image: Optional[torch.Tensor] = None,
+ output_type: Optional[str] = "np",
+ return_dict: bool = True,
+ attention_kwargs: Optional[Dict[str, Any]] = None,
+ callback_on_step_end: Optional[
+ Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
+ ] = None,
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
+ max_sequence_length: int = 512,
+ ):
+ r"""
+ The call function to the pipeline for generation.
+
+ Args:
+ image (`PipelineImageInput`):
+ The input image to condition the generation on. Must be an image, a list of images or a `torch.Tensor`.
+ prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
+ instead.
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
+ less than `1`).
+ height (`int`, defaults to `480`):
+ The height of the generated video.
+ width (`int`, defaults to `832`):
+ The width of the generated video.
+ num_frames (`int`, defaults to `81`):
+ The number of frames in the generated video.
+ num_inference_steps (`int`, defaults to `50`):
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+ expense of slower inference.
+ guidance_scale (`float`, defaults to `5.0`):
+ Guidance scale as defined in [Classifier-Free Diffusion
+ Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2.
+ of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting
+ `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to
+ the text `prompt`, usually at the expense of lower image quality.
+ num_videos_per_prompt (`int`, *optional*, defaults to 1):
+ The number of images to generate per prompt.
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
+ A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
+ generation deterministic.
+ latents (`torch.Tensor`, *optional*):
+ Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
+ tensor is generated by sampling using the supplied random `generator`.
+ prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not
+ provided, text embeddings are generated from the `prompt` input argument.
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not
+ provided, text embeddings are generated from the `negative_prompt` input argument.
+ image_embeds (`torch.Tensor`, *optional*):
+ Pre-generated image embeddings. Can be used to easily tweak image inputs (weighting). If not provided,
+ image embeddings are generated from the `image` input argument.
+ output_type (`str`, *optional*, defaults to `"np"`):
+ The output format of the generated image. Choose between `PIL.Image` or `np.array`.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`WanPipelineOutput`] instead of a plain tuple.
+ attention_kwargs (`dict`, *optional*):
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
+ `self.processor` in
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
+ callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*):
+ A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of
+ each denoising step during the inference. with the following arguments: `callback_on_step_end(self:
+ DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a
+ list of all tensors as specified by `callback_on_step_end_tensor_inputs`.
+ callback_on_step_end_tensor_inputs (`List`, *optional*):
+ The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
+ will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
+ `._callback_tensor_inputs` attribute of your pipeline class.
+ max_sequence_length (`int`, defaults to `512`):
+ The maximum sequence length of the text encoder. If the prompt is longer than this, it will be
+ truncated. If the prompt is shorter, it will be padded to this length.
+
+ Examples:
+
+ Returns:
+ [`~Magi1PipelineOutput`] or `tuple`:
+ If `return_dict` is `True`, [`Magi1PipelineOutput`] is returned, otherwise a `tuple` is returned where
+ the first element is a list with the generated images and the second element is a list of `bool`s
+ indicating whether the corresponding generated image contains "not-safe-for-work" (nsfw) content.
+ """
+
+ if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
+ callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
+
+ # 1. Check inputs. Raise error if not correct
+ self.check_inputs(
+ prompt,
+ negative_prompt,
+ image,
+ height,
+ width,
+ prompt_embeds,
+ negative_prompt_embeds,
+ image_embeds,
+ callback_on_step_end_tensor_inputs,
+ )
+
+ if num_frames % self.vae_scale_factor_temporal != 1:
+ logger.warning(
+ f"`num_frames - 1` has to be divisible by {self.vae_scale_factor_temporal}. Rounding to the nearest number."
+ )
+ num_frames = num_frames // self.vae_scale_factor_temporal * self.vae_scale_factor_temporal + 1
+ num_frames = max(num_frames, 1)
+
+ self._guidance_scale = guidance_scale
+ self._attention_kwargs = attention_kwargs
+ self._current_timestep = None
+ self._interrupt = False
+
+ device = self._execution_device
+
+ # 2. Define call parameters
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ # 3. Encode input prompt
+ prompt_embeds, negative_prompt_embeds = self.encode_prompt(
+ prompt=prompt,
+ negative_prompt=negative_prompt,
+ do_classifier_free_guidance=self.do_classifier_free_guidance,
+ num_videos_per_prompt=num_videos_per_prompt,
+ prompt_embeds=prompt_embeds,
+ negative_prompt_embeds=negative_prompt_embeds,
+ max_sequence_length=max_sequence_length,
+ device=device,
+ )
+
+ # Encode image embedding
+ transformer_dtype = self.transformer.dtype
+ prompt_embeds = prompt_embeds.to(transformer_dtype)
+ if negative_prompt_embeds is not None:
+ negative_prompt_embeds = negative_prompt_embeds.to(transformer_dtype)
+
+ if image_embeds is None:
+ if last_image is None:
+ image_embeds = self.encode_image(image, device)
+ else:
+ image_embeds = self.encode_image([image, last_image], device)
+ image_embeds = image_embeds.repeat(batch_size, 1, 1)
+ image_embeds = image_embeds.to(transformer_dtype)
+
+ # 4. Prepare timesteps
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
+ timesteps = self.scheduler.timesteps
+
+ # 5. Prepare latent variables
+ num_channels_latents = self.vae.config.z_dim
+ image = self.video_processor.preprocess(image, height=height, width=width).to(device, dtype=torch.float32)
+ if last_image is not None:
+ last_image = self.video_processor.preprocess(last_image, height=height, width=width).to(
+ device, dtype=torch.float32
+ )
+ latents, condition = self.prepare_latents(
+ image,
+ batch_size * num_videos_per_prompt,
+ num_channels_latents,
+ height,
+ width,
+ num_frames,
+ torch.float32,
+ device,
+ generator,
+ latents,
+ last_image,
+ )
+
+ # 6. Denoising loop
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
+ self._num_timesteps = len(timesteps)
+
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
+ for i, t in enumerate(timesteps):
+ if self.interrupt:
+ continue
+
+ self._current_timestep = t
+ latent_model_input = torch.cat([latents, condition], dim=1).to(transformer_dtype)
+ timestep = t.expand(latents.shape[0])
+
+ noise_pred = self.transformer(
+ hidden_states=latent_model_input,
+ timestep=timestep,
+ encoder_hidden_states=prompt_embeds,
+ encoder_hidden_states_image=image_embeds,
+ attention_kwargs=attention_kwargs,
+ return_dict=False,
+ )[0]
+
+ if self.do_classifier_free_guidance:
+ noise_uncond = self.transformer(
+ hidden_states=latent_model_input,
+ timestep=timestep,
+ encoder_hidden_states=negative_prompt_embeds,
+ encoder_hidden_states_image=image_embeds,
+ attention_kwargs=attention_kwargs,
+ return_dict=False,
+ )[0]
+ noise_pred = noise_uncond + guidance_scale * (noise_pred - noise_uncond)
+
+ # compute the previous noisy sample x_t -> x_t-1
+ latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
+
+ if callback_on_step_end is not None:
+ callback_kwargs = {}
+ for k in callback_on_step_end_tensor_inputs:
+ callback_kwargs[k] = locals()[k]
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
+
+ latents = callback_outputs.pop("latents", latents)
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
+ negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
+
+ # call the callback, if provided
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
+ progress_bar.update()
+
+ if XLA_AVAILABLE:
+ xm.mark_step()
+
+ self._current_timestep = None
+
+ if not output_type == "latent":
+ latents = latents.to(self.vae.dtype)
+ latents_mean = (
+ torch.tensor(self.vae.config.latents_mean)
+ .view(1, self.vae.config.z_dim, 1, 1, 1)
+ .to(latents.device, latents.dtype)
+ )
+ latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to(
+ latents.device, latents.dtype
+ )
+ latents = latents / latents_std + latents_mean
+ video = self.vae.decode(latents, return_dict=False)[0]
+ video = self.video_processor.postprocess_video(video, output_type=output_type)
+ else:
+ video = latents
+
+ # Offload all models
+ self.maybe_free_model_hooks()
+
+ if not return_dict:
+ return (video,)
+
+ return Magi1PipelineOutput(frames=video)
diff --git a/src/diffusers/pipelines/magi1/pipeline_magi1_v2v.py b/src/diffusers/pipelines/magi1/pipeline_magi1_v2v.py
new file mode 100644
index 000000000000..65efb1757a83
--- /dev/null
+++ b/src/diffusers/pipelines/magi1/pipeline_magi1_v2v.py
@@ -0,0 +1,725 @@
+# Copyright 2025 The SandAI Team and The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import html
+import inspect
+from typing import Any, Callable, Dict, List, Optional, Union
+
+import regex as re
+import torch
+from PIL import Image
+from transformers import AutoTokenizer, UMT5EncoderModel
+
+from ...callbacks import MultiPipelineCallbacks, PipelineCallback
+from ...loaders import Magi1LoraLoaderMixin
+from ...models import AutoencoderKLMagi1, Magi1Transformer3DModel
+from ...schedulers import FlowMatchEulerDiscreteScheduler
+from ...utils import is_ftfy_available, is_torch_xla_available, logging, replace_example_docstring
+from ...utils.torch_utils import randn_tensor
+from ...video_processor import VideoProcessor
+from ..pipeline_utils import DiffusionPipeline
+from .pipeline_output import Magi1PipelineOutput
+
+
+if is_torch_xla_available():
+ import torch_xla.core.xla_model as xm
+
+ XLA_AVAILABLE = True
+else:
+ XLA_AVAILABLE = False
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+if is_ftfy_available():
+ import ftfy
+
+
+EXAMPLE_DOC_STRING = """
+ Examples:
+ ```python
+ >>> import torch
+ >>> from diffusers.utils import export_to_video
+ >>> from diffusers import AutoencoderKLMagi1, Magi1VideoToVideoPipeline
+ >>> from diffusers.schedulers.scheduling_unipc_multistep import UniPCMultistepScheduler
+
+ >>> # Available models: SandAI/Magi1-T2V-14B-480P-Diffusers, SandAI/Magi1-T2V-1.3B-480P-Diffusers
+ >>> model_id = "SandAI/Magi1-T2V-1.3B-480P-Diffusers"
+ >>> vae = AutoencoderKLMagi1.from_pretrained(model_id, subfolder="vae", torch_dtype=torch.float32)
+ >>> pipe = Magi1VideoToVideoPipeline.from_pretrained(model_id, vae=vae, torch_dtype=torch.bfloat16)
+ >>> flow_shift = 3.0 # 5.0 for 720P, 3.0 for 480P
+ >>> pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config, flow_shift=flow_shift)
+ >>> pipe.to("cuda")
+
+ >>> prompt = "A robot standing on a mountain top. The sun is setting in the background"
+ >>> negative_prompt = "Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards"
+ >>> video = load_video(
+ ... "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/hiker.mp4"
+ ... )
+ >>> output = pipe(
+ ... video=video,
+ ... prompt=prompt,
+ ... negative_prompt=negative_prompt,
+ ... height=480,
+ ... width=720,
+ ... guidance_scale=5.0,
+ ... strength=0.7,
+ ... ).frames[0]
+ >>> export_to_video(output, "output.mp4", fps=16)
+ ```
+"""
+
+
+def basic_clean(text):
+ text = ftfy.fix_text(text)
+ text = html.unescape(html.unescape(text))
+ return text.strip()
+
+
+def whitespace_clean(text):
+ text = re.sub(r"\s+", " ", text)
+ text = text.strip()
+ return text
+
+
+def prompt_clean(text):
+ text = whitespace_clean(basic_clean(text))
+ return text
+
+
+# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
+def retrieve_timesteps(
+ scheduler,
+ num_inference_steps: Optional[int] = None,
+ device: Optional[Union[str, torch.device]] = None,
+ timesteps: Optional[List[int]] = None,
+ sigmas: Optional[List[float]] = None,
+ **kwargs,
+):
+ r"""
+ Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
+ custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
+
+ Args:
+ scheduler (`SchedulerMixin`):
+ The scheduler to get timesteps from.
+ num_inference_steps (`int`):
+ The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
+ must be `None`.
+ device (`str` or `torch.device`, *optional*):
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
+ timesteps (`List[int]`, *optional*):
+ Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
+ `num_inference_steps` and `sigmas` must be `None`.
+ sigmas (`List[float]`, *optional*):
+ Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
+ `num_inference_steps` and `timesteps` must be `None`.
+
+ Returns:
+ `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
+ second element is the number of inference steps.
+ """
+ if timesteps is not None and sigmas is not None:
+ raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
+ if timesteps is not None:
+ accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
+ if not accepts_timesteps:
+ raise ValueError(
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
+ f" timestep schedules. Please check whether you are using the correct scheduler."
+ )
+ scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ num_inference_steps = len(timesteps)
+ elif sigmas is not None:
+ accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
+ if not accept_sigmas:
+ raise ValueError(
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
+ f" sigmas schedules. Please check whether you are using the correct scheduler."
+ )
+ scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ num_inference_steps = len(timesteps)
+ else:
+ scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ return timesteps, num_inference_steps
+
+
+# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
+def retrieve_latents(
+ encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
+):
+ if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
+ return encoder_output.latent_dist.sample(generator)
+ elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
+ return encoder_output.latent_dist.mode()
+ elif hasattr(encoder_output, "latents"):
+ return encoder_output.latents
+ else:
+ raise AttributeError("Could not access latents of provided encoder_output")
+
+
+class Magi1VideoToVideoPipeline(DiffusionPipeline, Magi1LoraLoaderMixin):
+ r"""
+ Pipeline for video-to-video generation using Magi1.
+
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods
+ implemented for all pipelines (downloading, saving, running on a particular device, etc.).
+
+ Args:
+ tokenizer ([`T5Tokenizer`]):
+ Tokenizer from [T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5Tokenizer),
+ specifically the [google/umt5-xxl](https://huggingface.co/google/umt5-xxl) variant.
+ text_encoder ([`T5EncoderModel`]):
+ [T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5EncoderModel), specifically
+ the [google/umt5-xxl](https://huggingface.co/google/umt5-xxl) variant.
+ transformer ([`Magi1Transformer3DModel`]):
+ Conditional Transformer to denoise the input latents.
+ scheduler ([`UniPCMultistepScheduler`]):
+ A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
+ vae ([`AutoencoderKLMagi1`]):
+ Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations.
+ """
+
+ model_cpu_offload_seq = "text_encoder->transformer->vae"
+ _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"]
+
+ def __init__(
+ self,
+ tokenizer: AutoTokenizer,
+ text_encoder: UMT5EncoderModel,
+ transformer: Magi1Transformer3DModel,
+ vae: AutoencoderKLMagi1,
+ scheduler: FlowMatchEulerDiscreteScheduler,
+ ):
+ super().__init__()
+
+ self.register_modules(
+ vae=vae,
+ text_encoder=text_encoder,
+ tokenizer=tokenizer,
+ transformer=transformer,
+ scheduler=scheduler,
+ )
+
+ self.vae_scale_factor_temporal = 2 ** sum(self.vae.temperal_downsample) if getattr(self, "vae", None) else 4
+ self.vae_scale_factor_spatial = 2 ** len(self.vae.temperal_downsample) if getattr(self, "vae", None) else 8
+ self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial)
+
+ # Copied from diffusers.pipelines.wan.pipeline_wan.WanPipeline._get_t5_prompt_embeds
+ def _get_t5_prompt_embeds(
+ self,
+ prompt: Union[str, List[str]] = None,
+ num_videos_per_prompt: int = 1,
+ max_sequence_length: int = 226,
+ device: Optional[torch.device] = None,
+ dtype: Optional[torch.dtype] = None,
+ ):
+ device = device or self._execution_device
+ dtype = dtype or self.text_encoder.dtype
+
+ prompt = [prompt] if isinstance(prompt, str) else prompt
+ prompt = [prompt_clean(u) for u in prompt]
+ batch_size = len(prompt)
+
+ text_inputs = self.tokenizer(
+ prompt,
+ padding="max_length",
+ max_length=max_sequence_length,
+ truncation=True,
+ add_special_tokens=True,
+ return_attention_mask=True,
+ return_tensors="pt",
+ )
+ text_input_ids, mask = text_inputs.input_ids, text_inputs.attention_mask
+ seq_lens = mask.gt(0).sum(dim=1).long()
+
+ prompt_embeds = self.text_encoder(text_input_ids.to(device), mask.to(device)).last_hidden_state
+ prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
+ prompt_embeds = [u[:v] for u, v in zip(prompt_embeds, seq_lens)]
+ prompt_embeds = torch.stack(
+ [torch.cat([u, u.new_zeros(max_sequence_length - u.size(0), u.size(1))]) for u in prompt_embeds], dim=0
+ )
+
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
+ _, seq_len, _ = prompt_embeds.shape
+ prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1)
+ prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1)
+
+ return prompt_embeds
+
+ # Copied from diffusers.pipelines.wan.pipeline_wan.WanPipeline.encode_prompt
+ def encode_prompt(
+ self,
+ prompt: Union[str, List[str]],
+ negative_prompt: Optional[Union[str, List[str]]] = None,
+ do_classifier_free_guidance: bool = True,
+ num_videos_per_prompt: int = 1,
+ prompt_embeds: Optional[torch.Tensor] = None,
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
+ max_sequence_length: int = 226,
+ device: Optional[torch.device] = None,
+ dtype: Optional[torch.dtype] = None,
+ ):
+ r"""
+ Encodes the prompt into text encoder hidden states.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ prompt to be encoded
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
+ less than `1`).
+ do_classifier_free_guidance (`bool`, *optional*, defaults to `True`):
+ Whether to use classifier free guidance or not.
+ num_videos_per_prompt (`int`, *optional*, defaults to 1):
+ Number of videos that should be generated per prompt. torch device to place the resulting embeddings on
+ prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
+ argument.
+ device: (`torch.device`, *optional*):
+ torch device
+ dtype: (`torch.dtype`, *optional*):
+ torch dtype
+ """
+ device = device or self._execution_device
+
+ prompt = [prompt] if isinstance(prompt, str) else prompt
+ if prompt is not None:
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ if prompt_embeds is None:
+ prompt_embeds = self._get_t5_prompt_embeds(
+ prompt=prompt,
+ num_videos_per_prompt=num_videos_per_prompt,
+ max_sequence_length=max_sequence_length,
+ device=device,
+ dtype=dtype,
+ )
+
+ if do_classifier_free_guidance and negative_prompt_embeds is None:
+ negative_prompt = negative_prompt or ""
+ negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
+
+ if prompt is not None and type(prompt) is not type(negative_prompt):
+ raise TypeError(
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
+ f" {type(prompt)}."
+ )
+ elif batch_size != len(negative_prompt):
+ raise ValueError(
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
+ " the batch size of `prompt`."
+ )
+
+ negative_prompt_embeds = self._get_t5_prompt_embeds(
+ prompt=negative_prompt,
+ num_videos_per_prompt=num_videos_per_prompt,
+ max_sequence_length=max_sequence_length,
+ device=device,
+ dtype=dtype,
+ )
+
+ return prompt_embeds, negative_prompt_embeds
+
+ def check_inputs(
+ self,
+ prompt,
+ negative_prompt,
+ height,
+ width,
+ video=None,
+ latents=None,
+ prompt_embeds=None,
+ negative_prompt_embeds=None,
+ callback_on_step_end_tensor_inputs=None,
+ ):
+ if height % 16 != 0 or width % 16 != 0:
+ raise ValueError(f"`height` and `width` have to be divisible by 16 but are {height} and {width}.")
+
+ if callback_on_step_end_tensor_inputs is not None and not all(
+ k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
+ ):
+ raise ValueError(
+ f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
+ )
+
+ if prompt is not None and prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
+ " only forward one of the two."
+ )
+ elif negative_prompt is not None and negative_prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`: {negative_prompt_embeds}. Please make sure to"
+ " only forward one of the two."
+ )
+ elif prompt is None and prompt_embeds is None:
+ raise ValueError(
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
+ )
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
+ elif negative_prompt is not None and (
+ not isinstance(negative_prompt, str) and not isinstance(negative_prompt, list)
+ ):
+ raise ValueError(f"`negative_prompt` has to be of type `str` or `list` but is {type(negative_prompt)}")
+
+ if video is not None and latents is not None:
+ raise ValueError("Only one of `video` or `latents` should be provided")
+
+ def prepare_latents(
+ self,
+ video: Optional[torch.Tensor] = None,
+ batch_size: int = 1,
+ num_channels_latents: int = 16,
+ height: int = 480,
+ width: int = 832,
+ dtype: Optional[torch.dtype] = None,
+ device: Optional[torch.device] = None,
+ generator: Optional[torch.Generator] = None,
+ latents: Optional[torch.Tensor] = None,
+ timestep: Optional[torch.Tensor] = None,
+ ):
+ if isinstance(generator, list) and len(generator) != batch_size:
+ raise ValueError(
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
+ )
+
+ num_latent_frames = (
+ (video.size(2) - 1) // self.vae_scale_factor_temporal + 1 if latents is None else latents.size(1)
+ )
+ shape = (
+ batch_size,
+ num_channels_latents,
+ num_latent_frames,
+ height // self.vae_scale_factor_spatial,
+ width // self.vae_scale_factor_spatial,
+ )
+
+ if latents is None:
+ init_latents = [retrieve_latents(self.vae.encode(vid.unsqueeze(0)), sample_mode="argmax") for vid in video]
+
+ init_latents = torch.cat(init_latents, dim=0).to(dtype)
+
+ latents_mean = (
+ torch.tensor(self.vae.config.latents_mean).view(1, self.vae.config.z_dim, 1, 1, 1).to(device, dtype)
+ )
+ latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to(
+ device, dtype
+ )
+
+ init_latents = (init_latents - latents_mean) * latents_std
+
+ noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
+ if hasattr(self.scheduler, "add_noise"):
+ latents = self.scheduler.add_noise(init_latents, noise, timestep)
+ else:
+ latents = self.scheduler.scale_noise(init_latents, timestep, noise)
+ else:
+ latents = latents.to(device)
+
+ return latents
+
+ # Copied from diffusers.pipelines.animatediff.pipeline_animatediff_video2video.AnimateDiffVideoToVideoPipeline.get_timesteps
+ def get_timesteps(self, num_inference_steps, timesteps, strength, device):
+ # get the original timestep using init_timestep
+ init_timestep = min(int(num_inference_steps * strength), num_inference_steps)
+
+ t_start = max(num_inference_steps - init_timestep, 0)
+ timesteps = timesteps[t_start * self.scheduler.order :]
+
+ return timesteps, num_inference_steps - t_start
+
+ @property
+ def guidance_scale(self):
+ return self._guidance_scale
+
+ @property
+ def do_classifier_free_guidance(self):
+ return self._guidance_scale > 1.0
+
+ @property
+ def num_timesteps(self):
+ return self._num_timesteps
+
+ @property
+ def current_timestep(self):
+ return self._current_timestep
+
+ @property
+ def interrupt(self):
+ return self._interrupt
+
+ @property
+ def attention_kwargs(self):
+ return self._attention_kwargs
+
+ @torch.no_grad()
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
+ def __call__(
+ self,
+ video: List[Image.Image] = None,
+ prompt: Union[str, List[str]] = None,
+ negative_prompt: Union[str, List[str]] = None,
+ height: int = 480,
+ width: int = 832,
+ num_inference_steps: int = 50,
+ timesteps: Optional[List[int]] = None,
+ guidance_scale: float = 5.0,
+ strength: float = 0.8,
+ num_videos_per_prompt: Optional[int] = 1,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ latents: Optional[torch.Tensor] = None,
+ prompt_embeds: Optional[torch.Tensor] = None,
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
+ output_type: Optional[str] = "np",
+ return_dict: bool = True,
+ attention_kwargs: Optional[Dict[str, Any]] = None,
+ callback_on_step_end: Optional[
+ Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
+ ] = None,
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
+ max_sequence_length: int = 512,
+ ):
+ r"""
+ The call function to the pipeline for generation.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`
+ instead.
+ height (`int`, defaults to `480`):
+ The height in pixels of the generated image.
+ width (`int`, defaults to `832`):
+ The width in pixels of the generated image.
+ num_frames (`int`, defaults to `81`):
+ The number of frames in the generated video.
+ num_inference_steps (`int`, defaults to `50`):
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+ expense of slower inference.
+ guidance_scale (`float`, defaults to `5.0`):
+ Guidance scale as defined in [Classifier-Free Diffusion
+ Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2.
+ of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting
+ `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to
+ the text `prompt`, usually at the expense of lower image quality.
+ strength (`float`, defaults to `0.8`):
+ Higher strength leads to more differences between original image and generated video.
+ num_videos_per_prompt (`int`, *optional*, defaults to 1):
+ The number of images to generate per prompt.
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
+ A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
+ generation deterministic.
+ latents (`torch.Tensor`, *optional*):
+ Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
+ tensor is generated by sampling using the supplied random `generator`.
+ prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not
+ provided, text embeddings are generated from the `prompt` input argument.
+ output_type (`str`, *optional*, defaults to `"np"`):
+ The output format of the generated image. Choose between `PIL.Image` or `np.array`.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`WanPipelineOutput`] instead of a plain tuple.
+ attention_kwargs (`dict`, *optional*):
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
+ `self.processor` in
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
+ callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*):
+ A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of
+ each denoising step during the inference. with the following arguments: `callback_on_step_end(self:
+ DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a
+ list of all tensors as specified by `callback_on_step_end_tensor_inputs`.
+ callback_on_step_end_tensor_inputs (`List`, *optional*):
+ The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
+ will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
+ `._callback_tensor_inputs` attribute of your pipeline class.
+ max_sequence_length (`int`, defaults to `512`):
+ The maximum sequence length of the text encoder. If the prompt is longer than this, it will be
+ truncated. If the prompt is shorter, it will be padded to this length.
+
+ Examples:
+
+ Returns:
+ [`~Magi1PipelineOutput`] or `tuple`:
+ If `return_dict` is `True`, [`Magi1PipelineOutput`] is returned, otherwise a `tuple` is returned where
+ the first element is a list with the generated images and the second element is a list of `bool`s
+ indicating whether the corresponding generated image contains "not-safe-for-work" (nsfw) content.
+ """
+
+ if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
+ callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
+
+ height = height or self.transformer.config.sample_height * self.vae_scale_factor_spatial
+ width = width or self.transformer.config.sample_width * self.vae_scale_factor_spatial
+ num_videos_per_prompt = 1
+
+ # 1. Check inputs. Raise error if not correct
+ self.check_inputs(
+ prompt,
+ negative_prompt,
+ height,
+ width,
+ video,
+ latents,
+ prompt_embeds,
+ negative_prompt_embeds,
+ callback_on_step_end_tensor_inputs,
+ )
+
+ self._guidance_scale = guidance_scale
+ self._attention_kwargs = attention_kwargs
+ self._current_timestep = None
+ self._interrupt = False
+
+ device = self._execution_device
+
+ # 2. Define call parameters
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ # 3. Encode input prompt
+ prompt_embeds, negative_prompt_embeds = self.encode_prompt(
+ prompt=prompt,
+ negative_prompt=negative_prompt,
+ do_classifier_free_guidance=self.do_classifier_free_guidance,
+ num_videos_per_prompt=num_videos_per_prompt,
+ prompt_embeds=prompt_embeds,
+ negative_prompt_embeds=negative_prompt_embeds,
+ max_sequence_length=max_sequence_length,
+ device=device,
+ )
+
+ transformer_dtype = self.transformer.dtype
+ prompt_embeds = prompt_embeds.to(transformer_dtype)
+ if negative_prompt_embeds is not None:
+ negative_prompt_embeds = negative_prompt_embeds.to(transformer_dtype)
+
+ # 4. Prepare timesteps
+ timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)
+ timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, timesteps, strength, device)
+ latent_timestep = timesteps[:1].repeat(batch_size * num_videos_per_prompt)
+ self._num_timesteps = len(timesteps)
+
+ if latents is None:
+ video = self.video_processor.preprocess_video(video, height=height, width=width).to(
+ device, dtype=torch.float32
+ )
+
+ # 5. Prepare latent variables
+ num_channels_latents = self.transformer.config.in_channels
+ latents = self.prepare_latents(
+ video,
+ batch_size * num_videos_per_prompt,
+ num_channels_latents,
+ height,
+ width,
+ torch.float32,
+ device,
+ generator,
+ latents,
+ latent_timestep,
+ )
+
+ # 6. Denoising loop
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
+ self._num_timesteps = len(timesteps)
+
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
+ for i, t in enumerate(timesteps):
+ if self.interrupt:
+ continue
+
+ self._current_timestep = t
+ latent_model_input = latents.to(transformer_dtype)
+ timestep = t.expand(latents.shape[0])
+
+ noise_pred = self.transformer(
+ hidden_states=latent_model_input,
+ timestep=timestep,
+ encoder_hidden_states=prompt_embeds,
+ attention_kwargs=attention_kwargs,
+ return_dict=False,
+ )[0]
+
+ if self.do_classifier_free_guidance:
+ noise_uncond = self.transformer(
+ hidden_states=latent_model_input,
+ timestep=timestep,
+ encoder_hidden_states=negative_prompt_embeds,
+ attention_kwargs=attention_kwargs,
+ return_dict=False,
+ )[0]
+ noise_pred = noise_uncond + guidance_scale * (noise_pred - noise_uncond)
+
+ # compute the previous noisy sample x_t -> x_t-1
+ latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
+
+ if callback_on_step_end is not None:
+ callback_kwargs = {}
+ for k in callback_on_step_end_tensor_inputs:
+ callback_kwargs[k] = locals()[k]
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
+
+ latents = callback_outputs.pop("latents", latents)
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
+ negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
+
+ # call the callback, if provided
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
+ progress_bar.update()
+
+ if XLA_AVAILABLE:
+ xm.mark_step()
+
+ self._current_timestep = None
+
+ if not output_type == "latent":
+ latents = latents.to(self.vae.dtype)
+ latents_mean = (
+ torch.tensor(self.vae.config.latents_mean)
+ .view(1, self.vae.config.z_dim, 1, 1, 1)
+ .to(latents.device, latents.dtype)
+ )
+ latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to(
+ latents.device, latents.dtype
+ )
+ latents = latents / latents_std + latents_mean
+ video = self.vae.decode(latents, return_dict=False)[0]
+ video = self.video_processor.postprocess_video(video, output_type=output_type)
+ else:
+ video = latents
+
+ # Offload all models
+ self.maybe_free_model_hooks()
+
+ if not return_dict:
+ return (video,)
+
+ return Magi1PipelineOutput(frames=video)
diff --git a/src/diffusers/pipelines/magi1/pipeline_output.py b/src/diffusers/pipelines/magi1/pipeline_output.py
new file mode 100644
index 000000000000..200156cffac9
--- /dev/null
+++ b/src/diffusers/pipelines/magi1/pipeline_output.py
@@ -0,0 +1,36 @@
+# Copyright 2024 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from dataclasses import dataclass
+from typing import List, Union
+
+import numpy as np
+import torch
+
+from ...utils import BaseOutput
+
+
+@dataclass
+class Magi1PipelineOutput(BaseOutput):
+ """
+ Output class for MAGI-1 pipeline.
+
+ Args:
+ frames (`torch.Tensor` or `np.ndarray`):
+ List of denoised frames from the diffusion process, as a NumPy array of shape `(batch_size, num_frames,
+ height, width, num_channels)` or a PyTorch tensor of shape `(batch_size, num_channels, num_frames, height,
+ width)`.
+ """
+
+ frames: Union[torch.Tensor, np.ndarray, List[List[np.ndarray]]]
diff --git a/src/diffusers/utils/dummy_pt_objects.py b/src/diffusers/utils/dummy_pt_objects.py
index 2981f3a420d6..46ee0e97a5b6 100644
--- a/src/diffusers/utils/dummy_pt_objects.py
+++ b/src/diffusers/utils/dummy_pt_objects.py
@@ -205,6 +205,21 @@ def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
+class AutoencoderKLMagi1(metaclass=DummyObject):
+ _backends = ["torch"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+
class AutoencoderKLMagvit(metaclass=DummyObject):
_backends = ["torch"]
@@ -730,6 +745,21 @@ def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
+class Magi1Transformer3DModel(metaclass=DummyObject):
+ _backends = ["torch"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+
class MochiTransformer3DModel(metaclass=DummyObject):
_backends = ["torch"]
diff --git a/src/diffusers/utils/dummy_torch_and_transformers_objects.py b/src/diffusers/utils/dummy_torch_and_transformers_objects.py
index 9cb869c67a3e..ba3944c1f44b 100644
--- a/src/diffusers/utils/dummy_torch_and_transformers_objects.py
+++ b/src/diffusers/utils/dummy_torch_and_transformers_objects.py
@@ -1487,6 +1487,51 @@ def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
+class Magi1ImageToVideoPipeline(metaclass=DummyObject):
+ _backends = ["torch", "transformers"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch", "transformers"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+
+class Magi1Pipeline(metaclass=DummyObject):
+ _backends = ["torch", "transformers"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch", "transformers"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+
+class Magi1VideoToVideoPipeline(metaclass=DummyObject):
+ _backends = ["torch", "transformers"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch", "transformers"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+
class MarigoldDepthPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
diff --git a/tests/models/autoencoders/test_models_autoencoder_kl_magi.py b/tests/models/autoencoders/test_models_autoencoder_kl_magi.py
new file mode 100644
index 000000000000..5c11799e878c
--- /dev/null
+++ b/tests/models/autoencoders/test_models_autoencoder_kl_magi.py
@@ -0,0 +1,155 @@
+# coding=utf-8
+# Copyright 2024 HuggingFace Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import unittest
+
+import torch
+
+from diffusers import AutoencoderKLMagi1
+from diffusers.utils.testing_utils import enable_full_determinism, floats_tensor, torch_device
+
+from ..test_modeling_common import ModelTesterMixin, UNetTesterMixin
+
+
+enable_full_determinism()
+
+
+class AutoencoderKLMagiTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
+ model_class = AutoencoderKLMagi1
+ main_input_name = "sample"
+ base_precision = 1e-2
+
+ def get_autoencoder_kl_magi1_config(self):
+ return {
+ "base_dim": 3,
+ "z_dim": 16,
+ "dim_mult": [1, 1, 1, 1],
+ "num_res_blocks": 1,
+ "temperal_downsample": [False, True, True],
+ }
+
+ @property
+ def dummy_input(self):
+ batch_size = 2
+ num_frames = 9
+ num_channels = 3
+ sizes = (16, 16)
+ image = floats_tensor((batch_size, num_channels, num_frames) + sizes).to(torch_device)
+ return {"sample": image}
+
+ @property
+ def dummy_input_tiling(self):
+ batch_size = 2
+ num_frames = 9
+ num_channels = 3
+ sizes = (128, 128)
+ image = floats_tensor((batch_size, num_channels, num_frames) + sizes).to(torch_device)
+ return {"sample": image}
+
+ @property
+ def input_shape(self):
+ return (3, 9, 16, 16)
+
+ @property
+ def output_shape(self):
+ return (3, 9, 16, 16)
+
+ def prepare_init_args_and_inputs_for_common(self):
+ init_dict = self.get_autoencoder_kl_magi1_config()
+ inputs_dict = self.dummy_input
+ return init_dict, inputs_dict
+
+ def prepare_init_args_and_inputs_for_tiling(self):
+ init_dict = self.get_autoencoder_kl_magi1_config()
+ inputs_dict = self.dummy_input_tiling
+ return init_dict, inputs_dict
+
+ def test_enable_disable_tiling(self):
+ init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_tiling()
+
+ torch.manual_seed(0)
+ model = self.model_class(**init_dict).to(torch_device)
+
+ inputs_dict.update({"return_dict": False})
+
+ torch.manual_seed(0)
+ output_without_tiling = model(**inputs_dict, generator=torch.manual_seed(0))[0]
+
+ torch.manual_seed(0)
+ model.enable_tiling(96, 96, 64, 64)
+ output_with_tiling = model(**inputs_dict, generator=torch.manual_seed(0))[0]
+
+ self.assertLess(
+ (output_without_tiling.detach().cpu().numpy() - output_with_tiling.detach().cpu().numpy()).max(),
+ 0.5,
+ "VAE tiling should not affect the inference results",
+ )
+
+ torch.manual_seed(0)
+ model.disable_tiling()
+ output_without_tiling_2 = model(**inputs_dict, generator=torch.manual_seed(0))[0]
+
+ self.assertEqual(
+ output_without_tiling.detach().cpu().numpy().all(),
+ output_without_tiling_2.detach().cpu().numpy().all(),
+ "Without tiling outputs should match with the outputs when tiling is manually disabled.",
+ )
+
+ def test_enable_disable_slicing(self):
+ init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
+
+ torch.manual_seed(0)
+ model = self.model_class(**init_dict).to(torch_device)
+
+ inputs_dict.update({"return_dict": False})
+
+ torch.manual_seed(0)
+ output_without_slicing = model(**inputs_dict, generator=torch.manual_seed(0))[0]
+
+ torch.manual_seed(0)
+ model.enable_slicing()
+ output_with_slicing = model(**inputs_dict, generator=torch.manual_seed(0))[0]
+
+ self.assertLess(
+ (output_without_slicing.detach().cpu().numpy() - output_with_slicing.detach().cpu().numpy()).max(),
+ 0.05,
+ "VAE slicing should not affect the inference results",
+ )
+
+ torch.manual_seed(0)
+ model.disable_slicing()
+ output_without_slicing_2 = model(**inputs_dict, generator=torch.manual_seed(0))[0]
+
+ self.assertEqual(
+ output_without_slicing.detach().cpu().numpy().all(),
+ output_without_slicing_2.detach().cpu().numpy().all(),
+ "Without slicing outputs should match with the outputs when slicing is manually disabled.",
+ )
+
+ @unittest.skip("Gradient checkpointing has not been implemented yet")
+ def test_gradient_checkpointing_is_applied(self):
+ pass
+
+ @unittest.skip("Test not supported")
+ def test_forward_with_norm_groups(self):
+ pass
+
+ @unittest.skip("RuntimeError: fill_out not implemented for 'Float8_e4m3fn'")
+ def test_layerwise_casting_inference(self):
+ pass
+
+ @unittest.skip("RuntimeError: fill_out not implemented for 'Float8_e4m3fn'")
+ def test_layerwise_casting_training(self):
+ pass
diff --git a/tests/models/transformers/test_models_transformer_magi1.py b/tests/models/transformers/test_models_transformer_magi1.py
new file mode 100644
index 000000000000..ed8d775e6058
--- /dev/null
+++ b/tests/models/transformers/test_models_transformer_magi1.py
@@ -0,0 +1,91 @@
+# Copyright 2024 HuggingFace Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import unittest
+
+import torch
+
+from diffusers import Magi1Transformer3DModel
+from diffusers.utils.testing_utils import (
+ enable_full_determinism,
+ torch_device,
+)
+
+from ..test_modeling_common import ModelTesterMixin, TorchCompileTesterMixin
+
+
+enable_full_determinism()
+
+
+class Magi1Transformer3DTests(ModelTesterMixin, unittest.TestCase):
+ model_class = Magi1Transformer3DModel
+ main_input_name = "hidden_states"
+ uses_custom_attn_processor = True
+
+ @property
+ def dummy_input(self):
+ batch_size = 1
+ num_channels = 4
+ num_frames = 2
+ height = 16
+ width = 16
+ text_encoder_embedding_dim = 16
+ sequence_length = 12
+
+ hidden_states = torch.randn((batch_size, num_channels, num_frames, height, width)).to(torch_device)
+ timestep = torch.randint(0, 1000, size=(batch_size,)).to(torch_device)
+ encoder_hidden_states = torch.randn((batch_size, sequence_length, text_encoder_embedding_dim)).to(torch_device)
+
+ return {
+ "hidden_states": hidden_states,
+ "encoder_hidden_states": encoder_hidden_states,
+ "timestep": timestep,
+ }
+
+ @property
+ def input_shape(self):
+ return (4, 1, 16, 16)
+
+ @property
+ def output_shape(self):
+ return (4, 1, 16, 16)
+
+ def prepare_init_args_and_inputs_for_common(self):
+ init_dict = {
+ "patch_size": (1, 2, 2),
+ "num_attention_heads": 2,
+ "attention_head_dim": 12,
+ "in_channels": 4,
+ "out_channels": 4,
+ "cross_attention_dim": 16,
+ "freq_dim": 256,
+ "ffn_dim": 32,
+ "num_layers": 2,
+ "cross_attn_norm": True,
+ "qk_norm": "rms_norm_across_heads",
+ "rope_max_seq_len": 32,
+ }
+ inputs_dict = self.dummy_input
+ return init_dict, inputs_dict
+
+ def test_gradient_checkpointing_is_applied(self):
+ expected_set = {"Magi1Transformer3DModel"}
+ super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
+
+
+class Magi1TransformerCompileTests(TorchCompileTesterMixin, unittest.TestCase):
+ model_class = Magi1Transformer3DModel
+
+ def prepare_init_args_and_inputs_for_common(self):
+ return Magi1Transformer3DTests().prepare_init_args_and_inputs_for_common()
diff --git a/tests/pipelines/magi1/test_magi1.py b/tests/pipelines/magi1/test_magi1.py
new file mode 100644
index 000000000000..3695bcbe5be7
--- /dev/null
+++ b/tests/pipelines/magi1/test_magi1.py
@@ -0,0 +1,158 @@
+# Copyright 2024 The HuggingFace Team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import gc
+import unittest
+
+import numpy as np
+import torch
+from transformers import AutoTokenizer, T5EncoderModel
+
+from diffusers import AutoencoderKLMagi1, FlowMatchEulerDiscreteScheduler, Magi1Pipeline, Magi1Transformer3DModel
+from diffusers.utils.testing_utils import (
+ backend_empty_cache,
+ enable_full_determinism,
+ require_torch_accelerator,
+ slow,
+ torch_device,
+)
+
+from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS
+from ..test_pipelines_common import (
+ PipelineTesterMixin,
+)
+
+
+enable_full_determinism()
+
+
+class Magi1PipelineFastTests(PipelineTesterMixin, unittest.TestCase):
+ pipeline_class = Magi1Pipeline
+ params = TEXT_TO_IMAGE_PARAMS - {"cross_attention_kwargs"}
+ batch_params = TEXT_TO_IMAGE_BATCH_PARAMS
+ image_params = TEXT_TO_IMAGE_IMAGE_PARAMS
+ image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS
+ required_optional_params = frozenset(
+ [
+ "num_inference_steps",
+ "generator",
+ "latents",
+ "return_dict",
+ "callback_on_step_end",
+ "callback_on_step_end_tensor_inputs",
+ ]
+ )
+ test_xformers_attention = False
+ supports_dduf = False
+
+ def get_dummy_components(self):
+ torch.manual_seed(0)
+ vae = AutoencoderKLMagi1(
+ base_dim=3,
+ z_dim=16,
+ dim_mult=[1, 1, 1, 1],
+ num_res_blocks=1,
+ temperal_downsample=[False, True, True],
+ )
+
+ torch.manual_seed(0)
+ # TODO: impl FlowDPMSolverMultistepScheduler
+ scheduler = FlowMatchEulerDiscreteScheduler(shift=7.0)
+ text_encoder = T5EncoderModel.from_pretrained("hf-internal-testing/tiny-random-t5")
+ tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-t5")
+
+ torch.manual_seed(0)
+ transformer = Magi1Transformer3DModel(
+ patch_size=(1, 2, 2),
+ num_attention_heads=2,
+ attention_head_dim=12,
+ in_channels=16,
+ out_channels=16,
+ text_dim=32,
+ freq_dim=256,
+ ffn_dim=32,
+ num_layers=2,
+ cross_attn_norm=True,
+ qk_norm="rms_norm_across_heads",
+ rope_max_seq_len=32,
+ )
+
+ components = {
+ "transformer": transformer,
+ "vae": vae,
+ "scheduler": scheduler,
+ "text_encoder": text_encoder,
+ "tokenizer": tokenizer,
+ }
+ return components
+
+ def get_dummy_inputs(self, device, seed=0):
+ if str(device).startswith("mps"):
+ generator = torch.manual_seed(seed)
+ else:
+ generator = torch.Generator(device=device).manual_seed(seed)
+ inputs = {
+ "prompt": "dance monkey",
+ "negative_prompt": "negative", # TODO
+ "generator": generator,
+ "num_inference_steps": 2,
+ "guidance_scale": 6.0,
+ "height": 16,
+ "width": 16,
+ "num_frames": 9,
+ "max_sequence_length": 16,
+ "output_type": "pt",
+ }
+ return inputs
+
+ def test_inference(self):
+ device = "cpu"
+
+ components = self.get_dummy_components()
+ pipe = self.pipeline_class(**components)
+ pipe.to(device)
+ pipe.set_progress_bar_config(disable=None)
+
+ inputs = self.get_dummy_inputs(device)
+ video = pipe(**inputs).frames
+ generated_video = video[0]
+
+ self.assertEqual(generated_video.shape, (9, 3, 16, 16))
+ expected_video = torch.randn(9, 3, 16, 16)
+ max_diff = np.abs(generated_video - expected_video).max()
+ self.assertLessEqual(max_diff, 1e10)
+
+ @unittest.skip("Test not supported")
+ def test_attention_slicing_forward_pass(self):
+ pass
+
+
+@slow
+@require_torch_accelerator
+class Magi1PipelineIntegrationTests(unittest.TestCase):
+ prompt = "A painting of a squirrel eating a burger."
+
+ def setUp(self):
+ super().setUp()
+ gc.collect()
+ backend_empty_cache(torch_device)
+
+ def tearDown(self):
+ super().tearDown()
+ gc.collect()
+ backend_empty_cache(torch_device)
+
+ @unittest.skip("TODO: test needs to be implemented")
+ def test_Magi1(self):
+ pass
diff --git a/tests/pipelines/magi1/test_magi1_image_to_video.py b/tests/pipelines/magi1/test_magi1_image_to_video.py
new file mode 100644
index 000000000000..f841f60490cd
--- /dev/null
+++ b/tests/pipelines/magi1/test_magi1_image_to_video.py
@@ -0,0 +1,212 @@
+# Copyright 2024 The HuggingFace Team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import unittest
+
+import numpy as np
+import PIL
+import torch
+from transformers import AutoTokenizer, CLIPVisionModel, T5EncoderModel
+
+from diffusers import (
+ AutoencoderKLMagi1,
+ FlowMatchEulerDiscreteScheduler,
+ Magi1ImageToVideoPipeline,
+ Magi1Transformer3DModel,
+)
+from diffusers.utils.testing_utils import (
+ enable_full_determinism,
+)
+
+from ..pipeline_params import (
+ TEXT_TO_IMAGE_BATCH_PARAMS,
+ TEXT_TO_IMAGE_IMAGE_PARAMS,
+ TEXT_TO_IMAGE_PARAMS,
+)
+from ..test_pipelines_common import (
+ PipelineTesterMixin,
+)
+
+
+enable_full_determinism()
+
+
+class Magi1ImageToVideoPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
+ pipeline_class = Magi1ImageToVideoPipeline
+ params = TEXT_TO_IMAGE_PARAMS - {"cross_attention_kwargs"}
+ batch_params = TEXT_TO_IMAGE_BATCH_PARAMS
+ image_params = TEXT_TO_IMAGE_IMAGE_PARAMS
+ image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS
+ required_optional_params = frozenset(
+ [
+ "num_inference_steps",
+ "generator",
+ "latents",
+ "return_dict",
+ "callback_on_step_end",
+ "callback_on_step_end_tensor_inputs",
+ ]
+ )
+ test_xformers_attention = False
+ supports_dduf = False
+
+ def get_dummy_components(self):
+ torch.manual_seed(0)
+ vae = AutoencoderKLMagi1(
+ base_dim=3,
+ z_dim=16,
+ dim_mult=[1, 1, 1, 1],
+ num_res_blocks=1,
+ temperal_downsample=[False, True, True],
+ )
+
+ torch.manual_seed(0)
+ # TODO: impl FlowDPMSolverMultistepScheduler
+ scheduler = FlowMatchEulerDiscreteScheduler(shift=7.0)
+ text_encoder = T5EncoderModel.from_pretrained("hf-internal-testing/tiny-random-t5")
+ tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-t5")
+ image_encoder = CLIPVisionModel.from_pretrained("hf-internal-testing/tiny-random-clip")
+
+ torch.manual_seed(0)
+ transformer = Magi1Transformer3DModel(
+ patch_size=(1, 2, 2),
+ num_attention_heads=2,
+ attention_head_dim=12,
+ in_channels=16,
+ out_channels=16,
+ text_dim=32,
+ freq_dim=256,
+ ffn_dim=32,
+ num_layers=2,
+ cross_attn_norm=True,
+ qk_norm="rms_norm_across_heads",
+ rope_max_seq_len=32,
+ )
+
+ components = {
+ "transformer": transformer,
+ "vae": vae,
+ "scheduler": scheduler,
+ "text_encoder": text_encoder,
+ "tokenizer": tokenizer,
+ "image_encoder": image_encoder,
+ }
+ return components
+
+ def get_dummy_inputs(self, device, seed=0):
+ if str(device).startswith("mps"):
+ generator = torch.manual_seed(seed)
+ else:
+ generator = torch.Generator(device=device).manual_seed(seed)
+ inputs = {
+ "prompt": "dance monkey",
+ "negative_prompt": "negative", # TODO
+ "generator": generator,
+ "num_inference_steps": 2,
+ "guidance_scale": 6.0,
+ "height": 16,
+ "width": 16,
+ "num_frames": 9,
+ "max_sequence_length": 16,
+ "output_type": "pt",
+ "image": PIL.Image.new("RGB", (16, 16)),
+ }
+ return inputs
+
+ def test_inference(self):
+ device = "cpu"
+
+ components = self.get_dummy_components()
+ pipe = self.pipeline_class(**components)
+ pipe.to(device)
+ pipe.set_progress_bar_config(disable=None)
+
+ inputs = self.get_dummy_inputs(device)
+ video = pipe(**inputs).frames
+ generated_video = video[0]
+
+ self.assertEqual(generated_video.shape, (9, 3, 16, 16))
+ expected_video = torch.randn(9, 3, 16, 16)
+ max_diff = np.abs(generated_video - expected_video).max()
+ self.assertLessEqual(max_diff, 1e10)
+
+ @unittest.skip("Test not supported")
+ def test_attention_slicing_forward_pass(self):
+ pass
+
+
+class MagiFLFToVideoPipelineFastTests(Magi1ImageToVideoPipelineFastTests):
+ def get_dummy_components(self):
+ torch.manual_seed(0)
+ vae = AutoencoderKLMagi1(
+ base_dim=3,
+ z_dim=16,
+ dim_mult=[1, 1, 1, 1],
+ num_res_blocks=1,
+ temperal_downsample=[False, True, True],
+ )
+
+ torch.manual_seed(0)
+ # TODO: impl FlowDPMSolverMultistepScheduler
+ scheduler = FlowMatchEulerDiscreteScheduler(shift=7.0)
+ text_encoder = T5EncoderModel.from_pretrained("hf-internal-testing/tiny-random-t5")
+ tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-t5")
+ image_encoder = CLIPVisionModel.from_pretrained("hf-internal-testing/tiny-random-clip")
+
+ torch.manual_seed(0)
+ transformer = Magi1Transformer3DModel(
+ patch_size=(1, 2, 2),
+ num_attention_heads=2,
+ attention_head_dim=12,
+ in_channels=16,
+ out_channels=16,
+ text_dim=32,
+ freq_dim=256,
+ ffn_dim=32,
+ num_layers=2,
+ cross_attn_norm=True,
+ qk_norm="rms_norm_across_heads",
+ rope_max_seq_len=32,
+ )
+
+ components = {
+ "transformer": transformer,
+ "vae": vae,
+ "scheduler": scheduler,
+ "text_encoder": text_encoder,
+ "tokenizer": tokenizer,
+ "image_encoder": image_encoder,
+ }
+ return components
+
+ def get_dummy_inputs(self, device, seed=0):
+ if str(device).startswith("mps"):
+ generator = torch.manual_seed(seed)
+ else:
+ generator = torch.Generator(device=device).manual_seed(seed)
+ inputs = {
+ "prompt": "dance monkey",
+ "negative_prompt": "negative", # TODO
+ "generator": generator,
+ "num_inference_steps": 2,
+ "guidance_scale": 6.0,
+ "height": 16,
+ "width": 16,
+ "num_frames": 9,
+ "max_sequence_length": 16,
+ "output_type": "pt",
+ "image": PIL.Image.new("RGB", (16, 16)),
+ "last_image": PIL.Image.new("RGB", (16, 16)),
+ }
+ return inputs
diff --git a/tests/pipelines/magi1/test_magi1_video_to_video.py b/tests/pipelines/magi1/test_magi1_video_to_video.py
new file mode 100644
index 000000000000..9b6607143eaa
--- /dev/null
+++ b/tests/pipelines/magi1/test_magi1_video_to_video.py
@@ -0,0 +1,147 @@
+# Copyright 2024 The HuggingFace Team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import unittest
+
+import numpy as np
+import torch
+from transformers import AutoTokenizer, T5EncoderModel
+
+from diffusers import AutoencoderKLMagi1, Magi1Transformer3DModel, Magi1VideoToVideoPipeline, UniPCMultistepScheduler
+from diffusers.utils.testing_utils import (
+ enable_full_determinism,
+)
+
+from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS
+from ..test_pipelines_common import (
+ PipelineTesterMixin,
+)
+
+
+enable_full_determinism()
+
+
+class Magi1VideoToVideoPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
+ pipeline_class = Magi1VideoToVideoPipeline
+ params = TEXT_TO_IMAGE_PARAMS - {"cross_attention_kwargs"}
+ batch_params = TEXT_TO_IMAGE_BATCH_PARAMS
+ image_params = TEXT_TO_IMAGE_IMAGE_PARAMS
+ image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS
+ required_optional_params = frozenset(
+ [
+ "num_inference_steps",
+ "generator",
+ "latents",
+ "return_dict",
+ "callback_on_step_end",
+ "callback_on_step_end_tensor_inputs",
+ ]
+ )
+ test_xformers_attention = False
+ supports_dduf = False
+
+ def get_dummy_components(self):
+ torch.manual_seed(0)
+ vae = AutoencoderKLMagi1(
+ base_dim=3,
+ z_dim=16,
+ dim_mult=[1, 1, 1, 1],
+ num_res_blocks=1,
+ temperal_downsample=[False, True, True],
+ )
+
+ torch.manual_seed(0)
+ scheduler = UniPCMultistepScheduler(
+ beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000
+ )
+ text_encoder = T5EncoderModel.from_pretrained("hf-internal-testing/tiny-random-t5")
+ tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-t5")
+
+ torch.manual_seed(0)
+ transformer = Magi1Transformer3DModel(
+ patch_size=(1, 2, 2),
+ num_attention_heads=2,
+ attention_head_dim=12,
+ in_channels=16,
+ out_channels=16,
+ text_dim=32,
+ freq_dim=256,
+ ffn_dim=32,
+ num_layers=2,
+ cross_attn_norm=True,
+ qk_norm="rms_norm_across_heads",
+ rope_max_seq_len=32,
+ )
+
+ components = {
+ "transformer": transformer,
+ "vae": vae,
+ "scheduler": scheduler,
+ "text_encoder": text_encoder,
+ "tokenizer": tokenizer,
+ }
+ return components
+
+ def get_dummy_inputs(self, device, seed=0):
+ if str(device).startswith("mps"):
+ generator = torch.manual_seed(seed)
+ else:
+ generator = torch.Generator(device=device).manual_seed(seed)
+ inputs = {
+ "prompt": "dance monkey",
+ "negative_prompt": "negative", # TODO
+ "generator": generator,
+ "num_inference_steps": 2,
+ "guidance_scale": 6.0,
+ "height": 16,
+ "width": 16,
+ "num_frames": 9,
+ "max_sequence_length": 16,
+ "output_type": "pt",
+ "video": torch.randn((1, 3, 9, 16, 16)),
+ }
+ return inputs
+
+ def test_inference(self):
+ device = "cpu"
+
+ components = self.get_dummy_components()
+ pipe = self.pipeline_class(**components)
+ pipe.to(device)
+ pipe.set_progress_bar_config(disable=None)
+
+ inputs = self.get_dummy_inputs(device)
+ video = pipe(**inputs).frames
+ generated_video = video[0]
+
+ self.assertEqual(generated_video.shape, (9, 3, 16, 16))
+ expected_video = torch.randn(9, 3, 16, 16)
+ max_diff = np.abs(generated_video - expected_video).max()
+ self.assertLessEqual(max_diff, 1e10)
+
+ @unittest.skip("Test not supported")
+ def test_attention_slicing_forward_pass(self):
+ pass
+
+ @unittest.skip(
+ "Magi1VideoToVideoPipeline has to run in mixed precision. Casting the entire pipeline will result in errors"
+ )
+ def test_model_cpu_offload_forward_pass(self):
+ pass
+
+ @unittest.skip(
+ "Magi1VideoToVideoPipeline has to run in mixed precision. Save/Load the entire pipeline in FP16 will result in errors"
+ )
+ def test_save_load_float16(self):
+ pass
diff --git a/tests/single_file/test_model_magi_autoencoder_single_file.py b/tests/single_file/test_model_magi_autoencoder_single_file.py
new file mode 100644
index 000000000000..8721d884fa25
--- /dev/null
+++ b/tests/single_file/test_model_magi_autoencoder_single_file.py
@@ -0,0 +1,61 @@
+# coding=utf-8
+# Copyright 2024 HuggingFace Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import unittest
+
+import torch
+
+from diffusers import AutoencoderKLMagi1
+from diffusers.utils.testing_utils import (
+ require_torch_gpu,
+ slow,
+ torch_device,
+)
+
+
+class AutoencoderKLMagiSingleFileTests(unittest.TestCase):
+ model_class = AutoencoderKLMagi1
+ ckpt_path = "https://huggingface.co/sand-ai/MAGI-1/blob/main/vae/diffusion_pytorch_model.safetensors"
+ repo_id = "sand-ai/MAGI-1"
+
+ @slow
+ @require_torch_gpu
+ def test_single_file_components(self):
+ model = self.model_class.from_single_file(self.ckpt_path)
+ model.to(torch_device)
+
+ batch_size = 1
+ num_frames = 2
+ num_channels = 3
+ sizes = (16, 16)
+ image = torch.randn((batch_size, num_channels, num_frames) + sizes).to(torch_device)
+
+ with torch.no_grad():
+ model(image, return_dict=False)
+
+ @slow
+ @require_torch_gpu
+ def test_single_file_components_from_hub(self):
+ model = self.model_class.from_pretrained(self.repo_id, subfolder="vae")
+ model.to(torch_device)
+
+ batch_size = 1
+ num_frames = 2
+ num_channels = 3
+ sizes = (16, 16)
+ image = torch.randn((batch_size, num_channels, num_frames) + sizes).to(torch_device)
+
+ with torch.no_grad():
+ model(image, return_dict=False)
diff --git a/tests/single_file/test_model_magi_transformer3d_single_file.py b/tests/single_file/test_model_magi_transformer3d_single_file.py
new file mode 100644
index 000000000000..fb6b0ae04622
--- /dev/null
+++ b/tests/single_file/test_model_magi_transformer3d_single_file.py
@@ -0,0 +1,83 @@
+# coding=utf-8
+# Copyright 2024 HuggingFace Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import unittest
+
+import torch
+
+from diffusers import Magi1Transformer3DModel
+from diffusers.utils.testing_utils import (
+ require_torch_gpu,
+ slow,
+ torch_device,
+)
+
+
+class Magi1Transformer3DModelText2VideoSingleFileTest(unittest.TestCase):
+ model_class = Magi1Transformer3DModel
+ ckpt_path = "https://huggingface.co/sand-ai/MAGI-1/blob/main/transformer/diffusion_pytorch_model.safetensors"
+ repo_id = "sand-ai/MAGI-1"
+
+ @slow
+ @require_torch_gpu
+ def test_single_file_components(self):
+ model = self.model_class.from_single_file(self.ckpt_path)
+ model.to(torch_device)
+
+ batch_size = 1
+ num_channels = 4
+ num_frames = 2
+ height = 16
+ width = 16
+ text_encoder_embedding_dim = 16
+ sequence_length = 12
+
+ hidden_states = torch.randn((batch_size, num_channels, num_frames, height, width)).to(torch_device)
+ timestep = torch.randint(0, 1000, size=(batch_size,)).to(torch_device)
+ encoder_hidden_states = torch.randn((batch_size, sequence_length, text_encoder_embedding_dim)).to(torch_device)
+
+ with torch.no_grad():
+ model(
+ hidden_states=hidden_states,
+ timestep=timestep,
+ encoder_hidden_states=encoder_hidden_states,
+ return_dict=False,
+ )
+
+ @slow
+ @require_torch_gpu
+ def test_single_file_components_from_hub(self):
+ model = self.model_class.from_pretrained(self.repo_id, subfolder="transformer")
+ model.to(torch_device)
+
+ batch_size = 1
+ num_channels = 4
+ num_frames = 2
+ height = 16
+ width = 16
+ text_encoder_embedding_dim = 16
+ sequence_length = 12
+
+ hidden_states = torch.randn((batch_size, num_channels, num_frames, height, width)).to(torch_device)
+ timestep = torch.randint(0, 1000, size=(batch_size,)).to(torch_device)
+ encoder_hidden_states = torch.randn((batch_size, sequence_length, text_encoder_embedding_dim)).to(torch_device)
+
+ with torch.no_grad():
+ model(
+ hidden_states=hidden_states,
+ timestep=timestep,
+ encoder_hidden_states=encoder_hidden_states,
+ return_dict=False,
+ )