Skip to content

Commit

Permalink
Mod: Clean internal code for tokenization, add warning about newline …
Browse files Browse the repository at this point in the history
…character (#6).
  • Loading branch information
Labbeti committed Sep 18, 2023
1 parent 1a81a51 commit 9e19816
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 27 deletions.
16 changes: 8 additions & 8 deletions src/aac_metrics/utils/checks.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
import re
import subprocess

from functools import cache
from pathlib import Path
from subprocess import CalledProcessError
from typing import Any, Union
Expand Down Expand Up @@ -54,17 +53,19 @@ def check_java_path(java_path: Union[str, Path]) -> bool:


def is_mono_sents(sents: Any) -> bool:
"""Returns True if input is list[str]."""
return isinstance(sents, list) and all(isinstance(sent, str) for sent in sents)
"""Returns True if input is list[str] containing sentences."""
valid = isinstance(sents, list) and all(isinstance(sent, str) for sent in sents)
return valid


def is_mult_sents(mult_sents: Any) -> bool:
"""Returns True if input is list[list[str]]."""
return (
"""Returns True if input is list[list[str]] containing multiple sentences."""
valid = (
isinstance(mult_sents, list)
and all(isinstance(sents, list) for sents in mult_sents)
and all(isinstance(sent, str) for sents in mult_sents for sent in sents)
)
return valid


def _get_java_version(java_path: str) -> str:
Expand Down Expand Up @@ -106,9 +107,8 @@ def _check_java_version(version: str, min_major: int, max_major: int) -> bool:
major_version = int(result["major"])
minor_version = int(result["minor"])

if (
major_version == 1 and minor_version <= 8
): # java <= 8 use versioning "1.MAJOR.MINOR" and > 8 use "MAJOR.MINOR.PATCH"
if major_version == 1 and minor_version <= 8:
# java <= 8 use versioning "1.MAJOR.MINOR" and > 8 use "MAJOR.MINOR.PATCH"
major_version = minor_version

return min_major <= major_version <= max_major
34 changes: 15 additions & 19 deletions src/aac_metrics/utils/tokenization.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

from typing import Any, Hashable, Iterable, Optional

from aac_metrics.utils.checks import check_java_path
from aac_metrics.utils.checks import check_java_path, is_mono_sents
from aac_metrics.utils.collections import flat_list, unflat_list
from aac_metrics.utils.paths import (
_get_cache_path,
Expand Down Expand Up @@ -70,12 +70,17 @@ def ptb_tokenize_batch(
:returns: The sentences tokenized as list[list[str]].
"""
sentences = list(sentences)

if not is_mono_sents(sentences):
raise ValueError(f"Invalid argument sentences. (not a list[str] of sentences)")

if len(sentences) == 0:
return []

cache_path = _get_cache_path(cache_path)
java_path = _get_java_path(java_path)
tmp_path = _get_tmp_path(tmp_path)
punctuations = list(punctuations)

# Based on https://github.com/audio-captioning/caption-evaluation-tools/blob/c1798df4c91e29fe689b1ccd4ce45439ec966417/caption/pycocoevalcap/tokenizer/ptbtokenizer.py#L30

Expand Down Expand Up @@ -111,9 +116,6 @@ def ptb_tokenize_batch(
"-lowerCase",
]

# ======================================================
# Prepare data for PTB AACTokenizer
# ======================================================
if audio_ids is None:
audio_ids = list(range(len(sentences)))
else:
Expand All @@ -135,9 +137,6 @@ def ptb_tokenize_batch(
for old, new in replaces.items():
sentences = sentences.replace(old, new)

# ======================================================
# Save sentences to temporary file
# ======================================================
tmp_file = tempfile.NamedTemporaryFile(
delete=False,
dir=tmp_path,
Expand All @@ -147,38 +146,35 @@ def ptb_tokenize_batch(
tmp_file.write(sentences.encode())
tmp_file.close()

# ======================================================
# Tokenize sentence
# ======================================================
cmd.append(osp.basename(tmp_file.name))
p_tokenizer = subprocess.Popen(
cmd,
cwd=tmp_path,
stdout=subprocess.PIPE,
stderr=subprocess.DEVNULL if verbose <= 2 else None,
)
token_lines = p_tokenizer.communicate(input=sentences.rstrip().encode())[0]
encoded_sentences = sentences.rstrip().encode()
token_lines = p_tokenizer.communicate(input=encoded_sentences)[0]
token_lines = token_lines.decode()
lines = token_lines.split("\n")
# remove temp file
os.remove(tmp_file.name)

# ======================================================
# Create dictionary for tokenized captions
# ======================================================
outs: Any = [None for _ in range(len(lines))]
if len(audio_ids) != len(lines):
raise RuntimeError(
f"PTB tokenize error: expected {len(audio_ids)} lines in output file but found {len(lines)}."
f"Maybe check if there is any newline character '\\n' in your sentences or disable preprocessing tokenization."
)

punctuations = list(punctuations)
outs: Any = [None for _ in range(len(lines))]
for k, line in zip(audio_ids, lines):
tokenized_caption = [
w for w in line.rstrip().split(" ") if w not in punctuations
]
outs[k] = tokenized_caption
assert all(out is not None for out in outs)
assert all(
out is not None for out in outs
), f"INTERNAL ERROR: PTB tokenizer output."

if verbose >= 2:
duration = time.perf_counter() - start_time
Expand All @@ -199,9 +195,9 @@ def preprocess_mono_sents(
"""Tokenize sentences using PTB Tokenizer then merge them by space.
.. warning::
PTB tokenizer is a java program that takes a list[str] as input, so calling several times `preprocess_mono_sents` is slow on list[list[str]].
PTB tokenizer is a java program that takes a list[str] as input, so calling several times this function is slow on list[list[str]].
If you want to process multiple sentences (list[list[str]]), use `preprocess_mult_sents` instead.
If you want to process multiple sentences (list[list[str]]), use :func:`~aac_metrics.utils.tokenization.preprocess_mult_sents` instead.
:param sentences: The list of sentences to process.
:param cache_path: The path to the external code directory. defaults to the value returned by :func:`~aac_metrics.utils.paths.get_default_cache_path`.
Expand Down

0 comments on commit 9e19816

Please sign in to comment.