Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Single File] Add single file loading for SANA Transformer #10947

Merged
merged 15 commits into from
Mar 10, 2025

Conversation

ishan-modi
Copy link
Contributor

What does this PR do?

Fixes #10872

Who can review?

@yiyixuxu

@yiyixuxu
Copy link
Collaborator

yiyixuxu commented Mar 3, 2025

ohh thanks so much!!!!!!

cc @DN6 here! can you do a review?

@yiyixuxu yiyixuxu requested a review from DN6 March 3, 2025 20:35
@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@DN6
Copy link
Collaborator

DN6 commented Mar 4, 2025

Nice work @ishan-modi 👍🏽. Could we add a test similar to the one for Lumina

class Lumina2Transformer2DModelSingleFileTests(unittest.TestCase):

(We can skip the checkpoint loading test if no checkpoint exists with alternative key names)

We can merge once that's done.

@nitinmukesh
Copy link

nitinmukesh commented Mar 4, 2025

Not working, I tried with/without weights_only=False

import torch
from diffusers import SanaPipeline
from diffusers import SanaTransformer2DModel
model_path = "Efficient-Large-Model/Sana_1600M_1024px_diffusers"
dtype = torch.float16

transformer = SanaTransformer2DModel.from_single_file (
    "https://huggingface.co/Swarmeta-AI/Twig-v0-alpha/blob/main/Twig-v0-alpha-1.6B-2048x-fp16.pth",
	torch_dtype=dtype,
    weights_only=False
)

pipe = SanaPipeline.from_pretrained(
	pretrained_model_name_or_path=model_path,
	transformer=transformer,
	torch_dtype=dtype,
	use_safetensors=True,
)
(venv) C:\aiOWN\quanto>python sana_twig.py
Traceback (most recent call last):
  File "C:\aiOWN\quanto\venv\lib\site-packages\diffusers\models\model_loading_utils.py", line 191, in load_state_dict
    return torch.load(checkpoint_file, map_location=map_location, **weights_only_kwarg, **extra_args)
  File "C:\aiOWN\quanto\venv\lib\site-packages\torch\serialization.py", line 1359, in load
    raise pickle.UnpicklingError(_get_wo_message(str(e))) from None
_pickle.UnpicklingError: Weights only load failed. This file can still be loaded, to do so you have two options, do those steps only if you trust the source of the checkpoint.
        (1) Re-running `torch.load` with `weights_only` set to `False` will likely succeed, but it can result in arbitrary code execution. Do it only if you got the file from a trusted source.
        (2) Alternatively, to load with `weights_only=True` please check the recommended steps in the following error message.
        WeightsUnpickler error: Unsupported global: GLOBAL numpy.core.multiarray._reconstruct was not an allowed global by default. Please use `torch.serialization.add_safe_globals([_reconstruct])` to allowlist this global if you trust this class/function.

Check the documentation of torch.load to learn more about types accepted by default with weights_only https://pytorch.org/docs/stable/generated/torch.load.html.

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "C:\aiOWN\quanto\venv\lib\site-packages\diffusers\models\model_loading_utils.py", line 195, in load_state_dict
    if f.read().startswith("version"):
  File "C:\Program Files\Python310\lib\encodings\cp1252.py", line 23, in decode
    return codecs.charmap_decode(input,self.errors,decoding_table)[0]
UnicodeDecodeError: 'charmap' codec can't decode byte 0x90 in position 724: character maps to <undefined>

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "C:\aiOWN\quanto\sana_twig.py", line 7, in <module>
    transformer = SanaTransformer2DModel.from_single_file (
  File "C:\aiOWN\quanto\venv\lib\site-packages\huggingface_hub\utils\_validators.py", line 114, in _inner_fn
    return fn(*args, **kwargs)
  File "C:\aiOWN\quanto\venv\lib\site-packages\diffusers\loaders\single_file_model.py", line 262, in from_single_file
    checkpoint = load_single_file_checkpoint(
  File "C:\aiOWN\quanto\venv\lib\site-packages\diffusers\loaders\single_file_utils.py", line 418, in load_single_file_checkpoint
    checkpoint = load_state_dict(pretrained_model_link_or_path, disable_mmap=disable_mmap)
  File "C:\aiOWN\quanto\venv\lib\site-packages\diffusers\models\model_loading_utils.py", line 207, in load_state_dict
    raise OSError(
OSError: Unable to load weights from checkpoint file for 'C:\Users\nitin\.cache\huggingface\hub\models--Swarmeta-AI--Twig-v0-alpha\snapshots\0d767344d9949e34f9d32a510a2f1d95ddb731ae\Twig-v0-alpha-1.6B-2048x-fp16.pth' at 'C:\Users\nitin\.cache\huggingface\hub\models--Swarmeta-AI--Twig-v0-alpha\snapshots\0d767344d9949e34f9d32a510a2f1d95ddb731ae\Twig-v0-alpha-1.6B-2048x-fp16.pth'.

@ishan-modi
Copy link
Contributor Author

@nitinmukesh the problem that you are mentioning is a separate one, model should load with weights_only=False, but there is a weights_only=True override in load_state_dict function that is called downstream.

@DN6, @yiyixuxu, do we want to address above in this PR and carry kwargs in these downstream functions ?

@DN6
Copy link
Collaborator

DN6 commented Mar 4, 2025

So it looks like that particular checkpoint has extra serialized objects in the picke file.

For these cases we recommend loading the file yourself with torch.load (if you trust the provider) and then passing the loaded state_dict directly to from_single_file or convert the checkpoint to safetensors and then load it with from_single_file.

We would like to avoid providing the option to set weights_only=False since it is a known security hole.

@DN6
Copy link
Collaborator

DN6 commented Mar 4, 2025

@ishan-modi
Copy link
Contributor Author

@DN6, I have tested that and it loads correctly

@nitinmukesh
Copy link

nitinmukesh commented Mar 4, 2025

@DN6

Yes the default loads fine. I also don't want to use unsafe way of loading.

@DN6 DN6 changed the title Fixes issue 10872 [Single File] Add single file loading for SANA Transformer Mar 10, 2025
@DN6 DN6 merged commit 0703ce8 into huggingface:main Mar 10, 2025
11 of 12 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
6 participants