7
7
import requests
8
8
import os
9
9
import tempfile
10
- from huggingface_hub import HfApi , login , hf_hub_download
11
- import pkg_resources
10
+ from policyengine_core .tools .hugging_face import *
12
11
13
12
14
13
def atomic_write (file : Path , content : bytes ) -> None :
@@ -318,7 +317,7 @@ def download(self, url: str = None, version: str = None) -> None:
318
317
"""
319
318
320
319
if url is None :
321
- url = self .huggingface_url or self .url
320
+ url = self .url or self .huggingface_url
322
321
323
322
if "POLICYENGINE_GITHUB_MICRODATA_AUTH_TOKEN" not in os .environ :
324
323
auth_headers = {}
@@ -400,13 +399,29 @@ def from_file(file_path: str, time_period: str = None):
400
399
Dataset: The dataset.
401
400
"""
402
401
file_path = Path (file_path )
402
+
403
+ # If it's a h5 file, check the first key
404
+
405
+ if file_path .suffix == ".h5" :
406
+ with h5py .File (file_path , "r" ) as f :
407
+ first_key = list (f .keys ())[0 ]
408
+ first_value = f [first_key ]
409
+ if isinstance (first_value , h5py .Dataset ):
410
+ data_format = Dataset .ARRAYS
411
+ else :
412
+ data_format = Dataset .TIME_PERIOD_ARRAYS
413
+ subkeys = list (first_value .keys ())
414
+ if len (subkeys ) > 0 :
415
+ time_period = subkeys [0 ]
416
+ else :
417
+ data_format = Dataset .FLAT_FILE
403
418
dataset = type (
404
419
"Dataset" ,
405
420
(Dataset ,),
406
421
{
407
422
"name" : file_path .stem ,
408
423
"label" : file_path .stem ,
409
- "data_format" : Dataset . FLAT_FILE ,
424
+ "data_format" : data_format ,
410
425
"file_path" : file_path ,
411
426
"time_period" : time_period ,
412
427
},
@@ -446,21 +461,14 @@ def upload_to_huggingface(self, owner_name: str, model_name: str):
446
461
token = os .environ .get (
447
462
"HUGGING_FACE_TOKEN" ,
448
463
)
449
- login (token = token )
450
464
api = HfApi ()
451
465
452
- # Add the policyengine-uk-data version and policyengine-uk version to the h5 metadata.
453
- uk_data_version = get_package_version ("policyengine-uk-data" )
454
- uk_version = get_package_version ("policyengine-uk" )
455
- with h5py .File (self .file_path , "a" ) as f :
456
- f .attrs ["policyengine-uk-data" ] = uk_data_version
457
- f .attrs ["policyengine-uk" ] = uk_version
458
-
459
466
api .upload_file (
460
467
path_or_fileobj = self .file_path ,
461
468
path_in_repo = self .file_path .name ,
462
469
repo_id = f"{ owner_name } /{ model_name } " ,
463
470
repo_type = "model" ,
471
+ token = token ,
464
472
)
465
473
466
474
def download_from_huggingface (
@@ -475,19 +483,11 @@ def download_from_huggingface(
475
483
token = os .environ .get (
476
484
"HUGGING_FACE_TOKEN" ,
477
485
)
478
- login (token = token )
479
486
480
487
hf_hub_download (
481
488
repo_id = f"{ owner_name } /{ model_name } " ,
482
489
repo_type = "model" ,
483
490
path = self .file_path ,
484
491
revision = version ,
492
+ token = token ,
485
493
)
486
-
487
-
488
- def get_package_version (package_name : str ) -> str :
489
- """Get the installed version of a package."""
490
- try :
491
- return pkg_resources .get_distribution (package_name ).version
492
- except pkg_resources .DistributionNotFound :
493
- return "not installed"
0 commit comments