diff --git a/changelog_entry.yaml b/changelog_entry.yaml index e69de29b..61f3f7e0 100644 --- a/changelog_entry.yaml +++ b/changelog_entry.yaml @@ -0,0 +1,4 @@ +- bump: minor + changes: + added: + - Support for HuggingFace-hosted H5 file inputs. diff --git a/policyengine_core/data/dataset.py b/policyengine_core/data/dataset.py index 9e0558eb..d648417e 100644 --- a/policyengine_core/data/dataset.py +++ b/policyengine_core/data/dataset.py @@ -7,8 +7,7 @@ import requests import os import tempfile -from huggingface_hub import HfApi, login, hf_hub_download -import pkg_resources +from policyengine_core.tools.hugging_face import * def atomic_write(file: Path, content: bytes) -> None: @@ -318,7 +317,7 @@ def download(self, url: str = None, version: str = None) -> None: """ if url is None: - url = self.huggingface_url or self.url + url = self.url or self.huggingface_url if "POLICYENGINE_GITHUB_MICRODATA_AUTH_TOKEN" not in os.environ: auth_headers = {} @@ -400,13 +399,29 @@ def from_file(file_path: str, time_period: str = None): Dataset: The dataset. """ file_path = Path(file_path) + + # If it's a h5 file, check the first key + + if file_path.suffix == ".h5": + with h5py.File(file_path, "r") as f: + first_key = list(f.keys())[0] + first_value = f[first_key] + if isinstance(first_value, h5py.Dataset): + data_format = Dataset.ARRAYS + else: + data_format = Dataset.TIME_PERIOD_ARRAYS + subkeys = list(first_value.keys()) + if len(subkeys) > 0: + time_period = subkeys[0] + else: + data_format = Dataset.FLAT_FILE dataset = type( "Dataset", (Dataset,), { "name": file_path.stem, "label": file_path.stem, - "data_format": Dataset.FLAT_FILE, + "data_format": data_format, "file_path": file_path, "time_period": time_period, }, @@ -446,21 +461,14 @@ def upload_to_huggingface(self, owner_name: str, model_name: str): token = os.environ.get( "HUGGING_FACE_TOKEN", ) - login(token=token) api = HfApi() - # Add the policyengine-uk-data version and policyengine-uk version to the h5 metadata. - uk_data_version = get_package_version("policyengine-uk-data") - uk_version = get_package_version("policyengine-uk") - with h5py.File(self.file_path, "a") as f: - f.attrs["policyengine-uk-data"] = uk_data_version - f.attrs["policyengine-uk"] = uk_version - api.upload_file( path_or_fileobj=self.file_path, path_in_repo=self.file_path.name, repo_id=f"{owner_name}/{model_name}", repo_type="model", + token=token, ) def download_from_huggingface( @@ -475,19 +483,11 @@ def download_from_huggingface( token = os.environ.get( "HUGGING_FACE_TOKEN", ) - login(token=token) hf_hub_download( repo_id=f"{owner_name}/{model_name}", repo_type="model", path=self.file_path, revision=version, + token=token, ) - - -def get_package_version(package_name: str) -> str: - """Get the installed version of a package.""" - try: - return pkg_resources.get_distribution(package_name).version - except pkg_resources.DistributionNotFound: - return "not installed" diff --git a/policyengine_core/simulations/simulation.py b/policyengine_core/simulations/simulation.py index a6b5579b..3971fb90 100644 --- a/policyengine_core/simulations/simulation.py +++ b/policyengine_core/simulations/simulation.py @@ -23,6 +23,7 @@ TracingParameterNodeAtInstant, ) import random +from policyengine_core.tools.hugging_face import * import json @@ -54,7 +55,7 @@ class Simulation: default_dataset: Dataset = None """The default dataset class to use if none is provided.""" - default_role: str = None + default_role: str = "member" """The default role to assign people to groups if none is provided.""" default_input_period: str = None @@ -151,6 +152,14 @@ def __init__( if dataset is not None: if isinstance(dataset, str): + if "hf://" in dataset: + owner, repo, filename = dataset.split("/")[-3:] + if "@" in filename: + version = filename.split("@")[-1] + filename = filename.split("@")[0] + else: + version = None + dataset = download(owner + "/" + repo, filename, version) datasets_by_name = { dataset.name: dataset for dataset in self.datasets } diff --git a/policyengine_core/tools/hugging_face.py b/policyengine_core/tools/hugging_face.py new file mode 100644 index 00000000..b43df592 --- /dev/null +++ b/policyengine_core/tools/hugging_face.py @@ -0,0 +1,20 @@ +from huggingface_hub import hf_hub_download, login, HfApi +import os +import warnings + +with warnings.catch_warnings(): + warnings.simplefilter("ignore") + + +def download(repo: str, repo_filename: str, version: str = None): + token = os.environ.get( + "HUGGING_FACE_TOKEN", + ) + + return hf_hub_download( + repo_id=repo, + repo_type="model", + filename=repo_filename, + revision=version, + token=token, + )