diff --git a/src/diffusers/loaders/single_file_model.py b/src/diffusers/loaders/single_file_model.py index b7d61b3e8ff4..f72a0dd369f2 100644 --- a/src/diffusers/loaders/single_file_model.py +++ b/src/diffusers/loaders/single_file_model.py @@ -37,6 +37,7 @@ convert_ltx_vae_checkpoint_to_diffusers, convert_lumina2_to_diffusers, convert_mochi_transformer_checkpoint_to_diffusers, + convert_sana_transformer_to_diffusers, convert_sd3_transformer_checkpoint_to_diffusers, convert_stable_cascade_unet_single_file_to_diffusers, convert_wan_transformer_to_diffusers, @@ -119,6 +120,10 @@ "checkpoint_mapping_fn": convert_lumina2_to_diffusers, "default_subfolder": "transformer", }, + "SanaTransformer2DModel": { + "checkpoint_mapping_fn": convert_sana_transformer_to_diffusers, + "default_subfolder": "transformer", + }, "WanTransformer3DModel": { "checkpoint_mapping_fn": convert_wan_transformer_to_diffusers, "default_subfolder": "transformer", diff --git a/src/diffusers/loaders/single_file_utils.py b/src/diffusers/loaders/single_file_utils.py index 8ee7e14cb101..42aee4a84822 100644 --- a/src/diffusers/loaders/single_file_utils.py +++ b/src/diffusers/loaders/single_file_utils.py @@ -117,6 +117,12 @@ "hunyuan-video": "txt_in.individual_token_refiner.blocks.0.adaLN_modulation.1.bias", "instruct-pix2pix": "model.diffusion_model.input_blocks.0.0.weight", "lumina2": ["model.diffusion_model.cap_embedder.0.weight", "cap_embedder.0.weight"], + "sana": [ + "blocks.0.cross_attn.q_linear.weight", + "blocks.0.cross_attn.q_linear.bias", + "blocks.0.cross_attn.kv_linear.weight", + "blocks.0.cross_attn.kv_linear.bias", + ], "wan": ["model.diffusion_model.head.modulation", "head.modulation"], "wan_vae": "decoder.middle.0.residual.0.gamma", } @@ -178,6 +184,7 @@ "hunyuan-video": {"pretrained_model_name_or_path": "hunyuanvideo-community/HunyuanVideo"}, "instruct-pix2pix": {"pretrained_model_name_or_path": "timbrooks/instruct-pix2pix"}, "lumina2": {"pretrained_model_name_or_path": "Alpha-VLLM/Lumina-Image-2.0"}, + "sana": {"pretrained_model_name_or_path": "Efficient-Large-Model/Sana_1600M_1024px_diffusers"}, "wan-t2v-1.3B": {"pretrained_model_name_or_path": "Wan-AI/Wan2.1-T2V-1.3B-Diffusers"}, "wan-t2v-14B": {"pretrained_model_name_or_path": "Wan-AI/Wan2.1-T2V-14B-Diffusers"}, "wan-i2v-14B": {"pretrained_model_name_or_path": "Wan-AI/Wan2.1-I2V-14B-480P-Diffusers"}, @@ -669,6 +676,9 @@ def infer_diffusers_model_type(checkpoint): elif any(key in checkpoint for key in CHECKPOINT_KEY_NAMES["lumina2"]): model_type = "lumina2" + elif any(key in checkpoint for key in CHECKPOINT_KEY_NAMES["sana"]): + model_type = "sana" + elif any(key in checkpoint for key in CHECKPOINT_KEY_NAMES["wan"]): if "model.diffusion_model.patch_embedding.weight" in checkpoint: target_key = "model.diffusion_model.patch_embedding.weight" @@ -2897,6 +2907,111 @@ def convert_lumina_attn_to_diffusers(tensor, diffusers_key): return converted_state_dict +def convert_sana_transformer_to_diffusers(checkpoint, **kwargs): + converted_state_dict = {} + keys = list(checkpoint.keys()) + for k in keys: + if "model.diffusion_model." in k: + checkpoint[k.replace("model.diffusion_model.", "")] = checkpoint.pop(k) + + num_layers = list(set(int(k.split(".", 2)[1]) for k in checkpoint if "blocks" in k))[-1] + 1 # noqa: C401 + + # Positional and patch embeddings. + checkpoint.pop("pos_embed") + converted_state_dict["patch_embed.proj.weight"] = checkpoint.pop("x_embedder.proj.weight") + converted_state_dict["patch_embed.proj.bias"] = checkpoint.pop("x_embedder.proj.bias") + + # Timestep embeddings. + converted_state_dict["time_embed.emb.timestep_embedder.linear_1.weight"] = checkpoint.pop( + "t_embedder.mlp.0.weight" + ) + converted_state_dict["time_embed.emb.timestep_embedder.linear_1.bias"] = checkpoint.pop("t_embedder.mlp.0.bias") + converted_state_dict["time_embed.emb.timestep_embedder.linear_2.weight"] = checkpoint.pop( + "t_embedder.mlp.2.weight" + ) + converted_state_dict["time_embed.emb.timestep_embedder.linear_2.bias"] = checkpoint.pop("t_embedder.mlp.2.bias") + converted_state_dict["time_embed.linear.weight"] = checkpoint.pop("t_block.1.weight") + converted_state_dict["time_embed.linear.bias"] = checkpoint.pop("t_block.1.bias") + + # Caption Projection. + checkpoint.pop("y_embedder.y_embedding") + converted_state_dict["caption_projection.linear_1.weight"] = checkpoint.pop("y_embedder.y_proj.fc1.weight") + converted_state_dict["caption_projection.linear_1.bias"] = checkpoint.pop("y_embedder.y_proj.fc1.bias") + converted_state_dict["caption_projection.linear_2.weight"] = checkpoint.pop("y_embedder.y_proj.fc2.weight") + converted_state_dict["caption_projection.linear_2.bias"] = checkpoint.pop("y_embedder.y_proj.fc2.bias") + converted_state_dict["caption_norm.weight"] = checkpoint.pop("attention_y_norm.weight") + + for i in range(num_layers): + converted_state_dict[f"transformer_blocks.{i}.scale_shift_table"] = checkpoint.pop( + f"blocks.{i}.scale_shift_table" + ) + + # Self-Attention + sample_q, sample_k, sample_v = torch.chunk(checkpoint.pop(f"blocks.{i}.attn.qkv.weight"), 3, dim=0) + converted_state_dict[f"transformer_blocks.{i}.attn1.to_q.weight"] = torch.cat([sample_q]) + converted_state_dict[f"transformer_blocks.{i}.attn1.to_k.weight"] = torch.cat([sample_k]) + converted_state_dict[f"transformer_blocks.{i}.attn1.to_v.weight"] = torch.cat([sample_v]) + + # Output Projections + converted_state_dict[f"transformer_blocks.{i}.attn1.to_out.0.weight"] = checkpoint.pop( + f"blocks.{i}.attn.proj.weight" + ) + converted_state_dict[f"transformer_blocks.{i}.attn1.to_out.0.bias"] = checkpoint.pop( + f"blocks.{i}.attn.proj.bias" + ) + + # Cross-Attention + converted_state_dict[f"transformer_blocks.{i}.attn2.to_q.weight"] = checkpoint.pop( + f"blocks.{i}.cross_attn.q_linear.weight" + ) + converted_state_dict[f"transformer_blocks.{i}.attn2.to_q.bias"] = checkpoint.pop( + f"blocks.{i}.cross_attn.q_linear.bias" + ) + + linear_sample_k, linear_sample_v = torch.chunk( + checkpoint.pop(f"blocks.{i}.cross_attn.kv_linear.weight"), 2, dim=0 + ) + linear_sample_k_bias, linear_sample_v_bias = torch.chunk( + checkpoint.pop(f"blocks.{i}.cross_attn.kv_linear.bias"), 2, dim=0 + ) + converted_state_dict[f"transformer_blocks.{i}.attn2.to_k.weight"] = linear_sample_k + converted_state_dict[f"transformer_blocks.{i}.attn2.to_v.weight"] = linear_sample_v + converted_state_dict[f"transformer_blocks.{i}.attn2.to_k.bias"] = linear_sample_k_bias + converted_state_dict[f"transformer_blocks.{i}.attn2.to_v.bias"] = linear_sample_v_bias + + # Output Projections + converted_state_dict[f"transformer_blocks.{i}.attn2.to_out.0.weight"] = checkpoint.pop( + f"blocks.{i}.cross_attn.proj.weight" + ) + converted_state_dict[f"transformer_blocks.{i}.attn2.to_out.0.bias"] = checkpoint.pop( + f"blocks.{i}.cross_attn.proj.bias" + ) + + # MLP + converted_state_dict[f"transformer_blocks.{i}.ff.conv_inverted.weight"] = checkpoint.pop( + f"blocks.{i}.mlp.inverted_conv.conv.weight" + ) + converted_state_dict[f"transformer_blocks.{i}.ff.conv_inverted.bias"] = checkpoint.pop( + f"blocks.{i}.mlp.inverted_conv.conv.bias" + ) + converted_state_dict[f"transformer_blocks.{i}.ff.conv_depth.weight"] = checkpoint.pop( + f"blocks.{i}.mlp.depth_conv.conv.weight" + ) + converted_state_dict[f"transformer_blocks.{i}.ff.conv_depth.bias"] = checkpoint.pop( + f"blocks.{i}.mlp.depth_conv.conv.bias" + ) + converted_state_dict[f"transformer_blocks.{i}.ff.conv_point.weight"] = checkpoint.pop( + f"blocks.{i}.mlp.point_conv.conv.weight" + ) + + # Final layer + converted_state_dict["proj_out.weight"] = checkpoint.pop("final_layer.linear.weight") + converted_state_dict["proj_out.bias"] = checkpoint.pop("final_layer.linear.bias") + converted_state_dict["scale_shift_table"] = checkpoint.pop("final_layer.scale_shift_table") + + return converted_state_dict + + def convert_wan_transformer_to_diffusers(checkpoint, **kwargs): converted_state_dict = {} diff --git a/src/diffusers/models/transformers/sana_transformer.py b/src/diffusers/models/transformers/sana_transformer.py index cface676b409..b8cc96d3532c 100644 --- a/src/diffusers/models/transformers/sana_transformer.py +++ b/src/diffusers/models/transformers/sana_transformer.py @@ -18,7 +18,7 @@ from torch import nn from ...configuration_utils import ConfigMixin, register_to_config -from ...loaders import PeftAdapterMixin +from ...loaders import FromOriginalModelMixin, PeftAdapterMixin from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers from ..attention_processor import ( Attention, @@ -195,7 +195,7 @@ def forward( return hidden_states -class SanaTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin): +class SanaTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin): r""" A 2D Transformer model introduced in [Sana](https://huggingface.co/papers/2410.10629) family of models. diff --git a/tests/single_file/test_sana_transformer.py b/tests/single_file/test_sana_transformer.py new file mode 100644 index 000000000000..7695e1577711 --- /dev/null +++ b/tests/single_file/test_sana_transformer.py @@ -0,0 +1,61 @@ +import gc +import unittest + +import torch + +from diffusers import ( + SanaTransformer2DModel, +) +from diffusers.utils.testing_utils import ( + backend_empty_cache, + enable_full_determinism, + require_torch_accelerator, + torch_device, +) + + +enable_full_determinism() + + +@require_torch_accelerator +class SanaTransformer2DModelSingleFileTests(unittest.TestCase): + model_class = SanaTransformer2DModel + ckpt_path = ( + "https://huggingface.co/Efficient-Large-Model/Sana_1600M_1024px/blob/main/checkpoints/Sana_1600M_1024px.pth" + ) + alternate_keys_ckpt_paths = [ + "https://huggingface.co/Efficient-Large-Model/Sana_1600M_1024px/blob/main/checkpoints/Sana_1600M_1024px.pth" + ] + + repo_id = "Efficient-Large-Model/Sana_1600M_1024px_diffusers" + + def setUp(self): + super().setUp() + gc.collect() + backend_empty_cache(torch_device) + + def tearDown(self): + super().tearDown() + gc.collect() + backend_empty_cache(torch_device) + + def test_single_file_components(self): + model = self.model_class.from_pretrained(self.repo_id, subfolder="transformer") + model_single_file = self.model_class.from_single_file(self.ckpt_path) + + PARAMS_TO_IGNORE = ["torch_dtype", "_name_or_path", "_use_default_values", "_diffusers_version"] + for param_name, param_value in model_single_file.config.items(): + if param_name in PARAMS_TO_IGNORE: + continue + assert ( + model.config[param_name] == param_value + ), f"{param_name} differs between single file loading and pretrained loading" + + def test_checkpoint_loading(self): + for ckpt_path in self.alternate_keys_ckpt_paths: + torch.cuda.empty_cache() + model = self.model_class.from_single_file(ckpt_path) + + del model + gc.collect() + torch.cuda.empty_cache()