diff --git a/paperqa/clients/openalex.py b/paperqa/clients/openalex.py index dd5bd94e..c42ce7e7 100644 --- a/paperqa/clients/openalex.py +++ b/paperqa/clients/openalex.py @@ -3,7 +3,6 @@ import json import logging import os -import re from collections.abc import Collection from datetime import datetime from typing import Any @@ -20,9 +19,24 @@ OPENALEX_BASE_URL = "https://api.openalex.org" OPENALEX_API_REQUEST_TIMEOUT = 5.0 + logger = logging.getLogger(__name__) +# author_name will be FamilyName, GivenName Middle initial. (if available) +# there is no standalone "FamilyName" or "GivenName" fields +# this manually constructs the name into the format the other clients use +def reformat_name(name: str) -> str: + if "," not in name: + return name + family, given = name.split(",", 1) + family = family.strip() + given_names = given.strip() + + # Return the reformatted name + return f"{given_names} {family}" + + async def get_openalex_mailto() -> str | None: """Get the OpenAlex mailto address. @@ -64,7 +78,7 @@ async def get_doc_details_from_openalex( mailto = await get_openalex_mailto() params = {"mailto": mailto} if mailto else {} - if doi is None and title is None: + if doi is title is None: raise ValueError("Either a DOI or title must be provided.") url = f"{OPENALEX_BASE_URL}/works" @@ -125,28 +139,6 @@ async def parse_openalex_to_doc_details(message: dict[str, Any]) -> DocDetails: Returns: Parsed document details. """ - - # author_name will be FamilyName, GivenName Middle initial. (if available) - # there is no standalone "FamilyName" or "GivenName" fields - # this manually constructs the name into the format the other clients use - def reformat_name(name: str) -> str: - # https://regex101.com/r/74vR57/1 - pattern = r"^([^,]+),\s*(.+?)(?:\s+(\w+\.?))?$" - match = re.match(pattern, name) - if match: - family_name, given_name, middle = match.groups() - - family_name = family_name.strip() - given_name = given_name.strip() - - reformatted = f"{given_name}" - if middle: - reformatted += f" {middle.strip()}" - reformatted += f" {family_name}" - return reformatted.strip() - - return name - authors = [ authorship.get("raw_author_name") for authorship in message.get("authorships", []) @@ -161,6 +153,10 @@ def reformat_name(name: str) -> str: ) journal = message.get("primary_location", {}).get("source", {}).get("display_name") + best_oa_location = message.get("best_oa_location", {}) + pdf_url = best_oa_location.get("pdf_url") + oa_license = best_oa_location.get("license") + return DocDetails( # type: ignore[call-arg] key=None, bibtex_type=BIBTEX_MAPPING.get(message.get("type", "other"), "misc"), @@ -178,6 +174,8 @@ def reformat_name(name: str) -> str: title=message.get("title"), citation_count=message.get("cited_by_count"), doi=message.get("doi"), + license=oa_license, + pdf_url=pdf_url, other=message, ) diff --git a/tests/test_clients.py b/tests/test_clients.py index 2154a5ac..7395fc10 100644 --- a/tests/test_clients.py +++ b/tests/test_clients.py @@ -18,6 +18,7 @@ ) from paperqa.clients.client_models import MetadataPostProcessor, MetadataProvider from paperqa.clients.journal_quality import JournalQualityPostProcessor +from paperqa.clients.openalex import reformat_name from paperqa.clients.retractions import RetractionDataPostProcessor @@ -601,3 +602,25 @@ async def test_crossref_retraction_status(stub_data_dir: Path) -> None: in crossref_details.formatted_citation ) assert crossref_details.is_retracted is True, "Should be retracted" + + +def test_reformat_name(): + test_cases = [ + ("Doe, John", "John Doe"), + ("Doe, Jane Mary", "Jane Mary Doe"), + ("O'Doe, John", "John O'Doe"), + ("Doe, Jane", "Jane Doe"), + ("Family, Jane Mary Elizabeth", "Jane Mary Elizabeth Family"), + ("O'Doe, Jane", "Jane O'Doe"), + ("Family, John Jr.", "John Jr. Family"), + ("Family", "Family"), + ("Jane Doe", "Jane Doe"), + ("Doe, Jöhn", "Jöhn Doe"), + ("Doe, Jòhn", "Jòhn Doe"), + ] + + for name, expected in test_cases: + result = reformat_name(name) + assert ( + result == expected + ), f"Expected '{expected}', but got '{result}' for '{name}'"