Skip to content

Commit

Permalink
Add sorting options for non-time-constrained WERs (#49)
Browse files Browse the repository at this point in the history
* Add sort options to non time-constrained WERs

* Fix --regex option in _merge

* Convert boolean args (True, true, False, false) to bool

* Add burn tests for sort options

* Improve error message for seglst loading

* Remove debug help text

* Improve seglst format check
  • Loading branch information
thequilo authored Jan 25, 2024
1 parent b481983 commit 1d3c9ac
Show file tree
Hide file tree
Showing 4 changed files with 177 additions and 53 deletions.
3 changes: 2 additions & 1 deletion meeteval/io/load_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,8 @@ def load(path: 'Path | list[Path]', parse_float=decimal.Decimal, format: 'str |
# Guess the type from the file content. Only support Chime7 JSON / SegLST format.
try:
return meeteval.io.SegLST.load(path, parse_float=parse_float)
except Exception as e:
except ValueError as e:
# Catches simplejson's JSONDecodeError and our own ValueErrors
raise ValueError(f'Unknown JSON format: {path}. Only SegLST format is supported.') from e
else:
raise ValueError(f'Unknown file type: {path}')
Expand Down
38 changes: 30 additions & 8 deletions meeteval/io/seglst.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,14 +44,13 @@ class SegLstSegment(TypedDict, total=False):
@dataclasses.dataclass(frozen=True)
class SegLST(BaseABC):
"""
A collection of segments in SegLST format. This the input type to most
functions in MeetEval that process transcript segments.
Segment-wise Long-form Speech Transcription annotation (SegLST) format
This the input type to most functions in MeetEval that process transcript
segments.
"""
segments: 'list[SegLstSegment]'

# Caches
_unique = None

@classmethod
def load(
cls,
Expand All @@ -75,6 +74,11 @@ def parse(cls, s: str, parse_float=decimal.Decimal) -> 'Self':
>>> SegLST.parse('[{"words": "a b c", "segment_index": 0, "speaker": 0}]')
SegLST(segments=[{'words': 'a b c', 'segment_index': 0, 'speaker': 0}])
>>> SegLST.parse('{"a": {"words": "a b c", "segment_index": 0, "speaker": 0}}')
Traceback (most recent call last):
...
ValueError: Invalid JSON format for SegLST: Expected a list of segments, but found a dict.
"""
import simplejson

Expand All @@ -85,9 +89,27 @@ def fix_floats(s):
s[k] = parse_float(s[k])
return s

return cls([
fix_floats(s) for s in simplejson.loads(s, parse_float=parse_float)
])
loaded = simplejson.loads(s, parse_float=parse_float)

if not isinstance(loaded, list):
raise ValueError(
'Invalid JSON format for SegLST: Expected a list of segments, '
'but found a dict.'
)

# Check if the first and last entry have the correct format. We here
# require that the "session_id" key is present in all segments.
if (
loaded
and not isinstance(loaded[0], dict) and 'session_id' in loaded[0]
and not isinstance(loaded[-1], dict) and 'session_id' in loaded[-1]
):
raise ValueError(
f'Invalid JSON format for SegLST: Expected a list of segments '
f'(as dicts), but found a list of {type(loaded[0])}.'
)

return cls([fix_floats(s) for s in loaded])

def dump(self, file):
from meeteval.io.base import _open
Expand Down
183 changes: 139 additions & 44 deletions meeteval/wer/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ def _load(path: Path):

def _load_texts(
reference_paths: 'list[str]', hypothesis_paths: 'list[str]', regex,
file_format=None,
reference_sort=False, hypothesis_sort=False, file_format=None,
) -> 'tuple[meeteval.io.SegLST, list[Path], meeteval.io.SegLST, list[Path]]':
"""Load and validate reference and hypothesis texts.
Expand Down Expand Up @@ -125,6 +125,46 @@ def filter(s):
reference = filter(reference)
hypothesis = filter(hypothesis)

# Sort
if reference_sort == 'segment':
if 'start_time' in reference.T.keys():
reference = reference.sorted('start_time')
else:
logging.warning(
'Ignoring --reference-sort="segment" because no start_time is '
'found in the reference'
)
elif not reference_sort:
pass
elif reference_sort in ('word', True):
raise ValueError(
f'reference_sort={reference_sort} is only supported for'
f'time-constrained WERs.'
)
else:
raise ValueError(
f'Unknown choice for reference_sort: {reference_sort}'
)
if hypothesis_sort == 'segment':
if 'start_time' in hypothesis.T.keys():
hypothesis = hypothesis.sorted('start_time')
else:
logging.warning(
'Ignoring --hypothesis-sort="segment" because no start_time is '
'found in the hypothesis'
)
elif not hypothesis_sort:
pass
elif hypothesis_sort in ('word', True):
raise ValueError(
f'hypothesis_sort={hypothesis_sort} is only supported for'
f'time-constrained WERs.'
)
else:
raise ValueError(
f'Unknown choice for hypothesis_sort: {hypothesis_sort}'
)

return reference, reference_paths, hypothesis, hypothesis_paths


Expand Down Expand Up @@ -202,11 +242,15 @@ def orcwer(
average_out='{parent}/{stem}_orcwer.json',
per_reco_out='{parent}/{stem}_orcwer_per_reco.json',
regex=None,
reference_sort='segment',
hypothesis_sort='segment',
):
"""Computes the Optimal Reference Combination Word Error Rate (ORC WER)"""
from meeteval.wer.wer.orc import orc_word_error_rate_multifile
reference, _, hypothesis, hypothesis_paths = _load_texts(
reference, hypothesis, regex=regex)
reference, hypothesis, regex=regex,
reference_sort=reference_sort, hypothesis_sort=hypothesis_sort
)
results = orc_word_error_rate_multifile(reference, hypothesis)
_save_results(results, hypothesis_paths, per_reco_out, average_out)

Expand All @@ -216,11 +260,15 @@ def cpwer(
average_out='{parent}/{stem}_cpwer.json',
per_reco_out='{parent}/{stem}_cpwer_per_reco.json',
regex=None,
reference_sort='segment',
hypothesis_sort='segment',
):
"""Computes the Concatenated minimum-Permutation Word Error Rate (cpWER)"""
from meeteval.wer.wer.cp import cp_word_error_rate_multifile
reference, _, hypothesis, hypothesis_paths = _load_texts(
reference, hypothesis, regex)
reference, hypothesis, regex=regex,
reference_sort=reference_sort, hypothesis_sort=hypothesis_sort
)
results = cp_word_error_rate_multifile(reference, hypothesis)
_save_results(results, hypothesis_paths, per_reco_out, average_out)

Expand All @@ -230,11 +278,15 @@ def mimower(
average_out='{parent}/{stem}_mimower.json',
per_reco_out='{parent}/{stem}_mimower_per_reco.json',
regex=None,
reference_sort='segment',
hypothesis_sort='segment',
):
"""Computes the MIMO WER"""
from meeteval.wer.wer.mimo import mimo_word_error_rate_multifile
reference, _, hypothesis, hypothesis_paths = _load_texts(
reference, hypothesis, regex=regex)
reference, hypothesis, regex=regex,
reference_sort=reference_sort, hypothesis_sort=hypothesis_sort
)
results = mimo_word_error_rate_multifile(reference, hypothesis)
_save_results(results, hypothesis_paths, per_reco_out, average_out)

Expand Down Expand Up @@ -271,12 +323,16 @@ def tcpwer(
def _merge(
files: 'list[str]',
out: str = None,
average: bool = None
average: bool = None,
regex: str = None,
):
# Load input files
files = [Path(f) for f in files]
data = [_load(f) for f in files]

if regex is not None:
regex = re.compile(regex)

import meeteval
ers = []

Expand All @@ -287,6 +343,8 @@ def _merge(
ers.append([None, ErrorRate.from_dict(d)])
else:
for k, v in d.items(): # Details file
if regex is not None and not regex.fullmatch(k):
continue
if 'errors' in v:
ers.append([k, ErrorRate.from_dict(v)])

Expand All @@ -308,16 +366,30 @@ def merge(files, out):
return _merge(files, out, average=None)


def average(files, out):
def average(files, out, regex=None):
"""Computes the average over one or multiple per-reco files"""
return _merge(files, out, average=True)
return _merge(files, out, average=True, regex=regex)


class SmartFormatter(argparse.ArgumentDefaultsHelpFormatter):
"""
https://stackoverflow.com/a/22157136/5766934
"""
def _split_lines(self, text, width):
import textwrap
return [
tt
for i, t in enumerate(text.split('\n'))
for tt in textwrap.wrap(t, width, subsequent_indent=' ' if i > 0 else '')
]

class CLI:
def __init__(self):

# Define argument parser and commands
self.parser = argparse.ArgumentParser()
self.parser = argparse.ArgumentParser(
formatter_class=SmartFormatter
)
self.parser.add_argument('--version', action='store_true',
help='Show version')

Expand All @@ -329,7 +401,9 @@ def __init__(self):
choices=['DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL', 'SILENT']
)

self.commands = self.parser.add_subparsers(title='Subcommands')
self.commands = self.parser.add_subparsers(
title='Subcommands',
)

@staticmethod
def positive_number(x: str):
Expand All @@ -343,6 +417,17 @@ def positive_number(x: str):

return x

@staticmethod
def str_or_bool(x: str):
"""Convert common boolean strings to bool and pass through other
strings"""
if x in ('true', 'True'):
return True
elif x in ('false', 'False'):
return False
else:
return x

def add_argument(self, command_parser, name, p):
if name == 'reference':
command_parser.add_argument(
Expand Down Expand Up @@ -391,55 +476,65 @@ def add_argument(self, command_parser, name, p):
)
elif name == 'hyp_pseudo_word_timing':
command_parser.add_argument(
'--hyp-pseudo-word-timing', choices=pseudo_word_level_strategies.keys(),
'--hyp-pseudo-word-timing',
choices=pseudo_word_level_strategies.keys(),
help='Specifies how word-level timings are '
'determined from segment-level timing '
'for the hypothesis. Choices: '
'equidistant_intervals: Divide segment-level timing into equally sized intervals; '
'equidistant_points: Place time points equally spaded int the segment-level intervals; '
'full_segment: Use the full segment for each word that belongs to that segment;'
'character_based: Estimate the word length based on the number of characters; '
'character_based_points: Estimates the word length based on the number of characters and '
'creates a point in the center of each word; '
'none: Do not estimate word-level timings but assume that the provided timings are already '
'for the hypothesis.\n'
'Choices:\n'
'- equidistant_intervals: Divide segment-level timing into equally sized intervals\n'
'- equidistant_points: Place time points equally spaded int the segment-level intervals\n'
'- full_segment: Use the full segment for each word that belongs to that segment\n'
'- character_based: Estimate the word length based on the number of characters\n'
'- character_based_points: Estimates the word length based on the number of characters and '
'creates a point in the center of each word\n'
'- none: Do not estimate word-level timings but assume that the provided timings are already '
'given on a word level.'
)
elif name == 'ref_pseudo_word_timing':
command_parser.add_argument(
'--ref-pseudo-word-timing', choices=pseudo_word_level_strategies.keys(),
'--ref-pseudo-word-timing',
choices=pseudo_word_level_strategies.keys(),
help='Specifies how word-level timings are '
'determined from segment-level timing '
'for the reference. Choices: '
'equidistant_intervals: Divide segment-level timing into equally sized intervals; '
'equidistant_points: Place time points equally spaded int the segment-level intervals; '
'full_segment: Use the full segment for each word that belongs to that segment. '
'character_based: Estimate the word length based on the number of characters; '
'character_based_points: Estimates the word length based on the number of characters and '
'creates a point in the center of each word; '
'none: Do not estimate word-level timings but assume that the provided timings are already '
'for the reference.\n'
'Choices:\n'
'- equidistant_intervals: Divide segment-level timing into equally sized intervals\n'
'- equidistant_points: Place time points equally spaded int the segment-level intervals\n'
'- full_segment: Use the full segment for each word that belongs to that segment.\n'
'- character_based: Estimate the word length based on the number of characters\n'
'- character_based_points: Estimates the word length based on the number of characters and '
'creates a point in the center of each word\n'
'- none: Do not estimate word-level timings but assume that the provided timings are already '
'given on a word level.'
)
elif name == 'reference_sort':
command_parser.add_argument(
'--reference-sort', choices=[True, False, 'word', 'segment'],
help='How to sort words/segments in the reference; '
'True: sort by segment start time and assert that the word-level timings are sorted by start '
'time; '
'False: do not sort and do not check word order. Segment order is taken from input file '
'and sorting is up to the user; '
'segment: sort segments by start time and do not check word order'
'word: sort words by start time'
'--reference-sort',
choices=[True, False, 'word', 'segment'],
type=self.str_or_bool,
help='How to sort words/segments in the reference.\n'
'Choices:\n'
'- segment: Sort segments by start time and do not check word order\n'
'- False: Do not sort and do not check word order. Segment order is taken from input file '
'and sorting is up to the user\n'
'- True: Sort by segment start time and assert that the word-level timings are sorted by start '
'time. Only supported for time-constrained WERs\n'
'- word: sort words by start time. Only supported for time-constrained WERs'
)
elif name == 'hypothesis_sort':
command_parser.add_argument(
'--hypothesis-sort', choices=[True, False, 'word', 'segment'],
help='How to sort words/segments in the reference; '
'True: sort by segment start time and assert that the word-level timings are sorted by start '
'time; '
'False: do not sort and do not check word order. Segment order is taken from input file '
'and sorting is up to the user; '
'segment: sort segments by start time and do not check word order'
'word: sort words by start time'
'--hypothesis-sort',
choices=[True, False, 'word', 'segment'],
type=self.str_or_bool,
help='How to sort words/segments in the reference.\n'
'Choices:\n'
'- segment: Sort segments by start time and do not check word order\n'
'- False: Do not sort and do not check word order. Segment order is taken from input file '
'and sorting is up to the user\n'
'- True: Sort by segment start time and assert that the word-level timings are sorted by start '
'time. Only supported for time-constrained WERs\n'
'- word: sort words by start time. Only supported for time-constrained WERs'
)
elif name == 'files':
command_parser.add_argument('files', nargs='+')
Expand All @@ -452,7 +547,7 @@ def add_command(self, fn, command_name=None):
command_parser = self.commands.add_parser(
command_name,
add_help=False,
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
formatter_class=SmartFormatter,
help=fn.__doc__,
)
command_parser.add_argument(
Expand Down
Loading

0 comments on commit 1d3c9ac

Please sign in to comment.