Skip to content

Commit e1bbf58

Browse files
authored
Merge pull request #313 from PolicyEngine/nikhilwoodruff/issue312
Allow `Simulation` to load H5 files and HuggingFace H5 files directly
2 parents 239e7b9 + a83fbc8 commit e1bbf58

File tree

4 files changed

+55
-22
lines changed

4 files changed

+55
-22
lines changed

changelog_entry.yaml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
- bump: minor
2+
changes:
3+
added:
4+
- Support for HuggingFace-hosted H5 file inputs.

policyengine_core/data/dataset.py

Lines changed: 21 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,7 @@
77
import requests
88
import os
99
import tempfile
10-
from huggingface_hub import HfApi, login, hf_hub_download
11-
import pkg_resources
10+
from policyengine_core.tools.hugging_face import *
1211

1312

1413
def atomic_write(file: Path, content: bytes) -> None:
@@ -318,7 +317,7 @@ def download(self, url: str = None, version: str = None) -> None:
318317
"""
319318

320319
if url is None:
321-
url = self.huggingface_url or self.url
320+
url = self.url or self.huggingface_url
322321

323322
if "POLICYENGINE_GITHUB_MICRODATA_AUTH_TOKEN" not in os.environ:
324323
auth_headers = {}
@@ -400,13 +399,29 @@ def from_file(file_path: str, time_period: str = None):
400399
Dataset: The dataset.
401400
"""
402401
file_path = Path(file_path)
402+
403+
# If it's a h5 file, check the first key
404+
405+
if file_path.suffix == ".h5":
406+
with h5py.File(file_path, "r") as f:
407+
first_key = list(f.keys())[0]
408+
first_value = f[first_key]
409+
if isinstance(first_value, h5py.Dataset):
410+
data_format = Dataset.ARRAYS
411+
else:
412+
data_format = Dataset.TIME_PERIOD_ARRAYS
413+
subkeys = list(first_value.keys())
414+
if len(subkeys) > 0:
415+
time_period = subkeys[0]
416+
else:
417+
data_format = Dataset.FLAT_FILE
403418
dataset = type(
404419
"Dataset",
405420
(Dataset,),
406421
{
407422
"name": file_path.stem,
408423
"label": file_path.stem,
409-
"data_format": Dataset.FLAT_FILE,
424+
"data_format": data_format,
410425
"file_path": file_path,
411426
"time_period": time_period,
412427
},
@@ -446,21 +461,14 @@ def upload_to_huggingface(self, owner_name: str, model_name: str):
446461
token = os.environ.get(
447462
"HUGGING_FACE_TOKEN",
448463
)
449-
login(token=token)
450464
api = HfApi()
451465

452-
# Add the policyengine-uk-data version and policyengine-uk version to the h5 metadata.
453-
uk_data_version = get_package_version("policyengine-uk-data")
454-
uk_version = get_package_version("policyengine-uk")
455-
with h5py.File(self.file_path, "a") as f:
456-
f.attrs["policyengine-uk-data"] = uk_data_version
457-
f.attrs["policyengine-uk"] = uk_version
458-
459466
api.upload_file(
460467
path_or_fileobj=self.file_path,
461468
path_in_repo=self.file_path.name,
462469
repo_id=f"{owner_name}/{model_name}",
463470
repo_type="model",
471+
token=token,
464472
)
465473

466474
def download_from_huggingface(
@@ -475,19 +483,11 @@ def download_from_huggingface(
475483
token = os.environ.get(
476484
"HUGGING_FACE_TOKEN",
477485
)
478-
login(token=token)
479486

480487
hf_hub_download(
481488
repo_id=f"{owner_name}/{model_name}",
482489
repo_type="model",
483490
path=self.file_path,
484491
revision=version,
492+
token=token,
485493
)
486-
487-
488-
def get_package_version(package_name: str) -> str:
489-
"""Get the installed version of a package."""
490-
try:
491-
return pkg_resources.get_distribution(package_name).version
492-
except pkg_resources.DistributionNotFound:
493-
return "not installed"

policyengine_core/simulations/simulation.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
TracingParameterNodeAtInstant,
2424
)
2525
import random
26+
from policyengine_core.tools.hugging_face import *
2627

2728
import json
2829

@@ -54,7 +55,7 @@ class Simulation:
5455
default_dataset: Dataset = None
5556
"""The default dataset class to use if none is provided."""
5657

57-
default_role: str = None
58+
default_role: str = "member"
5859
"""The default role to assign people to groups if none is provided."""
5960

6061
default_input_period: str = None
@@ -151,6 +152,14 @@ def __init__(
151152

152153
if dataset is not None:
153154
if isinstance(dataset, str):
155+
if "hf://" in dataset:
156+
owner, repo, filename = dataset.split("/")[-3:]
157+
if "@" in filename:
158+
version = filename.split("@")[-1]
159+
filename = filename.split("@")[0]
160+
else:
161+
version = None
162+
dataset = download(owner + "/" + repo, filename, version)
154163
datasets_by_name = {
155164
dataset.name: dataset for dataset in self.datasets
156165
}
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
from huggingface_hub import hf_hub_download, login, HfApi
2+
import os
3+
import warnings
4+
5+
with warnings.catch_warnings():
6+
warnings.simplefilter("ignore")
7+
8+
9+
def download(repo: str, repo_filename: str, version: str = None):
10+
token = os.environ.get(
11+
"HUGGING_FACE_TOKEN",
12+
)
13+
14+
return hf_hub_download(
15+
repo_id=repo,
16+
repo_type="model",
17+
filename=repo_filename,
18+
revision=version,
19+
token=token,
20+
)

0 commit comments

Comments
 (0)