diff --git a/policyengine_core/tools/hugging_face.py b/policyengine_core/tools/hugging_face.py index 47fb1906..cbf41fdf 100644 --- a/policyengine_core/tools/hugging_face.py +++ b/policyengine_core/tools/hugging_face.py @@ -25,10 +25,9 @@ def download_huggingface_dataset(repo: str, repo_filename: str, version: str = N """ # Attempt connection to Hugging Face model_info endpoint # (https://huggingface.co/docs/huggingface_hub/v0.26.5/en/package_reference/hf_api#huggingface_hub.HfApi.model_info) - # Unfortunately, this endpoint will 401 on a private repo, - # but also on a public repo with a malformed URL, etc. - # Assume a 401 means the token is required. - + # Attempt to fetch model info to determine if repo is private + # A RepositoryNotFoundError & 401 likely means the repo is private, + # but this error will also surface for public repos with malformed URL, etc. try: fetched_model_info: ModelInfo = model_info(repo) is_repo_private: bool = fetched_model_info.private @@ -43,16 +42,16 @@ def download_huggingface_dataset(repo: str, repo_filename: str, version: str = N + "is private, the URL is malformed, or the dataset does not exist." ) - token: str = None + authentication_token: str = None if is_repo_private: - token: str = get_or_prompt_hf_token() + authentication_token: str = get_or_prompt_hf_token() return hf_hub_download( repo_id=repo, repo_type="model", filename=repo_filename, revision=version, - token=token, + token=authentication_token, )