-
Notifications
You must be signed in to change notification settings - Fork 20
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conditionally pass Hugging Face tokens for download when repo is private #321
base: master
Are you sure you want to change the base?
Conditionally pass Hugging Face tokens for download when repo is private #321
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for doing this! Here are suggestions from https://pr-improver.streamlit.app that I haven't reviewed:
Here are 6 specific suggestions to improve this PR:
-
File: policyengine_core/tools/hugging_face.py
Lines: 17-18
Change:def hf_download(repo: str, repo_filename: str, version: str = None):
to:
def download_huggingface_dataset(repository: str, file_name: str, version: str = None):
Explanation: More descriptive function name and parameter names improve clarity, especially for non-native English speakers.
-
File: policyengine_core/tools/hugging_face.py
Lines: 31-34
Add comment:# Attempt to fetch model info to determine if repo is private # A RepositoryNotFoundError likely means the repo is private and requires authentication try: fetched_model_info: ModelInfo = model_info(repository) is_repo_private: bool = fetched_model_info.private except RepositoryNotFoundError: is_repo_private = True
Explanation: This comment clarifies the purpose of the try-except block and explains the logic behind assuming a private repository.
-
File: policyengine_core/tools/hugging_face.py
Lines: 46-50
Change:token: str = None if is_repo_private: token: str = get_or_prompt_hf_token()
to:
authentication_token: str = None if is_repo_private: authentication_token: str = get_or_prompt_huggingface_token()
Explanation: More descriptive variable names and function name improve clarity.
-
File: policyengine_core/simulations/simulation.py
Lines: 161-163
Change:dataset = hf_download( owner + "/" + repo, filename, version )
to:
dataset = download_huggingface_dataset( repository=f"{owner}/{repo}", file_name=filename, version=version )
Explanation: Use the new function name and provide named arguments for better readability.
-
File: tests/core/tools/test_hugging_face.py
Lines: 13-14
Change class name:class TestHfDownload:
to:
class TestHuggingFaceDownload:
Explanation: More descriptive class name improves clarity of test purpose.
-
File: tests/core/tools/test_hugging_face.py
Lines: 70-92
Add a new test case:def test_download_nonexistent_repo(self): """Test handling of a nonexistent repository""" test_repo = "nonexistent_repo" test_filename = "test_filename" with patch("policyengine_core.tools.hugging_face.model_info") as mock_model_info: mock_model_info.side_effect = Exception("Repository not found") with pytest.raises(Exception) as exc_info: download_huggingface_dataset(test_repo, test_filename) assert "Unable to download dataset" in str(exc_info.value)
Explanation: This new test case improves coverage by checking the handling of nonexistent repositories, which is different from private repositories.
These suggestions focus on improving code clarity, adding helpful comments, and enhancing test coverage, which should make the code more maintainable and easier to understand for all contributors.
Wow, this is interesting! Thanks for sending these suggestions. Will incorporate and re-request review. |
Review re-requested. I haven't added the new test from above, as Hugging Face raises the same error for both a nonexistent repo and a private repo without a passed token, and thus the test would be the exact same. |
Fixes #320.
This code does three things:
Dataset
class merely calls under the hood, to avoid code duplication and accord with DRYmodel_info
from Hugging Face. If the model info indicates that the repo is public, the code merely downloads relevant files, whereas if it's private, the function asks the user to input a Hugging Face token before proceeding.A couple things here:
model_info
endpoint raises an error if the repo is private. This code attempts to capitalize on that, but it's inherently messy and means that we're using a raised error to signal "not private," even though at times it means "repo doesn't exist"Requesting both Max and Nikhil, as Nikhil's touched this, but Max asked about the functionality.