77from pathlib import Path
88from tempfile import TemporaryDirectory
99from typing import Iterable , Optional , Union
10+
1011import torch
1112from torch .hub import HASH_REGEX , download_url_to_file , urlparse
12- import safetensors .torch
1313
1414try :
1515 from torch .hub import get_dir
1616except ImportError :
1717 from torch .hub import _get_torch_home as get_dir
1818
19+ try :
20+ import safetensors .torch
21+ _has_safetensors = True
22+ except ImportError :
23+ _has_safetensors = False
24+
1925if sys .version_info >= (3 , 8 ):
2026 from typing import Literal
2127else :
4551HF_WEIGHTS_NAME = "pytorch_model.bin" # default pytorch pkl
4652HF_SAFE_WEIGHTS_NAME = "model.safetensors" # safetensors version
4753
54+
4855def get_cache_dir (child_dir = '' ):
4956 """
5057 Returns the location of the directory where models are cached (and creates it if necessary).
@@ -164,21 +171,28 @@ def load_state_dict_from_hf(model_id: str, filename: str = HF_WEIGHTS_NAME):
164171 hf_model_id , hf_revision = hf_split (model_id )
165172
166173 # Look for .safetensors alternatives and load from it if it exists
167- for safe_filename in _get_safe_alternatives (filename ):
168- try :
169- cached_safe_file = hf_hub_download (repo_id = hf_model_id , filename = safe_filename , revision = hf_revision )
170- _logger .info (f"[{ model_id } ] Safe alternative available for '{ filename } ' (as '{ safe_filename } '). Loading weights using safetensors." )
171- return safetensors .torch .load_file (cached_safe_file , device = "cpu" )
172- except EntryNotFoundError :
173- pass
174+ if _has_safetensors :
175+ for safe_filename in _get_safe_alternatives (filename ):
176+ try :
177+ cached_safe_file = hf_hub_download (repo_id = hf_model_id , filename = safe_filename , revision = hf_revision )
178+ _logger .info (
179+ f"[{ model_id } ] Safe alternative available for '{ filename } ' "
180+ f"(as '{ safe_filename } '). Loading weights using safetensors." )
181+ return safetensors .torch .load_file (cached_safe_file , device = "cpu" )
182+ except EntryNotFoundError :
183+ pass
174184
175185 # Otherwise, load using pytorch.load
176186 cached_file = hf_hub_download (hf_model_id , filename = filename , revision = hf_revision )
177- _logger .info (f"[{ model_id } ] Safe alternative not found for '{ filename } '. Loading weights using default pytorch." )
187+ _logger .debug (f"[{ model_id } ] Safe alternative not found for '{ filename } '. Loading weights using default pytorch." )
178188 return torch .load (cached_file , map_location = 'cpu' )
179189
180190
181- def save_config_for_hf (model , config_path : str , model_config : Optional [dict ] = None ):
191+ def save_config_for_hf (
192+ model ,
193+ config_path : str ,
194+ model_config : Optional [dict ] = None
195+ ):
182196 model_config = model_config or {}
183197 hf_config = {}
184198 pretrained_cfg = filter_pretrained_cfg (model .pretrained_cfg , remove_source = True , remove_null = True )
@@ -220,15 +234,16 @@ def save_for_hf(
220234 model ,
221235 save_directory : str ,
222236 model_config : Optional [dict ] = None ,
223- safe_serialization : Union [bool , Literal ["both" ]] = False
224- ):
237+ safe_serialization : Union [bool , Literal ["both" ]] = False ,
238+ ):
225239 assert has_hf_hub (True )
226240 save_directory = Path (save_directory )
227241 save_directory .mkdir (exist_ok = True , parents = True )
228242
229243 # Save model weights, either safely (using safetensors), or using legacy pytorch approach or both.
230244 tensors = model .state_dict ()
231245 if safe_serialization is True or safe_serialization == "both" :
246+ assert _has_safetensors , "`pip install safetensors` to use .safetensors"
232247 safetensors .torch .save_file (tensors , save_directory / HF_SAFE_WEIGHTS_NAME )
233248 if safe_serialization is False or safe_serialization == "both" :
234249 torch .save (tensors , save_directory / HF_WEIGHTS_NAME )
@@ -238,16 +253,16 @@ def save_for_hf(
238253
239254
240255def push_to_hf_hub (
241- model ,
242- repo_id : str ,
243- commit_message : str = 'Add model' ,
244- token : Optional [str ] = None ,
245- revision : Optional [str ] = None ,
246- private : bool = False ,
247- create_pr : bool = False ,
248- model_config : Optional [dict ] = None ,
249- model_card : Optional [dict ] = None ,
250- safe_serialization : Union [bool , Literal ["both" ]] = False
256+ model ,
257+ repo_id : str ,
258+ commit_message : str = 'Add model' ,
259+ token : Optional [str ] = None ,
260+ revision : Optional [str ] = None ,
261+ private : bool = False ,
262+ create_pr : bool = False ,
263+ model_config : Optional [dict ] = None ,
264+ model_card : Optional [dict ] = None ,
265+ safe_serialization : Union [bool , Literal ["both" ]] = False ,
251266):
252267 """
253268 Arguments:
@@ -341,6 +356,7 @@ def generate_readme(model_card: dict, model_name: str):
341356 readme_text += f"```bibtex\n { c } \n ```\n "
342357 return readme_text
343358
359+
344360def _get_safe_alternatives (filename : str ) -> Iterable [str ]:
345361 """Returns potential safetensors alternatives for a given filename.
346362
@@ -350,5 +366,5 @@ def _get_safe_alternatives(filename: str) -> Iterable[str]:
350366 """
351367 if filename == HF_WEIGHTS_NAME :
352368 yield HF_SAFE_WEIGHTS_NAME
353- if filename .endswith (".bin" ):
354- yield filename [:- 4 ] + ".safetensors"
369+ if filename != HF_WEIGHTS_NAME and filename .endswith (".bin" ):
370+ return filename [:- 4 ] + ".safetensors"
0 commit comments