From 42f6ce2a0a733f0b2d0ec4316782eb394fb5f966 Mon Sep 17 00:00:00 2001 From: Nikhil Woodruff Date: Wed, 27 Nov 2024 13:20:35 +0000 Subject: [PATCH] Add huggingface URLs Fixes #309 --- CHANGELOG.md | 7 +++ changelog.yaml | 5 ++ policyengine_core/data/dataset.py | 81 ++++++++++++++++++++++++++++++- setup.py | 3 +- 4 files changed, 93 insertions(+), 3 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index a189cdc0..e2914c38 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,6 +5,12 @@ All notable changes to this project will be documented in this file. The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). +## [3.13.0] - 2024-11-27 13:12:44 + +### Added + +- HuggingFace upload/download functionality. + ## [3.12.5] - 2024-11-20 13:13:13 ### Changed @@ -932,6 +938,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 +[3.13.0]: https://github.com/PolicyEngine/policyengine-core/compare/3.12.5...3.13.0 [3.12.5]: https://github.com/PolicyEngine/policyengine-core/compare/3.12.4...3.12.5 [3.12.4]: https://github.com/PolicyEngine/policyengine-core/compare/3.12.3...3.12.4 [3.12.3]: https://github.com/PolicyEngine/policyengine-core/compare/3.12.2...3.12.3 diff --git a/changelog.yaml b/changelog.yaml index 41fec5c7..be727d65 100644 --- a/changelog.yaml +++ b/changelog.yaml @@ -755,3 +755,8 @@ - update the furo requirment to <2025 - update the markupsafe requirement to <3 date: 2024-11-20 13:13:13 +- bump: minor + changes: + added: + - HuggingFace upload/download functionality. + date: 2024-11-27 13:12:44 diff --git a/policyengine_core/data/dataset.py b/policyengine_core/data/dataset.py index 7a0f792d..9e0558eb 100644 --- a/policyengine_core/data/dataset.py +++ b/policyengine_core/data/dataset.py @@ -7,6 +7,8 @@ import requests import os import tempfile +from huggingface_hub import HfApi, login, hf_hub_download +import pkg_resources def atomic_write(file: Path, content: bytes) -> None: @@ -53,6 +55,8 @@ class Dataset: """The time period of the dataset. This is used to automatically enter the values in the correct time period if the data type is `Dataset.ARRAYS`.""" url: str = None """The URL to download the dataset from. This is used to download the dataset if it does not exist.""" + huggingface_url: str = None + """The HuggingFace URL to download the dataset from. This is used to download the dataset if it does not exist.""" # Data formats TABLES = "tables" @@ -306,7 +310,7 @@ def store_file(self, file_path: str): raise FileNotFoundError(f"File {file_path} does not exist.") shutil.move(file_path, self.file_path) - def download(self, url: str = None) -> None: + def download(self, url: str = None, version: str = None) -> None: """Downloads a file to the dataset's file path. Args: @@ -314,7 +318,7 @@ def download(self, url: str = None) -> None: """ if url is None: - url = self.url + url = self.huggingface_url or self.url if "POLICYENGINE_GITHUB_MICRODATA_AUTH_TOKEN" not in os.environ: auth_headers = {} @@ -345,6 +349,10 @@ def download(self, url: str = None) -> None: raise ValueError( f"File {file_path} not found in release {release_tag} of {org}/{repo}." ) + elif url.startswith("hf://"): + owner_name, model_name = url.split("/")[2:] + self.download_from_huggingface(owner_name, model_name, version) + return else: url = url @@ -363,6 +371,19 @@ def download(self, url: str = None) -> None: atomic_write(self.file_path, response.content) + def upload(self, url: str = None): + """Uploads the dataset to a URL. + + Args: + url (str): The url to upload. + """ + if url is None: + url = self.huggingface_url or self.url + + if url.startswith("hf://"): + owner_name, model_name = url.split("/")[2:] + self.upload_to_huggingface(owner_name, model_name) + def remove(self): """Removes the dataset from disk.""" if self.exists: @@ -414,3 +435,59 @@ def from_dataframe(dataframe: pd.DataFrame, time_period: str = None): )() return dataset + + def upload_to_huggingface(self, owner_name: str, model_name: str): + """Uploads the dataset to Hugging Face. + + Args: + owner_name (str): The owner name. + model_name (str): The model name. + """ + token = os.environ.get( + "HUGGING_FACE_TOKEN", + ) + login(token=token) + api = HfApi() + + # Add the policyengine-uk-data version and policyengine-uk version to the h5 metadata. + uk_data_version = get_package_version("policyengine-uk-data") + uk_version = get_package_version("policyengine-uk") + with h5py.File(self.file_path, "a") as f: + f.attrs["policyengine-uk-data"] = uk_data_version + f.attrs["policyengine-uk"] = uk_version + + api.upload_file( + path_or_fileobj=self.file_path, + path_in_repo=self.file_path.name, + repo_id=f"{owner_name}/{model_name}", + repo_type="model", + ) + + def download_from_huggingface( + self, owner_name: str, model_name: str, version: str = None + ): + """Downloads the dataset from Hugging Face. + + Args: + owner_name (str): The owner name. + model_name (str): The model name. + """ + token = os.environ.get( + "HUGGING_FACE_TOKEN", + ) + login(token=token) + + hf_hub_download( + repo_id=f"{owner_name}/{model_name}", + repo_type="model", + path=self.file_path, + revision=version, + ) + + +def get_package_version(package_name: str) -> str: + """Get the installed version of a package.""" + try: + return pkg_resources.get_distribution(package_name).version + except pkg_resources.DistributionNotFound: + return "not installed" diff --git a/setup.py b/setup.py index 422ae149..fa168135 100644 --- a/setup.py +++ b/setup.py @@ -24,6 +24,7 @@ "ipython>=8,<9", "pyvis>=0.3.2", "microdf_python>=0.4.3", + "huggingface_hub>=0.25.1", ] dev_requirements = [ @@ -48,7 +49,7 @@ setup( name="policyengine-core", - version="3.12.5", + version="3.13.0", author="PolicyEngine", author_email="hello@policyengine.org", classifiers=[