Skip to content

Commit ca734f3

Browse files
committed
feat: 🎸 allow loading .safetensors
1 parent 0bf9b43 commit ca734f3

File tree

1 file changed

+35
-14
lines changed

1 file changed

+35
-14
lines changed

src/models/model_util.py

+35-14
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,6 @@ def load_checkpoint_model(
8282
weight_dtype: torch.dtype = torch.float32,
8383
device = "cuda",
8484
) -> tuple[CLIPTokenizer, CLIPTextModel, UNet2DConditionModel, DiffusionPipeline]:
85-
print(f"Loading checkpoint from {checkpoint_path}")
8685
if checkpoint_path == "BAAI/AltDiffusion":
8786
pipe = AltDiffusionPipeline.from_pretrained(
8887
"BAAI/AltDiffusion",
@@ -92,13 +91,24 @@ def load_checkpoint_model(
9291
local_files_only=LOCAL_ONLY,
9392
).to(device)
9493
else:
95-
pipe = StableDiffusionPipeline.from_pretrained(
96-
checkpoint_path,
97-
upcast_attention=True if v2 else False,
98-
torch_dtype=weight_dtype,
99-
cache_dir=DIFFUSERS_CACHE_DIR,
100-
local_files_only=LOCAL_ONLY,
101-
).to(device)
94+
if checkpoint_path.endswith(".safetensors"):
95+
print(f"Loading local checkpoint from {checkpoint_path}")
96+
pipe = StableDiffusionPipeline.from_single_file(
97+
checkpoint_path,
98+
upcast_attention=True if v2 else False,
99+
torch_dtype=weight_dtype,
100+
cache_dir=DIFFUSERS_CACHE_DIR,
101+
local_files_only=LOCAL_ONLY,
102+
).to(device)
103+
else:
104+
print(f"Loading checkpoint from {checkpoint_path}")
105+
pipe = StableDiffusionPipeline.from_pretrained(
106+
checkpoint_path,
107+
upcast_attention=True if v2 else False,
108+
torch_dtype=weight_dtype,
109+
cache_dir=DIFFUSERS_CACHE_DIR,
110+
local_files_only=LOCAL_ONLY,
111+
).to(device)
102112

103113
unet = pipe.unet
104114
tokenizer = pipe.tokenizer
@@ -183,12 +193,23 @@ def load_checkpoint_model_xl(
183193
weight_dtype: torch.dtype = torch.float32,
184194
device = "cuda",
185195
) -> tuple[list[CLIPTokenizer], list[SDXL_TEXT_ENCODER_TYPE], UNet2DConditionModel, DiffusionPipeline, ]:
186-
pipe = StableDiffusionXLPipeline.from_pretrained(
187-
checkpoint_path,
188-
torch_dtype=weight_dtype,
189-
cache_dir=DIFFUSERS_CACHE_DIR,
190-
local_files_only=LOCAL_ONLY,
191-
).to(device)
196+
#Check if checkpoint path ends with ".safetensors"
197+
if checkpoint_path.endswith(".safetensors"):
198+
print(f"Loading local checkpoint from {checkpoint_path}")
199+
pipe = StableDiffusionXLPipeline.from_single_file(
200+
checkpoint_path,
201+
torch_dtype=weight_dtype,
202+
cache_dir=DIFFUSERS_CACHE_DIR,
203+
local_files_only=LOCAL_ONLY,
204+
).to(device)
205+
else:
206+
print(f"Loading checkpoint from {checkpoint_path}")
207+
pipe = StableDiffusionXLPipeline.from_pretrained(
208+
checkpoint_path,
209+
torch_dtype=weight_dtype,
210+
cache_dir=DIFFUSERS_CACHE_DIR,
211+
local_files_only=LOCAL_ONLY,
212+
).to(device)
192213

193214
unet = pipe.unet
194215
tokenizers = [pipe.tokenizer, pipe.tokenizer_2]

0 commit comments

Comments
 (0)