Skip to content

Commit 239e7b9

Browse files
Add huggingface URLs (#311)
Fixes #309
1 parent 9fbe198 commit 239e7b9

File tree

4 files changed

+93
-3
lines changed

4 files changed

+93
-3
lines changed

CHANGELOG.md

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,12 @@ All notable changes to this project will be documented in this file.
55
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
66
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
77

8+
## [3.13.0] - 2024-11-27 13:12:44
9+
10+
### Added
11+
12+
- HuggingFace upload/download functionality.
13+
814
## [3.12.5] - 2024-11-20 13:13:13
915

1016
### Changed
@@ -932,6 +938,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
932938

933939

934940

941+
[3.13.0]: https://github.com/PolicyEngine/policyengine-core/compare/3.12.5...3.13.0
935942
[3.12.5]: https://github.com/PolicyEngine/policyengine-core/compare/3.12.4...3.12.5
936943
[3.12.4]: https://github.com/PolicyEngine/policyengine-core/compare/3.12.3...3.12.4
937944
[3.12.3]: https://github.com/PolicyEngine/policyengine-core/compare/3.12.2...3.12.3

changelog.yaml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -755,3 +755,8 @@
755755
- update the furo requirment to <2025
756756
- update the markupsafe requirement to <3
757757
date: 2024-11-20 13:13:13
758+
- bump: minor
759+
changes:
760+
added:
761+
- HuggingFace upload/download functionality.
762+
date: 2024-11-27 13:12:44

policyengine_core/data/dataset.py

Lines changed: 79 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@
77
import requests
88
import os
99
import tempfile
10+
from huggingface_hub import HfApi, login, hf_hub_download
11+
import pkg_resources
1012

1113

1214
def atomic_write(file: Path, content: bytes) -> None:
@@ -53,6 +55,8 @@ class Dataset:
5355
"""The time period of the dataset. This is used to automatically enter the values in the correct time period if the data type is `Dataset.ARRAYS`."""
5456
url: str = None
5557
"""The URL to download the dataset from. This is used to download the dataset if it does not exist."""
58+
huggingface_url: str = None
59+
"""The HuggingFace URL to download the dataset from. This is used to download the dataset if it does not exist."""
5660

5761
# Data formats
5862
TABLES = "tables"
@@ -306,15 +310,15 @@ def store_file(self, file_path: str):
306310
raise FileNotFoundError(f"File {file_path} does not exist.")
307311
shutil.move(file_path, self.file_path)
308312

309-
def download(self, url: str = None) -> None:
313+
def download(self, url: str = None, version: str = None) -> None:
310314
"""Downloads a file to the dataset's file path.
311315
312316
Args:
313317
url (str): The url to download.
314318
"""
315319

316320
if url is None:
317-
url = self.url
321+
url = self.huggingface_url or self.url
318322

319323
if "POLICYENGINE_GITHUB_MICRODATA_AUTH_TOKEN" not in os.environ:
320324
auth_headers = {}
@@ -345,6 +349,10 @@ def download(self, url: str = None) -> None:
345349
raise ValueError(
346350
f"File {file_path} not found in release {release_tag} of {org}/{repo}."
347351
)
352+
elif url.startswith("hf://"):
353+
owner_name, model_name = url.split("/")[2:]
354+
self.download_from_huggingface(owner_name, model_name, version)
355+
return
348356
else:
349357
url = url
350358

@@ -363,6 +371,19 @@ def download(self, url: str = None) -> None:
363371

364372
atomic_write(self.file_path, response.content)
365373

374+
def upload(self, url: str = None):
375+
"""Uploads the dataset to a URL.
376+
377+
Args:
378+
url (str): The url to upload.
379+
"""
380+
if url is None:
381+
url = self.huggingface_url or self.url
382+
383+
if url.startswith("hf://"):
384+
owner_name, model_name = url.split("/")[2:]
385+
self.upload_to_huggingface(owner_name, model_name)
386+
366387
def remove(self):
367388
"""Removes the dataset from disk."""
368389
if self.exists:
@@ -414,3 +435,59 @@ def from_dataframe(dataframe: pd.DataFrame, time_period: str = None):
414435
)()
415436

416437
return dataset
438+
439+
def upload_to_huggingface(self, owner_name: str, model_name: str):
440+
"""Uploads the dataset to Hugging Face.
441+
442+
Args:
443+
owner_name (str): The owner name.
444+
model_name (str): The model name.
445+
"""
446+
token = os.environ.get(
447+
"HUGGING_FACE_TOKEN",
448+
)
449+
login(token=token)
450+
api = HfApi()
451+
452+
# Add the policyengine-uk-data version and policyengine-uk version to the h5 metadata.
453+
uk_data_version = get_package_version("policyengine-uk-data")
454+
uk_version = get_package_version("policyengine-uk")
455+
with h5py.File(self.file_path, "a") as f:
456+
f.attrs["policyengine-uk-data"] = uk_data_version
457+
f.attrs["policyengine-uk"] = uk_version
458+
459+
api.upload_file(
460+
path_or_fileobj=self.file_path,
461+
path_in_repo=self.file_path.name,
462+
repo_id=f"{owner_name}/{model_name}",
463+
repo_type="model",
464+
)
465+
466+
def download_from_huggingface(
467+
self, owner_name: str, model_name: str, version: str = None
468+
):
469+
"""Downloads the dataset from Hugging Face.
470+
471+
Args:
472+
owner_name (str): The owner name.
473+
model_name (str): The model name.
474+
"""
475+
token = os.environ.get(
476+
"HUGGING_FACE_TOKEN",
477+
)
478+
login(token=token)
479+
480+
hf_hub_download(
481+
repo_id=f"{owner_name}/{model_name}",
482+
repo_type="model",
483+
path=self.file_path,
484+
revision=version,
485+
)
486+
487+
488+
def get_package_version(package_name: str) -> str:
489+
"""Get the installed version of a package."""
490+
try:
491+
return pkg_resources.get_distribution(package_name).version
492+
except pkg_resources.DistributionNotFound:
493+
return "not installed"

setup.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
"ipython>=8,<9",
2525
"pyvis>=0.3.2",
2626
"microdf_python>=0.4.3",
27+
"huggingface_hub>=0.25.1",
2728
]
2829

2930
dev_requirements = [
@@ -48,7 +49,7 @@
4849

4950
setup(
5051
name="policyengine-core",
51-
version="3.12.5",
52+
version="3.13.0",
5253
author="PolicyEngine",
5354
author_email="[email protected]",
5455
classifiers=[

0 commit comments

Comments
 (0)