Skip to content

Commit

Permalink
Mod: Update typeguard for args checks and check if any newline char i…
Browse files Browse the repository at this point in the history
…s in the sentences before PTB tokenization.
  • Loading branch information
Labbeti committed Sep 25, 2023
1 parent cd3ea12 commit 582ae5c
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 7 deletions.
17 changes: 12 additions & 5 deletions src/aac_metrics/utils/checks.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

from pathlib import Path
from subprocess import CalledProcessError
from typing import Any, Union
from typing import Any, TypeGuard, Union


pylog = logging.getLogger(__name__)
Expand All @@ -22,11 +22,18 @@ def check_metric_inputs(
mult_references: Any,
) -> None:
"""Raises ValueError if candidates and mult_references does not have a valid type and size."""

error_msgs = []
if not is_mono_sents(candidates):
raise ValueError("Invalid candidates type. (expected list[str])")
error_msg = "Invalid candidates type. (expected list[str])"
error_msgs.append(error_msg)

if not is_mult_sents(mult_references):
raise ValueError("Invalid mult_references type. (expected list[list[str]])")
error_msg = "Invalid mult_references type. (expected list[list[str]])"
error_msgs.append(error_msg)

if len(error_msgs) > 0:
raise ValueError("\n".join(error_msgs))

same_len = len(candidates) == len(mult_references)
if not same_len:
Expand All @@ -52,13 +59,13 @@ def check_java_path(java_path: Union[str, Path]) -> bool:
return valid


def is_mono_sents(sents: Any) -> bool:
def is_mono_sents(sents: Any) -> TypeGuard[list[str]]:
"""Returns True if input is list[str] containing sentences."""
valid = isinstance(sents, list) and all(isinstance(sent, str) for sent in sents)
return valid


def is_mult_sents(mult_sents: Any) -> bool:
def is_mult_sents(mult_sents: Any) -> TypeGuard[list[list[str]]]:
"""Returns True if input is list[list[str]] containing multiple sentences."""
valid = (
isinstance(mult_sents, list)
Expand Down
10 changes: 8 additions & 2 deletions src/aac_metrics/utils/tokenization.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,8 @@ def ptb_tokenize_batch(
:param verbose: The verbose level. defaults to 0.
:returns: The sentences tokenized as list[list[str]].
"""
# Originally based on https://github.com/audio-captioning/caption-evaluation-tools/blob/c1798df4c91e29fe689b1ccd4ce45439ec966417/caption/pycocoevalcap/tokenizer/ptbtokenizer.py#L30

sentences = list(sentences)

if not is_mono_sents(sentences):
Expand All @@ -82,12 +84,16 @@ def ptb_tokenize_batch(
tmp_path = _get_tmp_path(tmp_path)
punctuations = list(punctuations)

# Based on https://github.com/audio-captioning/caption-evaluation-tools/blob/c1798df4c91e29fe689b1ccd4ce45439ec966417/caption/pycocoevalcap/tokenizer/ptbtokenizer.py#L30

stanford_fpath = osp.join(cache_path, FNAME_STANFORD_CORENLP_3_4_1_JAR)

# Sanity checks
if __debug__:
newlines_count = sum(sent.count("\n") for sent in sentences)
if newlines_count > 0:
raise ValueError(
f"Invalid argument sentences for tokenization. (found {newlines_count} newlines character '\\n')"
)

if not osp.isdir(cache_path):
raise RuntimeError(f"Cannot find cache directory at {cache_path=}.")
if not osp.isdir(tmp_path):
Expand Down

0 comments on commit 582ae5c

Please sign in to comment.