@@ -82,7 +82,6 @@ def load_checkpoint_model(
82
82
weight_dtype : torch .dtype = torch .float32 ,
83
83
device = "cuda" ,
84
84
) -> tuple [CLIPTokenizer , CLIPTextModel , UNet2DConditionModel , DiffusionPipeline ]:
85
- print (f"Loading checkpoint from { checkpoint_path } " )
86
85
if checkpoint_path == "BAAI/AltDiffusion" :
87
86
pipe = AltDiffusionPipeline .from_pretrained (
88
87
"BAAI/AltDiffusion" ,
@@ -92,13 +91,24 @@ def load_checkpoint_model(
92
91
local_files_only = LOCAL_ONLY ,
93
92
).to (device )
94
93
else :
95
- pipe = StableDiffusionPipeline .from_pretrained (
96
- checkpoint_path ,
97
- upcast_attention = True if v2 else False ,
98
- torch_dtype = weight_dtype ,
99
- cache_dir = DIFFUSERS_CACHE_DIR ,
100
- local_files_only = LOCAL_ONLY ,
101
- ).to (device )
94
+ if checkpoint_path .endswith (".safetensors" ):
95
+ print (f"Loading local checkpoint from { checkpoint_path } " )
96
+ pipe = StableDiffusionPipeline .from_single_file (
97
+ checkpoint_path ,
98
+ upcast_attention = True if v2 else False ,
99
+ torch_dtype = weight_dtype ,
100
+ cache_dir = DIFFUSERS_CACHE_DIR ,
101
+ local_files_only = LOCAL_ONLY ,
102
+ ).to (device )
103
+ else :
104
+ print (f"Loading checkpoint from { checkpoint_path } " )
105
+ pipe = StableDiffusionPipeline .from_pretrained (
106
+ checkpoint_path ,
107
+ upcast_attention = True if v2 else False ,
108
+ torch_dtype = weight_dtype ,
109
+ cache_dir = DIFFUSERS_CACHE_DIR ,
110
+ local_files_only = LOCAL_ONLY ,
111
+ ).to (device )
102
112
103
113
unet = pipe .unet
104
114
tokenizer = pipe .tokenizer
@@ -183,12 +193,23 @@ def load_checkpoint_model_xl(
183
193
weight_dtype : torch .dtype = torch .float32 ,
184
194
device = "cuda" ,
185
195
) -> tuple [list [CLIPTokenizer ], list [SDXL_TEXT_ENCODER_TYPE ], UNet2DConditionModel , DiffusionPipeline , ]:
186
- pipe = StableDiffusionXLPipeline .from_pretrained (
187
- checkpoint_path ,
188
- torch_dtype = weight_dtype ,
189
- cache_dir = DIFFUSERS_CACHE_DIR ,
190
- local_files_only = LOCAL_ONLY ,
191
- ).to (device )
196
+ #Check if checkpoint path ends with ".safetensors"
197
+ if checkpoint_path .endswith (".safetensors" ):
198
+ print (f"Loading local checkpoint from { checkpoint_path } " )
199
+ pipe = StableDiffusionXLPipeline .from_single_file (
200
+ checkpoint_path ,
201
+ torch_dtype = weight_dtype ,
202
+ cache_dir = DIFFUSERS_CACHE_DIR ,
203
+ local_files_only = LOCAL_ONLY ,
204
+ ).to (device )
205
+ else :
206
+ print (f"Loading checkpoint from { checkpoint_path } " )
207
+ pipe = StableDiffusionXLPipeline .from_pretrained (
208
+ checkpoint_path ,
209
+ torch_dtype = weight_dtype ,
210
+ cache_dir = DIFFUSERS_CACHE_DIR ,
211
+ local_files_only = LOCAL_ONLY ,
212
+ ).to (device )
192
213
193
214
unet = pipe .unet
194
215
tokenizers = [pipe .tokenizer , pipe .tokenizer_2 ]
0 commit comments