Skip to content

Commit

Permalink
Fix lint, fix format name logic, add UTs for format name logic
Browse files Browse the repository at this point in the history
  • Loading branch information
nadolskit committed Oct 10, 2024
1 parent 5ee88da commit 62fd389
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 24 deletions.
46 changes: 22 additions & 24 deletions paperqa/clients/openalex.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
import json
import logging
import os
import re
from collections.abc import Collection
from datetime import datetime
from typing import Any
Expand All @@ -20,9 +19,24 @@

OPENALEX_BASE_URL = "https://api.openalex.org"
OPENALEX_API_REQUEST_TIMEOUT = 5.0

logger = logging.getLogger(__name__)


# author_name will be FamilyName, GivenName Middle initial. (if available)
# there is no standalone "FamilyName" or "GivenName" fields
# this manually constructs the name into the format the other clients use
def reformat_name(name: str) -> str:
if "," not in name:
return name
family, given = name.split(",", 1)
family = family.strip()
given_names = given.strip()

# Return the reformatted name
return f"{given_names} {family}"


async def get_openalex_mailto() -> str | None:
"""Get the OpenAlex mailto address.
Expand Down Expand Up @@ -64,7 +78,7 @@ async def get_doc_details_from_openalex(
mailto = await get_openalex_mailto()
params = {"mailto": mailto} if mailto else {}

if doi is None and title is None:
if doi is title is None:
raise ValueError("Either a DOI or title must be provided.")

url = f"{OPENALEX_BASE_URL}/works"
Expand Down Expand Up @@ -125,28 +139,6 @@ async def parse_openalex_to_doc_details(message: dict[str, Any]) -> DocDetails:
Returns:
Parsed document details.
"""

# author_name will be FamilyName, GivenName Middle initial. (if available)
# there is no standalone "FamilyName" or "GivenName" fields
# this manually constructs the name into the format the other clients use
def reformat_name(name: str) -> str:
# https://regex101.com/r/74vR57/1
pattern = r"^([^,]+),\s*(.+?)(?:\s+(\w+\.?))?$"
match = re.match(pattern, name)
if match:
family_name, given_name, middle = match.groups()

family_name = family_name.strip()
given_name = given_name.strip()

reformatted = f"{given_name}"
if middle:
reformatted += f" {middle.strip()}"
reformatted += f" {family_name}"
return reformatted.strip()

return name

authors = [
authorship.get("raw_author_name")
for authorship in message.get("authorships", [])
Expand All @@ -161,6 +153,10 @@ def reformat_name(name: str) -> str:
)
journal = message.get("primary_location", {}).get("source", {}).get("display_name")

best_oa_location = message.get("best_oa_location", {})
pdf_url = best_oa_location.get("pdf_url")
oa_license = best_oa_location.get("license")

return DocDetails( # type: ignore[call-arg]
key=None,
bibtex_type=BIBTEX_MAPPING.get(message.get("type", "other"), "misc"),
Expand All @@ -178,6 +174,8 @@ def reformat_name(name: str) -> str:
title=message.get("title"),
citation_count=message.get("cited_by_count"),
doi=message.get("doi"),
license=oa_license,
pdf_url=pdf_url,
other=message,
)

Expand Down
23 changes: 23 additions & 0 deletions tests/test_clients.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
)
from paperqa.clients.client_models import MetadataPostProcessor, MetadataProvider
from paperqa.clients.journal_quality import JournalQualityPostProcessor
from paperqa.clients.openalex import reformat_name
from paperqa.clients.retractions import RetractionDataPostProcessor


Expand Down Expand Up @@ -601,3 +602,25 @@ async def test_crossref_retraction_status(stub_data_dir: Path) -> None:
in crossref_details.formatted_citation
)
assert crossref_details.is_retracted is True, "Should be retracted"


def test_reformat_name():
test_cases = [
("Doe, John", "John Doe"),
("Doe, Jane Mary", "Jane Mary Doe"),
("O'Doe, John", "John O'Doe"),
("Doe, Jane", "Jane Doe"),
("Family, Jane Mary Elizabeth", "Jane Mary Elizabeth Family"),
("O'Doe, Jane", "Jane O'Doe"),
("Family, John Jr.", "John Jr. Family"),
("Family", "Family"),
("Jane Doe", "Jane Doe"),
("Doe, Jöhn", "Jöhn Doe"),
("Doe, Jòhn", "Jòhn Doe"),
]

for name, expected in test_cases:
result = reformat_name(name)
assert (
result == expected
), f"Expected '{expected}', but got '{result}' for '{name}'"

0 comments on commit 62fd389

Please sign in to comment.