From 85b1e19c4577db287d1ac35e00b09ae9d3c8efe3 Mon Sep 17 00:00:00 2001 From: Nikhil Woodruff Date: Wed, 27 Nov 2024 22:29:19 +0000 Subject: [PATCH 1/3] Allow `Simulation` to load H5 files and HuggingFace H5 files directly Fixes #312 --- changelog_entry.yaml | 4 +++ policyengine_core/data/dataset.py | 36 ++++++++++----------- policyengine_core/simulations/simulation.py | 11 ++++++- policyengine_core/tools/hugging_face.py | 20 ++++++++++++ 4 files changed, 52 insertions(+), 19 deletions(-) create mode 100644 policyengine_core/tools/hugging_face.py diff --git a/changelog_entry.yaml b/changelog_entry.yaml index e69de29b..61f3f7e0 100644 --- a/changelog_entry.yaml +++ b/changelog_entry.yaml @@ -0,0 +1,4 @@ +- bump: minor + changes: + added: + - Support for HuggingFace-hosted H5 file inputs. diff --git a/policyengine_core/data/dataset.py b/policyengine_core/data/dataset.py index 9e0558eb..f24cb899 100644 --- a/policyengine_core/data/dataset.py +++ b/policyengine_core/data/dataset.py @@ -7,8 +7,7 @@ import requests import os import tempfile -from huggingface_hub import HfApi, login, hf_hub_download -import pkg_resources +from policyengine_core.tools.hugging_face import * def atomic_write(file: Path, content: bytes) -> None: @@ -400,13 +399,29 @@ def from_file(file_path: str, time_period: str = None): Dataset: The dataset. """ file_path = Path(file_path) + + # If it's a h5 file, check the first key + + if file_path.suffix == ".h5": + with h5py.File(file_path, "r") as f: + first_key = list(f.keys())[0] + first_value = f[first_key] + if isinstance(first_value, h5py.Dataset): + data_format = Dataset.ARRAYS + else: + data_format = Dataset.TIME_PERIOD_ARRAYS + subkeys = list(first_value.keys()) + if len(subkeys) > 0: + time_period = subkeys[0] + else: + data_format = Dataset.FLAT_FILE dataset = type( "Dataset", (Dataset,), { "name": file_path.stem, "label": file_path.stem, - "data_format": Dataset.FLAT_FILE, + "data_format": data_format, "file_path": file_path, "time_period": time_period, }, @@ -449,13 +464,6 @@ def upload_to_huggingface(self, owner_name: str, model_name: str): login(token=token) api = HfApi() - # Add the policyengine-uk-data version and policyengine-uk version to the h5 metadata. - uk_data_version = get_package_version("policyengine-uk-data") - uk_version = get_package_version("policyengine-uk") - with h5py.File(self.file_path, "a") as f: - f.attrs["policyengine-uk-data"] = uk_data_version - f.attrs["policyengine-uk"] = uk_version - api.upload_file( path_or_fileobj=self.file_path, path_in_repo=self.file_path.name, @@ -483,11 +491,3 @@ def download_from_huggingface( path=self.file_path, revision=version, ) - - -def get_package_version(package_name: str) -> str: - """Get the installed version of a package.""" - try: - return pkg_resources.get_distribution(package_name).version - except pkg_resources.DistributionNotFound: - return "not installed" diff --git a/policyengine_core/simulations/simulation.py b/policyengine_core/simulations/simulation.py index a6b5579b..3971fb90 100644 --- a/policyengine_core/simulations/simulation.py +++ b/policyengine_core/simulations/simulation.py @@ -23,6 +23,7 @@ TracingParameterNodeAtInstant, ) import random +from policyengine_core.tools.hugging_face import * import json @@ -54,7 +55,7 @@ class Simulation: default_dataset: Dataset = None """The default dataset class to use if none is provided.""" - default_role: str = None + default_role: str = "member" """The default role to assign people to groups if none is provided.""" default_input_period: str = None @@ -151,6 +152,14 @@ def __init__( if dataset is not None: if isinstance(dataset, str): + if "hf://" in dataset: + owner, repo, filename = dataset.split("/")[-3:] + if "@" in filename: + version = filename.split("@")[-1] + filename = filename.split("@")[0] + else: + version = None + dataset = download(owner + "/" + repo, filename, version) datasets_by_name = { dataset.name: dataset for dataset in self.datasets } diff --git a/policyengine_core/tools/hugging_face.py b/policyengine_core/tools/hugging_face.py new file mode 100644 index 00000000..b43df592 --- /dev/null +++ b/policyengine_core/tools/hugging_face.py @@ -0,0 +1,20 @@ +from huggingface_hub import hf_hub_download, login, HfApi +import os +import warnings + +with warnings.catch_warnings(): + warnings.simplefilter("ignore") + + +def download(repo: str, repo_filename: str, version: str = None): + token = os.environ.get( + "HUGGING_FACE_TOKEN", + ) + + return hf_hub_download( + repo_id=repo, + repo_type="model", + filename=repo_filename, + revision=version, + token=token, + ) From 4bdebb05d390a0851cd2623764feaaba536ceaa1 Mon Sep 17 00:00:00 2001 From: Nikhil Woodruff Date: Fri, 29 Nov 2024 09:30:09 +0000 Subject: [PATCH 2/3] Make default URL the GH url for backwards compatibility --- policyengine_core/data/dataset.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/policyengine_core/data/dataset.py b/policyengine_core/data/dataset.py index f24cb899..9f3b1848 100644 --- a/policyengine_core/data/dataset.py +++ b/policyengine_core/data/dataset.py @@ -317,7 +317,7 @@ def download(self, url: str = None, version: str = None) -> None: """ if url is None: - url = self.huggingface_url or self.url + url = self.url or self.huggingface_url if "POLICYENGINE_GITHUB_MICRODATA_AUTH_TOKEN" not in os.environ: auth_headers = {} From a83fbc842e3ebd493fc9a43116bd0a5684fb027e Mon Sep 17 00:00:00 2001 From: Nikhil Woodruff Date: Fri, 29 Nov 2024 14:06:21 +0000 Subject: [PATCH 3/3] Don't re-login every time --- policyengine_core/data/dataset.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/policyengine_core/data/dataset.py b/policyengine_core/data/dataset.py index 9f3b1848..d648417e 100644 --- a/policyengine_core/data/dataset.py +++ b/policyengine_core/data/dataset.py @@ -461,7 +461,6 @@ def upload_to_huggingface(self, owner_name: str, model_name: str): token = os.environ.get( "HUGGING_FACE_TOKEN", ) - login(token=token) api = HfApi() api.upload_file( @@ -469,6 +468,7 @@ def upload_to_huggingface(self, owner_name: str, model_name: str): path_in_repo=self.file_path.name, repo_id=f"{owner_name}/{model_name}", repo_type="model", + token=token, ) def download_from_huggingface( @@ -483,11 +483,11 @@ def download_from_huggingface( token = os.environ.get( "HUGGING_FACE_TOKEN", ) - login(token=token) hf_hub_download( repo_id=f"{owner_name}/{model_name}", repo_type="model", path=self.file_path, revision=version, + token=token, )