Skip to content

Commit

Permalink
Add HuggingFace dataset upload/download (#310)
Browse files Browse the repository at this point in the history
* Add HuggingFace dataset upload/download
Fixes #309

* Changelog

* Add dep
  • Loading branch information
nikhilwoodruff authored Nov 27, 2024
1 parent 9fbe198 commit 711df81
Show file tree
Hide file tree
Showing 3 changed files with 84 additions and 2 deletions.
4 changes: 4 additions & 0 deletions changelog_entry.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
- bump: minor
changes:
added:
- HuggingFace upload/download functionality.
81 changes: 79 additions & 2 deletions policyengine_core/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
import requests
import os
import tempfile
from huggingface_hub import HfApi, login, hf_hub_download
import pkg_resources


def atomic_write(file: Path, content: bytes) -> None:
Expand Down Expand Up @@ -53,6 +55,8 @@ class Dataset:
"""The time period of the dataset. This is used to automatically enter the values in the correct time period if the data type is `Dataset.ARRAYS`."""
url: str = None
"""The URL to download the dataset from. This is used to download the dataset if it does not exist."""
huggingface_url: str = None
"""The HuggingFace URL to download the dataset from. This is used to download the dataset if it does not exist."""

# Data formats
TABLES = "tables"
Expand Down Expand Up @@ -306,15 +310,15 @@ def store_file(self, file_path: str):
raise FileNotFoundError(f"File {file_path} does not exist.")
shutil.move(file_path, self.file_path)

def download(self, url: str = None) -> None:
def download(self, url: str = None, version: str = None) -> None:
"""Downloads a file to the dataset's file path.
Args:
url (str): The url to download.
"""

if url is None:
url = self.url
url = self.huggingface_url or self.url

if "POLICYENGINE_GITHUB_MICRODATA_AUTH_TOKEN" not in os.environ:
auth_headers = {}
Expand Down Expand Up @@ -345,6 +349,10 @@ def download(self, url: str = None) -> None:
raise ValueError(
f"File {file_path} not found in release {release_tag} of {org}/{repo}."
)
elif url.startswith("hf://"):
owner_name, model_name = url.split("/")[2:]
self.download_from_huggingface(owner_name, model_name, version)
return
else:
url = url

Expand All @@ -363,6 +371,19 @@ def download(self, url: str = None) -> None:

atomic_write(self.file_path, response.content)

def upload(self, url: str = None):
"""Uploads the dataset to a URL.
Args:
url (str): The url to upload.
"""
if url is None:
url = self.huggingface_url or self.url

if url.startswith("hf://"):
owner_name, model_name = url.split("/")[2:]
self.upload_to_huggingface(owner_name, model_name)

def remove(self):
"""Removes the dataset from disk."""
if self.exists:
Expand Down Expand Up @@ -414,3 +435,59 @@ def from_dataframe(dataframe: pd.DataFrame, time_period: str = None):
)()

return dataset

def upload_to_huggingface(self, owner_name: str, model_name: str):
"""Uploads the dataset to Hugging Face.
Args:
owner_name (str): The owner name.
model_name (str): The model name.
"""
token = os.environ.get(
"HUGGING_FACE_TOKEN", "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"
)
login(token=token)
api = HfApi()

# Add the policyengine-uk-data version and policyengine-uk version to the h5 metadata.
uk_data_version = get_package_version("policyengine-uk-data")
uk_version = get_package_version("policyengine-uk")
with h5py.File(self.file_path, "a") as f:
f.attrs["policyengine-uk-data"] = uk_data_version
f.attrs["policyengine-uk"] = uk_version

api.upload_file(
path_or_fileobj=self.file_path,
path_in_repo=self.file_path.name,
repo_id=f"{owner_name}/{model_name}",
repo_type="model",
)

def download_from_huggingface(
self, owner_name: str, model_name: str, version: str = None
):
"""Downloads the dataset from Hugging Face.
Args:
owner_name (str): The owner name.
model_name (str): The model name.
"""
token = os.environ.get(
"HUGGING_FACE_TOKEN", "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"
)
login(token=token)

hf_hub_download(
repo_id=f"{owner_name}/{model_name}",
repo_type="model",
path=self.file_path,
revision=version,
)


def get_package_version(package_name: str) -> str:
"""Get the installed version of a package."""
try:
return pkg_resources.get_distribution(package_name).version
except pkg_resources.DistributionNotFound:
return "not installed"
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
"ipython>=8,<9",
"pyvis>=0.3.2",
"microdf_python>=0.4.3",
"huggingface_hub>=0.25.1",
]

dev_requirements = [
Expand Down

0 comments on commit 711df81

Please sign in to comment.