From 7264cd342f2654be9e1b8de5441b6a4c2fd6cda6 Mon Sep 17 00:00:00 2001 From: ishan-modi Date: Mon, 3 Mar 2025 15:15:47 +0530 Subject: [PATCH 1/7] added support for from_single_file --- src/diffusers/models/transformers/sana_transformer.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/diffusers/models/transformers/sana_transformer.py b/src/diffusers/models/transformers/sana_transformer.py index cface676b409..4cb3c78fca87 100644 --- a/src/diffusers/models/transformers/sana_transformer.py +++ b/src/diffusers/models/transformers/sana_transformer.py @@ -18,7 +18,7 @@ from torch import nn from ...configuration_utils import ConfigMixin, register_to_config -from ...loaders import PeftAdapterMixin +from ...loaders import PeftAdapterMixin, FromOriginalModelMixin from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers from ..attention_processor import ( Attention, @@ -195,7 +195,7 @@ def forward( return hidden_states -class SanaTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin): +class SanaTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin): r""" A 2D Transformer model introduced in [Sana](https://huggingface.co/papers/2410.10629) family of models. From 4829c9e0ce1b1fc31187afae757d30ca026bd5ae Mon Sep 17 00:00:00 2001 From: ishan-modi Date: Mon, 3 Mar 2025 21:12:18 +0530 Subject: [PATCH 2/7] added diffusers mapping script --- src/diffusers/loaders/single_file_model.py | 5 ++ src/diffusers/loaders/single_file_utils.py | 83 ++++++++++++++++++++++ 2 files changed, 88 insertions(+) diff --git a/src/diffusers/loaders/single_file_model.py b/src/diffusers/loaders/single_file_model.py index e6b050833485..f85e6c30036c 100644 --- a/src/diffusers/loaders/single_file_model.py +++ b/src/diffusers/loaders/single_file_model.py @@ -36,6 +36,7 @@ convert_ltx_transformer_checkpoint_to_diffusers, convert_ltx_vae_checkpoint_to_diffusers, convert_lumina2_to_diffusers, + convert_sana_transformer_to_diffusers, convert_mochi_transformer_checkpoint_to_diffusers, convert_sd3_transformer_checkpoint_to_diffusers, convert_stable_cascade_unet_single_file_to_diffusers, @@ -117,6 +118,10 @@ "checkpoint_mapping_fn": convert_lumina2_to_diffusers, "default_subfolder": "transformer", }, + "SanaTransformer2DModel": { + "checkpoint_mapping_fn": convert_sana_transformer_to_diffusers, + "default_subfolder": "transformer", + } } diff --git a/src/diffusers/loaders/single_file_utils.py b/src/diffusers/loaders/single_file_utils.py index 59060efade8b..9fe2d4726a36 100644 --- a/src/diffusers/loaders/single_file_utils.py +++ b/src/diffusers/loaders/single_file_utils.py @@ -117,6 +117,12 @@ "hunyuan-video": "txt_in.individual_token_refiner.blocks.0.adaLN_modulation.1.bias", "instruct-pix2pix": "model.diffusion_model.input_blocks.0.0.weight", "lumina2": ["model.diffusion_model.cap_embedder.0.weight", "cap_embedder.0.weight"], + "sana": [ + "blocks.0.cross_attn.q_linear.weight", + "blocks.0.cross_attn.q_linear.bias", + "blocks.0.cross_attn.kv_linear.weight", + "blocks.0.cross_attn.kv_linear.bias" + ], } DIFFUSERS_DEFAULT_PIPELINE_PATHS = { @@ -176,6 +182,7 @@ "hunyuan-video": {"pretrained_model_name_or_path": "hunyuanvideo-community/HunyuanVideo"}, "instruct-pix2pix": {"pretrained_model_name_or_path": "timbrooks/instruct-pix2pix"}, "lumina2": {"pretrained_model_name_or_path": "Alpha-VLLM/Lumina-Image-2.0"}, + "sana": {"pretrained_model_name_or_path": "Efficient-Large-Model/Sana_1600M_1024px"}, } # Use to configure model sample size when original config is provided @@ -662,6 +669,9 @@ def infer_diffusers_model_type(checkpoint): elif any(key in checkpoint for key in CHECKPOINT_KEY_NAMES["lumina2"]): model_type = "lumina2" + elif any(key in checkpoint for key in CHECKPOINT_KEY_NAMES["sana"]): + model_type = "sana" + else: model_type = "v1" @@ -2857,3 +2867,76 @@ def convert_lumina_attn_to_diffusers(tensor, diffusers_key): converted_state_dict[diffusers_key] = checkpoint.pop(key) return converted_state_dict + + +def convert_sana_transformer_to_diffusers(checkpoint, **kwargs): + converted_state_dict = {} + keys = list(checkpoint.keys()) + for k in keys: + if "model.diffusion_model." in k: + checkpoint[k.replace("model.diffusion_model.", "")] = checkpoint.pop(k) + + num_layers = list(set(int(k.split(".", 2)[1]) for k in checkpoint if "blocks" in k))[-1] + 1 # noqa: C401 + + # Positional and patch embeddings. + checkpoint.pop("pos_embed") + converted_state_dict["patch_embed.proj.weight"] = checkpoint.pop("x_embedder.proj.weight") + converted_state_dict["patch_embed.proj.bias"] = checkpoint.pop("x_embedder.proj.bias") + + # Timestep embeddings. + converted_state_dict["time_embed.emb.timestep_embedder.linear_1.weight"] = checkpoint.pop("t_embedder.mlp.0.weight") + converted_state_dict["time_embed.emb.timestep_embedder.linear_1.bias"] = checkpoint.pop("t_embedder.mlp.0.bias") + converted_state_dict["time_embed.emb.timestep_embedder.linear_2.weight"] = checkpoint.pop("t_embedder.mlp.2.weight") + converted_state_dict["time_embed.emb.timestep_embedder.linear_2.bias"] = checkpoint.pop("t_embedder.mlp.2.bias") + converted_state_dict["time_embed.linear.weight"] = checkpoint.pop("t_block.1.weight") + converted_state_dict["time_embed.linear.bias"] = checkpoint.pop("t_block.1.bias") + + # Caption Projection. + converted_state_dict["caption_proj.linear_1.weight"] = checkpoint.pop("y_embedder.y_proj.fc1.weight") + converted_state_dict["caption_proj.linear_1.bias"] = checkpoint.pop("y_embedder.y_proj.fc1.bias") + converted_state_dict["caption_proj.linear_2.weight"] = checkpoint.pop("y_embedder.y_proj.fc2.weight") + converted_state_dict["caption_proj.linear_2.bias"] = checkpoint.pop("y_embedder.y_proj.fc2.bias") + converted_state_dict["caption_norm.weight"] = checkpoint.pop("attention_y_norm.weight") + + + for i in range(num_layers): + converted_state_dict[f"transformer_blocks.{i}.scale_shift_table"] = checkpoint.pop(f"blocks.{i}.scale_shift_table") + + # Self-Attention + sample_q, sample_k, sample_v = torch.chunk(checkpoint.pop(f"blocks.{i}.attn.qkv.weight"), 3, dim=0) + converted_state_dict[f"transformer_blocks.{i}.attn1.to_q.weight"] = torch.cat([sample_q]) + converted_state_dict[f"transformer_blocks.{i}.attn1.to_k.weight"] = torch.cat([sample_k]) + converted_state_dict[f"transformer_blocks.{i}.attn1.to_v.weight"] = torch.cat([sample_v]) + + # Output Projections + converted_state_dict[f"transformer_blocks.{i}.attn1.to_out.0.weight"] = checkpoint.pop(f"blocks.{i}.attn.proj.weight") + converted_state_dict[f"transformer_blocks.{i}.attn1.to_out.0.bias"] = checkpoint.pop(f"blocks.{i}.attn.proj.bias") + + # Cross-Attention + converted_state_dict[f"transformer_blocks.{i}.attn2.to_q.weight"] = checkpoint.pop(f"blocks.{i}.cross_attn.q_linear.weight") + converted_state_dict[f"transformer_blocks.{i}.attn2.to_q.bias"] = checkpoint.pop(f"blocks.{i}.cross_attn.q_linear.bias") + + linear_sample_k, linear_sample_v = torch.chunk(checkpoint.pop(f"blocks.{i}.cross_attn.kv_linear.weight"), 2, dim=0) + linear_sample_k_bias, linear_sample_v_bias = torch.chunk(checkpoint.pop(f"blocks.{i}.cross_attn.kv_linear.bias"), 2, dim=0) + converted_state_dict[f"transformer_blocks.{i}.attn2.to_k.weight"] = linear_sample_k + converted_state_dict[f"transformer_blocks.{i}.attn2.to_v.weight"] = linear_sample_v + converted_state_dict[f"transformer_blocks.{i}.attn2.to_k.bias"] = linear_sample_k_bias + converted_state_dict[f"transformer_blocks.{i}.attn2.to_v.bias"] = linear_sample_v_bias + + # Output Projections + converted_state_dict[f"transformer_blocks.{i}.attn2.to_out.0.weight"] = checkpoint.pop(f"blocks.{i}.cross_attn.proj.weight") + converted_state_dict[f"transformer_blocks.{i}.attn2.to_out.0.bias"] = checkpoint.pop(f"blocks.{i}.cross_attn.proj.bias") + + # MLP + converted_state_dict[f"transformer_blocks.{i}.ff.conv_inverted.weight"] = checkpoint.pop(f"blocks.{i}.mlp.inverted_conv.conv.weight") + converted_state_dict[f"transformer_blocks.{i}.ff.conv_inverted.bias"] = checkpoint.pop(f"blocks.{i}.mlp.inverted_conv.conv.bias") + converted_state_dict[f"transformer_blocks.{i}.ff.conv_depth.weight"] = checkpoint.pop(f"blocks.{i}.mlp.depth_conv.conv.weight") + converted_state_dict[f"transformer_blocks.{i}.ff.conv_depth.bias"] = checkpoint.pop(f"blocks.{i}.mlp.depth_conv.conv.bias") + converted_state_dict[f"transformer_blocks.{i}.ff.conv_point.weight"] = checkpoint.pop(f"blocks.{i}.mlp.point_conv.conv.weight") + + # Final layer + converted_state_dict["proj_out.weight"] = checkpoint.pop("final_layer.linear.weight") + converted_state_dict["proj_out.bias"] = checkpoint.pop("final_layer.linear.bias") + converted_state_dict["scale_shift_table"] = checkpoint.pop("final_layer.scale_shift_table") + + return converted_state_dict \ No newline at end of file From a990c1c9e1cc06808ee2fece1813163b27f1da0b Mon Sep 17 00:00:00 2001 From: ishan-modi Date: Tue, 4 Mar 2025 11:47:44 +0530 Subject: [PATCH 3/7] added testcase --- src/diffusers/loaders/single_file_utils.py | 4 ++- tests/single_file/test_sana_transformer.py | 37 ++++++++++++++++++++++ 2 files changed, 40 insertions(+), 1 deletion(-) create mode 100644 tests/single_file/test_sana_transformer.py diff --git a/src/diffusers/loaders/single_file_utils.py b/src/diffusers/loaders/single_file_utils.py index 9fe2d4726a36..e2d4e6ce81cd 100644 --- a/src/diffusers/loaders/single_file_utils.py +++ b/src/diffusers/loaders/single_file_utils.py @@ -182,7 +182,7 @@ "hunyuan-video": {"pretrained_model_name_or_path": "hunyuanvideo-community/HunyuanVideo"}, "instruct-pix2pix": {"pretrained_model_name_or_path": "timbrooks/instruct-pix2pix"}, "lumina2": {"pretrained_model_name_or_path": "Alpha-VLLM/Lumina-Image-2.0"}, - "sana": {"pretrained_model_name_or_path": "Efficient-Large-Model/Sana_1600M_1024px"}, + "sana": {"pretrained_model_name_or_path": "Efficient-Large-Model/Sana_1600M_1024px_diffusers"}, } # Use to configure model sample size when original config is provided @@ -2878,6 +2878,7 @@ def convert_sana_transformer_to_diffusers(checkpoint, **kwargs): num_layers = list(set(int(k.split(".", 2)[1]) for k in checkpoint if "blocks" in k))[-1] + 1 # noqa: C401 + # Positional and patch embeddings. checkpoint.pop("pos_embed") converted_state_dict["patch_embed.proj.weight"] = checkpoint.pop("x_embedder.proj.weight") @@ -2892,6 +2893,7 @@ def convert_sana_transformer_to_diffusers(checkpoint, **kwargs): converted_state_dict["time_embed.linear.bias"] = checkpoint.pop("t_block.1.bias") # Caption Projection. + checkpoint.pop("y_embedder.y_embedding") converted_state_dict["caption_proj.linear_1.weight"] = checkpoint.pop("y_embedder.y_proj.fc1.weight") converted_state_dict["caption_proj.linear_1.bias"] = checkpoint.pop("y_embedder.y_proj.fc1.bias") converted_state_dict["caption_proj.linear_2.weight"] = checkpoint.pop("y_embedder.y_proj.fc2.weight") diff --git a/tests/single_file/test_sana_transformer.py b/tests/single_file/test_sana_transformer.py new file mode 100644 index 000000000000..a1ea87d3f891 --- /dev/null +++ b/tests/single_file/test_sana_transformer.py @@ -0,0 +1,37 @@ +import gc +import unittest + +import torch + +from diffusers import ( + SanaTransformer2DModel, +) +from diffusers.utils.testing_utils import ( + backend_empty_cache, + enable_full_determinism, + require_torch_accelerator, + torch_device, +) + + +enable_full_determinism() + + +@require_torch_accelerator +class SanaTransformer2DModelSingleFileTests(unittest.TestCase): + model_class = SanaTransformer2DModel + + repo_id = "Efficient-Large-Model/Sana_1600M_1024px_diffusers" + + def setUp(self): + super().setUp() + gc.collect() + backend_empty_cache(torch_device) + + def tearDown(self): + super().tearDown() + gc.collect() + backend_empty_cache(torch_device) + + def test_single_file_components(self): + _ = self.model_class.from_pretrained(self.repo_id, subfolder="transformer") From 17edc43970e3c4e434154cd6517417c359f2aa78 Mon Sep 17 00:00:00 2001 From: ishan-modi Date: Tue, 4 Mar 2025 14:16:46 +0530 Subject: [PATCH 4/7] bug fix --- src/diffusers/loaders/single_file_utils.py | 8 ++++---- tests/single_file/test_sana_transformer.py | 2 -- 2 files changed, 4 insertions(+), 6 deletions(-) diff --git a/src/diffusers/loaders/single_file_utils.py b/src/diffusers/loaders/single_file_utils.py index e2d4e6ce81cd..53af5bceb2f6 100644 --- a/src/diffusers/loaders/single_file_utils.py +++ b/src/diffusers/loaders/single_file_utils.py @@ -2894,10 +2894,10 @@ def convert_sana_transformer_to_diffusers(checkpoint, **kwargs): # Caption Projection. checkpoint.pop("y_embedder.y_embedding") - converted_state_dict["caption_proj.linear_1.weight"] = checkpoint.pop("y_embedder.y_proj.fc1.weight") - converted_state_dict["caption_proj.linear_1.bias"] = checkpoint.pop("y_embedder.y_proj.fc1.bias") - converted_state_dict["caption_proj.linear_2.weight"] = checkpoint.pop("y_embedder.y_proj.fc2.weight") - converted_state_dict["caption_proj.linear_2.bias"] = checkpoint.pop("y_embedder.y_proj.fc2.bias") + converted_state_dict["caption_projection.linear_1.weight"] = checkpoint.pop("y_embedder.y_proj.fc1.weight") + converted_state_dict["caption_projection.linear_1.bias"] = checkpoint.pop("y_embedder.y_proj.fc1.bias") + converted_state_dict["caption_projection.linear_2.weight"] = checkpoint.pop("y_embedder.y_proj.fc2.weight") + converted_state_dict["caption_projection.linear_2.bias"] = checkpoint.pop("y_embedder.y_proj.fc2.bias") converted_state_dict["caption_norm.weight"] = checkpoint.pop("attention_y_norm.weight") diff --git a/tests/single_file/test_sana_transformer.py b/tests/single_file/test_sana_transformer.py index a1ea87d3f891..dedf9145b56b 100644 --- a/tests/single_file/test_sana_transformer.py +++ b/tests/single_file/test_sana_transformer.py @@ -1,8 +1,6 @@ import gc import unittest -import torch - from diffusers import ( SanaTransformer2DModel, ) From 56a77183f68a2817873cfec3fe1258160ab53d91 Mon Sep 17 00:00:00 2001 From: ishan-modi Date: Tue, 4 Mar 2025 14:30:47 +0530 Subject: [PATCH 5/7] updated tests --- tests/single_file/test_sana_transformer.py | 26 +++++++++++++++++++++- 1 file changed, 25 insertions(+), 1 deletion(-) diff --git a/tests/single_file/test_sana_transformer.py b/tests/single_file/test_sana_transformer.py index dedf9145b56b..c0b1a29e982e 100644 --- a/tests/single_file/test_sana_transformer.py +++ b/tests/single_file/test_sana_transformer.py @@ -1,6 +1,8 @@ import gc import unittest +import torch + from diffusers import ( SanaTransformer2DModel, ) @@ -18,6 +20,10 @@ @require_torch_accelerator class SanaTransformer2DModelSingleFileTests(unittest.TestCase): model_class = SanaTransformer2DModel + ckpt_path = "https://huggingface.co/Efficient-Large-Model/Sana_1600M_1024px/blob/main/checkpoints/Sana_1600M_1024px.pth" + alternate_keys_ckpt_paths = [ + "https://huggingface.co/Efficient-Large-Model/Sana_1600M_1024px/blob/main/checkpoints/Sana_1600M_1024px.pth" + ] repo_id = "Efficient-Large-Model/Sana_1600M_1024px_diffusers" @@ -32,4 +38,22 @@ def tearDown(self): backend_empty_cache(torch_device) def test_single_file_components(self): - _ = self.model_class.from_pretrained(self.repo_id, subfolder="transformer") + model = self.model_class.from_pretrained(self.repo_id, subfolder="transformer") + model_single_file = self.model_class.from_single_file(self.ckpt_path) + + PARAMS_TO_IGNORE = ["torch_dtype", "_name_or_path", "_use_default_values", "_diffusers_version"] + for param_name, param_value in model_single_file.config.items(): + if param_name in PARAMS_TO_IGNORE: + continue + assert ( + model.config[param_name] == param_value + ), f"{param_name} differs between single file loading and pretrained loading" + + def test_checkpoint_loading(self): + for ckpt_path in self.alternate_keys_ckpt_paths: + torch.cuda.empty_cache() + model = self.model_class.from_single_file(ckpt_path) + + del model + gc.collect() + torch.cuda.empty_cache() From d3de540f7f83716da2cb09079c1bb2bef0bded79 Mon Sep 17 00:00:00 2001 From: ishan-modi Date: Tue, 4 Mar 2025 22:30:57 +0530 Subject: [PATCH 6/7] corrected code quality --- src/diffusers/loaders/single_file_model.py | 2 +- src/diffusers/loaders/single_file_utils.py | 10 +++++----- src/diffusers/models/transformers/sana_transformer.py | 2 +- 3 files changed, 7 insertions(+), 7 deletions(-) diff --git a/src/diffusers/loaders/single_file_model.py b/src/diffusers/loaders/single_file_model.py index f85e6c30036c..faa8aca8774d 100644 --- a/src/diffusers/loaders/single_file_model.py +++ b/src/diffusers/loaders/single_file_model.py @@ -36,8 +36,8 @@ convert_ltx_transformer_checkpoint_to_diffusers, convert_ltx_vae_checkpoint_to_diffusers, convert_lumina2_to_diffusers, - convert_sana_transformer_to_diffusers, convert_mochi_transformer_checkpoint_to_diffusers, + convert_sana_transformer_to_diffusers, convert_sd3_transformer_checkpoint_to_diffusers, convert_stable_cascade_unet_single_file_to_diffusers, create_controlnet_diffusers_config_from_ldm, diff --git a/src/diffusers/loaders/single_file_utils.py b/src/diffusers/loaders/single_file_utils.py index 53af5bceb2f6..6b375c2808f6 100644 --- a/src/diffusers/loaders/single_file_utils.py +++ b/src/diffusers/loaders/single_file_utils.py @@ -117,7 +117,7 @@ "hunyuan-video": "txt_in.individual_token_refiner.blocks.0.adaLN_modulation.1.bias", "instruct-pix2pix": "model.diffusion_model.input_blocks.0.0.weight", "lumina2": ["model.diffusion_model.cap_embedder.0.weight", "cap_embedder.0.weight"], - "sana": [ + "sana": [ "blocks.0.cross_attn.q_linear.weight", "blocks.0.cross_attn.q_linear.bias", "blocks.0.cross_attn.kv_linear.weight", @@ -2877,7 +2877,7 @@ def convert_sana_transformer_to_diffusers(checkpoint, **kwargs): checkpoint[k.replace("model.diffusion_model.", "")] = checkpoint.pop(k) num_layers = list(set(int(k.split(".", 2)[1]) for k in checkpoint if "blocks" in k))[-1] + 1 # noqa: C401 - + # Positional and patch embeddings. checkpoint.pop("pos_embed") @@ -2891,7 +2891,7 @@ def convert_sana_transformer_to_diffusers(checkpoint, **kwargs): converted_state_dict["time_embed.emb.timestep_embedder.linear_2.bias"] = checkpoint.pop("t_embedder.mlp.2.bias") converted_state_dict["time_embed.linear.weight"] = checkpoint.pop("t_block.1.weight") converted_state_dict["time_embed.linear.bias"] = checkpoint.pop("t_block.1.bias") - + # Caption Projection. checkpoint.pop("y_embedder.y_embedding") converted_state_dict["caption_projection.linear_1.weight"] = checkpoint.pop("y_embedder.y_proj.fc1.weight") @@ -2935,10 +2935,10 @@ def convert_sana_transformer_to_diffusers(checkpoint, **kwargs): converted_state_dict[f"transformer_blocks.{i}.ff.conv_depth.weight"] = checkpoint.pop(f"blocks.{i}.mlp.depth_conv.conv.weight") converted_state_dict[f"transformer_blocks.{i}.ff.conv_depth.bias"] = checkpoint.pop(f"blocks.{i}.mlp.depth_conv.conv.bias") converted_state_dict[f"transformer_blocks.{i}.ff.conv_point.weight"] = checkpoint.pop(f"blocks.{i}.mlp.point_conv.conv.weight") - + # Final layer converted_state_dict["proj_out.weight"] = checkpoint.pop("final_layer.linear.weight") converted_state_dict["proj_out.bias"] = checkpoint.pop("final_layer.linear.bias") converted_state_dict["scale_shift_table"] = checkpoint.pop("final_layer.scale_shift_table") - return converted_state_dict \ No newline at end of file + return converted_state_dict diff --git a/src/diffusers/models/transformers/sana_transformer.py b/src/diffusers/models/transformers/sana_transformer.py index 4cb3c78fca87..b8cc96d3532c 100644 --- a/src/diffusers/models/transformers/sana_transformer.py +++ b/src/diffusers/models/transformers/sana_transformer.py @@ -18,7 +18,7 @@ from torch import nn from ...configuration_utils import ConfigMixin, register_to_config -from ...loaders import PeftAdapterMixin, FromOriginalModelMixin +from ...loaders import FromOriginalModelMixin, PeftAdapterMixin from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers from ..attention_processor import ( Attention, From e3046a5280d5c973b54d52fb6ccf54dc5080f649 Mon Sep 17 00:00:00 2001 From: ishan-modi Date: Wed, 5 Mar 2025 09:11:22 +0530 Subject: [PATCH 7/7] corrected code quality --- src/diffusers/loaders/single_file_model.py | 2 +- src/diffusers/loaders/single_file_utils.py | 70 +++++++++++++++------- tests/single_file/test_sana_transformer.py | 4 +- 3 files changed, 54 insertions(+), 22 deletions(-) diff --git a/src/diffusers/loaders/single_file_model.py b/src/diffusers/loaders/single_file_model.py index faa8aca8774d..8cdecb1f8fbc 100644 --- a/src/diffusers/loaders/single_file_model.py +++ b/src/diffusers/loaders/single_file_model.py @@ -121,7 +121,7 @@ "SanaTransformer2DModel": { "checkpoint_mapping_fn": convert_sana_transformer_to_diffusers, "default_subfolder": "transformer", - } + }, } diff --git a/src/diffusers/loaders/single_file_utils.py b/src/diffusers/loaders/single_file_utils.py index 63a69cda7c12..15aa1ed2d8d3 100644 --- a/src/diffusers/loaders/single_file_utils.py +++ b/src/diffusers/loaders/single_file_utils.py @@ -121,7 +121,7 @@ "blocks.0.cross_attn.q_linear.weight", "blocks.0.cross_attn.q_linear.bias", "blocks.0.cross_attn.kv_linear.weight", - "blocks.0.cross_attn.kv_linear.bias" + "blocks.0.cross_attn.kv_linear.bias", ], } @@ -182,7 +182,7 @@ "hunyuan-video": {"pretrained_model_name_or_path": "hunyuanvideo-community/HunyuanVideo"}, "instruct-pix2pix": {"pretrained_model_name_or_path": "timbrooks/instruct-pix2pix"}, "lumina2": {"pretrained_model_name_or_path": "Alpha-VLLM/Lumina-Image-2.0"}, - "sana": {"pretrained_model_name_or_path": "Efficient-Large-Model/Sana_1600M_1024px_diffusers"}, + "sana": {"pretrained_model_name_or_path": "Efficient-Large-Model/Sana_1600M_1024px_diffusers"}, } # Use to configure model sample size when original config is provided @@ -2878,16 +2878,19 @@ def convert_sana_transformer_to_diffusers(checkpoint, **kwargs): num_layers = list(set(int(k.split(".", 2)[1]) for k in checkpoint if "blocks" in k))[-1] + 1 # noqa: C401 - # Positional and patch embeddings. checkpoint.pop("pos_embed") converted_state_dict["patch_embed.proj.weight"] = checkpoint.pop("x_embedder.proj.weight") converted_state_dict["patch_embed.proj.bias"] = checkpoint.pop("x_embedder.proj.bias") # Timestep embeddings. - converted_state_dict["time_embed.emb.timestep_embedder.linear_1.weight"] = checkpoint.pop("t_embedder.mlp.0.weight") + converted_state_dict["time_embed.emb.timestep_embedder.linear_1.weight"] = checkpoint.pop( + "t_embedder.mlp.0.weight" + ) converted_state_dict["time_embed.emb.timestep_embedder.linear_1.bias"] = checkpoint.pop("t_embedder.mlp.0.bias") - converted_state_dict["time_embed.emb.timestep_embedder.linear_2.weight"] = checkpoint.pop("t_embedder.mlp.2.weight") + converted_state_dict["time_embed.emb.timestep_embedder.linear_2.weight"] = checkpoint.pop( + "t_embedder.mlp.2.weight" + ) converted_state_dict["time_embed.emb.timestep_embedder.linear_2.bias"] = checkpoint.pop("t_embedder.mlp.2.bias") converted_state_dict["time_embed.linear.weight"] = checkpoint.pop("t_block.1.weight") converted_state_dict["time_embed.linear.bias"] = checkpoint.pop("t_block.1.bias") @@ -2900,9 +2903,10 @@ def convert_sana_transformer_to_diffusers(checkpoint, **kwargs): converted_state_dict["caption_projection.linear_2.bias"] = checkpoint.pop("y_embedder.y_proj.fc2.bias") converted_state_dict["caption_norm.weight"] = checkpoint.pop("attention_y_norm.weight") - for i in range(num_layers): - converted_state_dict[f"transformer_blocks.{i}.scale_shift_table"] = checkpoint.pop(f"blocks.{i}.scale_shift_table") + converted_state_dict[f"transformer_blocks.{i}.scale_shift_table"] = checkpoint.pop( + f"blocks.{i}.scale_shift_table" + ) # Self-Attention sample_q, sample_k, sample_v = torch.chunk(checkpoint.pop(f"blocks.{i}.attn.qkv.weight"), 3, dim=0) @@ -2911,30 +2915,56 @@ def convert_sana_transformer_to_diffusers(checkpoint, **kwargs): converted_state_dict[f"transformer_blocks.{i}.attn1.to_v.weight"] = torch.cat([sample_v]) # Output Projections - converted_state_dict[f"transformer_blocks.{i}.attn1.to_out.0.weight"] = checkpoint.pop(f"blocks.{i}.attn.proj.weight") - converted_state_dict[f"transformer_blocks.{i}.attn1.to_out.0.bias"] = checkpoint.pop(f"blocks.{i}.attn.proj.bias") + converted_state_dict[f"transformer_blocks.{i}.attn1.to_out.0.weight"] = checkpoint.pop( + f"blocks.{i}.attn.proj.weight" + ) + converted_state_dict[f"transformer_blocks.{i}.attn1.to_out.0.bias"] = checkpoint.pop( + f"blocks.{i}.attn.proj.bias" + ) # Cross-Attention - converted_state_dict[f"transformer_blocks.{i}.attn2.to_q.weight"] = checkpoint.pop(f"blocks.{i}.cross_attn.q_linear.weight") - converted_state_dict[f"transformer_blocks.{i}.attn2.to_q.bias"] = checkpoint.pop(f"blocks.{i}.cross_attn.q_linear.bias") + converted_state_dict[f"transformer_blocks.{i}.attn2.to_q.weight"] = checkpoint.pop( + f"blocks.{i}.cross_attn.q_linear.weight" + ) + converted_state_dict[f"transformer_blocks.{i}.attn2.to_q.bias"] = checkpoint.pop( + f"blocks.{i}.cross_attn.q_linear.bias" + ) - linear_sample_k, linear_sample_v = torch.chunk(checkpoint.pop(f"blocks.{i}.cross_attn.kv_linear.weight"), 2, dim=0) - linear_sample_k_bias, linear_sample_v_bias = torch.chunk(checkpoint.pop(f"blocks.{i}.cross_attn.kv_linear.bias"), 2, dim=0) + linear_sample_k, linear_sample_v = torch.chunk( + checkpoint.pop(f"blocks.{i}.cross_attn.kv_linear.weight"), 2, dim=0 + ) + linear_sample_k_bias, linear_sample_v_bias = torch.chunk( + checkpoint.pop(f"blocks.{i}.cross_attn.kv_linear.bias"), 2, dim=0 + ) converted_state_dict[f"transformer_blocks.{i}.attn2.to_k.weight"] = linear_sample_k converted_state_dict[f"transformer_blocks.{i}.attn2.to_v.weight"] = linear_sample_v converted_state_dict[f"transformer_blocks.{i}.attn2.to_k.bias"] = linear_sample_k_bias converted_state_dict[f"transformer_blocks.{i}.attn2.to_v.bias"] = linear_sample_v_bias # Output Projections - converted_state_dict[f"transformer_blocks.{i}.attn2.to_out.0.weight"] = checkpoint.pop(f"blocks.{i}.cross_attn.proj.weight") - converted_state_dict[f"transformer_blocks.{i}.attn2.to_out.0.bias"] = checkpoint.pop(f"blocks.{i}.cross_attn.proj.bias") + converted_state_dict[f"transformer_blocks.{i}.attn2.to_out.0.weight"] = checkpoint.pop( + f"blocks.{i}.cross_attn.proj.weight" + ) + converted_state_dict[f"transformer_blocks.{i}.attn2.to_out.0.bias"] = checkpoint.pop( + f"blocks.{i}.cross_attn.proj.bias" + ) # MLP - converted_state_dict[f"transformer_blocks.{i}.ff.conv_inverted.weight"] = checkpoint.pop(f"blocks.{i}.mlp.inverted_conv.conv.weight") - converted_state_dict[f"transformer_blocks.{i}.ff.conv_inverted.bias"] = checkpoint.pop(f"blocks.{i}.mlp.inverted_conv.conv.bias") - converted_state_dict[f"transformer_blocks.{i}.ff.conv_depth.weight"] = checkpoint.pop(f"blocks.{i}.mlp.depth_conv.conv.weight") - converted_state_dict[f"transformer_blocks.{i}.ff.conv_depth.bias"] = checkpoint.pop(f"blocks.{i}.mlp.depth_conv.conv.bias") - converted_state_dict[f"transformer_blocks.{i}.ff.conv_point.weight"] = checkpoint.pop(f"blocks.{i}.mlp.point_conv.conv.weight") + converted_state_dict[f"transformer_blocks.{i}.ff.conv_inverted.weight"] = checkpoint.pop( + f"blocks.{i}.mlp.inverted_conv.conv.weight" + ) + converted_state_dict[f"transformer_blocks.{i}.ff.conv_inverted.bias"] = checkpoint.pop( + f"blocks.{i}.mlp.inverted_conv.conv.bias" + ) + converted_state_dict[f"transformer_blocks.{i}.ff.conv_depth.weight"] = checkpoint.pop( + f"blocks.{i}.mlp.depth_conv.conv.weight" + ) + converted_state_dict[f"transformer_blocks.{i}.ff.conv_depth.bias"] = checkpoint.pop( + f"blocks.{i}.mlp.depth_conv.conv.bias" + ) + converted_state_dict[f"transformer_blocks.{i}.ff.conv_point.weight"] = checkpoint.pop( + f"blocks.{i}.mlp.point_conv.conv.weight" + ) # Final layer converted_state_dict["proj_out.weight"] = checkpoint.pop("final_layer.linear.weight") diff --git a/tests/single_file/test_sana_transformer.py b/tests/single_file/test_sana_transformer.py index c0b1a29e982e..7695e1577711 100644 --- a/tests/single_file/test_sana_transformer.py +++ b/tests/single_file/test_sana_transformer.py @@ -20,7 +20,9 @@ @require_torch_accelerator class SanaTransformer2DModelSingleFileTests(unittest.TestCase): model_class = SanaTransformer2DModel - ckpt_path = "https://huggingface.co/Efficient-Large-Model/Sana_1600M_1024px/blob/main/checkpoints/Sana_1600M_1024px.pth" + ckpt_path = ( + "https://huggingface.co/Efficient-Large-Model/Sana_1600M_1024px/blob/main/checkpoints/Sana_1600M_1024px.pth" + ) alternate_keys_ckpt_paths = [ "https://huggingface.co/Efficient-Large-Model/Sana_1600M_1024px/blob/main/checkpoints/Sana_1600M_1024px.pth" ]