diff --git a/changelog_entry.yaml b/changelog_entry.yaml index e69de29b..004e6066 100644 --- a/changelog_entry.yaml +++ b/changelog_entry.yaml @@ -0,0 +1,4 @@ +- bump: minor + changes: + added: + - HuggingFace upload/download functionality. diff --git a/policyengine_core/data/dataset.py b/policyengine_core/data/dataset.py index 7a0f792d..1dd87d89 100644 --- a/policyengine_core/data/dataset.py +++ b/policyengine_core/data/dataset.py @@ -7,6 +7,8 @@ import requests import os import tempfile +from huggingface_hub import HfApi, login, hf_hub_download +import pkg_resources def atomic_write(file: Path, content: bytes) -> None: @@ -53,6 +55,8 @@ class Dataset: """The time period of the dataset. This is used to automatically enter the values in the correct time period if the data type is `Dataset.ARRAYS`.""" url: str = None """The URL to download the dataset from. This is used to download the dataset if it does not exist.""" + huggingface_url: str = None + """The HuggingFace URL to download the dataset from. This is used to download the dataset if it does not exist.""" # Data formats TABLES = "tables" @@ -306,7 +310,7 @@ def store_file(self, file_path: str): raise FileNotFoundError(f"File {file_path} does not exist.") shutil.move(file_path, self.file_path) - def download(self, url: str = None) -> None: + def download(self, url: str = None, version: str = None) -> None: """Downloads a file to the dataset's file path. Args: @@ -314,7 +318,7 @@ def download(self, url: str = None) -> None: """ if url is None: - url = self.url + url = self.huggingface_url or self.url if "POLICYENGINE_GITHUB_MICRODATA_AUTH_TOKEN" not in os.environ: auth_headers = {} @@ -345,6 +349,10 @@ def download(self, url: str = None) -> None: raise ValueError( f"File {file_path} not found in release {release_tag} of {org}/{repo}." ) + elif url.startswith("hf://"): + owner_name, model_name = url.split("/")[2:] + self.download_from_huggingface(owner_name, model_name, version) + return else: url = url @@ -363,6 +371,19 @@ def download(self, url: str = None) -> None: atomic_write(self.file_path, response.content) + def upload(self, url: str = None): + """Uploads the dataset to a URL. + + Args: + url (str): The url to upload. + """ + if url is None: + url = self.huggingface_url or self.url + + if url.startswith("hf://"): + owner_name, model_name = url.split("/")[2:] + self.upload_to_huggingface(owner_name, model_name) + def remove(self): """Removes the dataset from disk.""" if self.exists: @@ -414,3 +435,59 @@ def from_dataframe(dataframe: pd.DataFrame, time_period: str = None): )() return dataset + + def upload_to_huggingface(self, owner_name: str, model_name: str): + """Uploads the dataset to Hugging Face. + + Args: + owner_name (str): The owner name. + model_name (str): The model name. + """ + token = os.environ.get( + "HUGGING_FACE_TOKEN", "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" + ) + login(token=token) + api = HfApi() + + # Add the policyengine-uk-data version and policyengine-uk version to the h5 metadata. + uk_data_version = get_package_version("policyengine-uk-data") + uk_version = get_package_version("policyengine-uk") + with h5py.File(self.file_path, "a") as f: + f.attrs["policyengine-uk-data"] = uk_data_version + f.attrs["policyengine-uk"] = uk_version + + api.upload_file( + path_or_fileobj=self.file_path, + path_in_repo=self.file_path.name, + repo_id=f"{owner_name}/{model_name}", + repo_type="model", + ) + + def download_from_huggingface( + self, owner_name: str, model_name: str, version: str = None + ): + """Downloads the dataset from Hugging Face. + + Args: + owner_name (str): The owner name. + model_name (str): The model name. + """ + token = os.environ.get( + "HUGGING_FACE_TOKEN", "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" + ) + login(token=token) + + hf_hub_download( + repo_id=f"{owner_name}/{model_name}", + repo_type="model", + path=self.file_path, + revision=version, + ) + + +def get_package_version(package_name: str) -> str: + """Get the installed version of a package.""" + try: + return pkg_resources.get_distribution(package_name).version + except pkg_resources.DistributionNotFound: + return "not installed" diff --git a/setup.py b/setup.py index 422ae149..ebad5d89 100644 --- a/setup.py +++ b/setup.py @@ -24,6 +24,7 @@ "ipython>=8,<9", "pyvis>=0.3.2", "microdf_python>=0.4.3", + "huggingface_hub>=0.25.1", ] dev_requirements = [