Skip to content

Commit

Permalink
Merge pull request #313 from PolicyEngine/nikhilwoodruff/issue312
Browse files Browse the repository at this point in the history
Allow `Simulation` to load H5 files and HuggingFace H5 files directly
  • Loading branch information
anth-volk authored Nov 29, 2024
2 parents 239e7b9 + a83fbc8 commit e1bbf58
Show file tree
Hide file tree
Showing 4 changed files with 55 additions and 22 deletions.
4 changes: 4 additions & 0 deletions changelog_entry.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
- bump: minor
changes:
added:
- Support for HuggingFace-hosted H5 file inputs.
42 changes: 21 additions & 21 deletions policyengine_core/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,7 @@
import requests
import os
import tempfile
from huggingface_hub import HfApi, login, hf_hub_download
import pkg_resources
from policyengine_core.tools.hugging_face import *


def atomic_write(file: Path, content: bytes) -> None:
Expand Down Expand Up @@ -318,7 +317,7 @@ def download(self, url: str = None, version: str = None) -> None:
"""

if url is None:
url = self.huggingface_url or self.url
url = self.url or self.huggingface_url

if "POLICYENGINE_GITHUB_MICRODATA_AUTH_TOKEN" not in os.environ:
auth_headers = {}
Expand Down Expand Up @@ -400,13 +399,29 @@ def from_file(file_path: str, time_period: str = None):
Dataset: The dataset.
"""
file_path = Path(file_path)

# If it's a h5 file, check the first key

if file_path.suffix == ".h5":
with h5py.File(file_path, "r") as f:
first_key = list(f.keys())[0]
first_value = f[first_key]
if isinstance(first_value, h5py.Dataset):
data_format = Dataset.ARRAYS
else:
data_format = Dataset.TIME_PERIOD_ARRAYS
subkeys = list(first_value.keys())
if len(subkeys) > 0:
time_period = subkeys[0]
else:
data_format = Dataset.FLAT_FILE
dataset = type(
"Dataset",
(Dataset,),
{
"name": file_path.stem,
"label": file_path.stem,
"data_format": Dataset.FLAT_FILE,
"data_format": data_format,
"file_path": file_path,
"time_period": time_period,
},
Expand Down Expand Up @@ -446,21 +461,14 @@ def upload_to_huggingface(self, owner_name: str, model_name: str):
token = os.environ.get(
"HUGGING_FACE_TOKEN",
)
login(token=token)
api = HfApi()

# Add the policyengine-uk-data version and policyengine-uk version to the h5 metadata.
uk_data_version = get_package_version("policyengine-uk-data")
uk_version = get_package_version("policyengine-uk")
with h5py.File(self.file_path, "a") as f:
f.attrs["policyengine-uk-data"] = uk_data_version
f.attrs["policyengine-uk"] = uk_version

api.upload_file(
path_or_fileobj=self.file_path,
path_in_repo=self.file_path.name,
repo_id=f"{owner_name}/{model_name}",
repo_type="model",
token=token,
)

def download_from_huggingface(
Expand All @@ -475,19 +483,11 @@ def download_from_huggingface(
token = os.environ.get(
"HUGGING_FACE_TOKEN",
)
login(token=token)

hf_hub_download(
repo_id=f"{owner_name}/{model_name}",
repo_type="model",
path=self.file_path,
revision=version,
token=token,
)


def get_package_version(package_name: str) -> str:
"""Get the installed version of a package."""
try:
return pkg_resources.get_distribution(package_name).version
except pkg_resources.DistributionNotFound:
return "not installed"
11 changes: 10 additions & 1 deletion policyengine_core/simulations/simulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
TracingParameterNodeAtInstant,
)
import random
from policyengine_core.tools.hugging_face import *

import json

Expand Down Expand Up @@ -54,7 +55,7 @@ class Simulation:
default_dataset: Dataset = None
"""The default dataset class to use if none is provided."""

default_role: str = None
default_role: str = "member"
"""The default role to assign people to groups if none is provided."""

default_input_period: str = None
Expand Down Expand Up @@ -151,6 +152,14 @@ def __init__(

if dataset is not None:
if isinstance(dataset, str):
if "hf://" in dataset:
owner, repo, filename = dataset.split("/")[-3:]
if "@" in filename:
version = filename.split("@")[-1]
filename = filename.split("@")[0]
else:
version = None
dataset = download(owner + "/" + repo, filename, version)
datasets_by_name = {
dataset.name: dataset for dataset in self.datasets
}
Expand Down
20 changes: 20 additions & 0 deletions policyengine_core/tools/hugging_face.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
from huggingface_hub import hf_hub_download, login, HfApi
import os
import warnings

with warnings.catch_warnings():
warnings.simplefilter("ignore")


def download(repo: str, repo_filename: str, version: str = None):
token = os.environ.get(
"HUGGING_FACE_TOKEN",
)

return hf_hub_download(
repo_id=repo,
repo_type="model",
filename=repo_filename,
revision=version,
token=token,
)

0 comments on commit e1bbf58

Please sign in to comment.