Skip to content

Commit

Permalink
Merge pull request #43 from boeddeker/main
Browse files Browse the repository at this point in the history
add der md_eval_22
  • Loading branch information
boeddeker authored Dec 7, 2023
2 parents 88eb16b + d6fd7dd commit 76022ca
Show file tree
Hide file tree
Showing 8 changed files with 415 additions and 159 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/noneeditable.yml
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ jobs:
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v4
with:
python-version: '>=3.10'
python-version: '3.11'
cache: 'pip'
- name: Install dependencies
run: |
Expand Down
Empty file added meeteval/der/__init__.py
Empty file.
14 changes: 14 additions & 0 deletions meeteval/der/__main__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@

def cli():
from meeteval.wer.__main__ import CLI
from meeteval.der.md_eval import _md_eval_22

cli = CLI()

cli.add_command(_md_eval_22, 'md_eval_22')

cli.run()


if __name__ == '__main__':
cli()
170 changes: 170 additions & 0 deletions meeteval/der/md_eval.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,170 @@
import logging
import re
import decimal
import shutil
import tempfile
import subprocess
import dataclasses
from pathlib import Path

import meeteval.io
from meeteval.wer.wer.error_rate import ErrorRate


def _fix_channel(r):
return meeteval.io.rttm.RTTM([
line.replace(channel='1')
# Thilo puts there usually some random value, e.g. <NA> for hyp and 0
# for ref, while dscore enforces the default to be 1
for line in r
])


@dataclasses.dataclass(frozen=True)
class DiaErrorRate:
"""
"""
error_rate: 'float | decimal.Decimal'

scored_speaker_time: 'float | decimal.Decimal'
missed_speaker_time: 'float | decimal.Decimal' # deletions
falarm_speaker_time: 'float | decimal.Decimal' # insertions
speaker_error_time: 'float | decimal.Decimal' # substitutions

@classmethod
def zero(cls):
return cls(0, 0, 0, 0, 0)

def __post_init__(self):
assert self.scored_speaker_time >= 0
assert self.missed_speaker_time >= 0
assert self.falarm_speaker_time >= 0
assert self.speaker_error_time >= 0
errors = self.speaker_error_time + self.falarm_speaker_time + self.missed_speaker_time
error_rate = decimal.Decimal(errors / self.scored_speaker_time)
if self.error_rate is None:
object.__setattr__(self, 'error_rate', error_rate)
else:
# decimal.Decimal.quantize rounds to the same number of digits as self.error_rate has.
error_rate = error_rate.quantize(self.error_rate)
assert self.error_rate == error_rate, (error_rate, self)

def __radd__(self, other: 'int') -> 'ErrorRate':
if isinstance(other, int) and other == 0:
# Special case to support sum.
return self
return NotImplemented

def __add__(self, other: 'DiaErrorRate'):
if not isinstance(other, self.__class__):
raise ValueError()

return DiaErrorRate(
error_rate=None,
scored_speaker_time=self.scored_speaker_time + other.scored_speaker_time,
missed_speaker_time=self.missed_speaker_time + other.missed_speaker_time,
falarm_speaker_time=self.falarm_speaker_time + other.falarm_speaker_time,
speaker_error_time=self.speaker_error_time + other.speaker_error_time,
)


def _md_eval_22(
reference,
hypothesis,
average_out='{parent}/{stem}_md_eval_22.json',
per_reco_out='{parent}/{stem}_md_eval_22_per_reco.json',
collar=0,
regex=None,
):
from meeteval.wer.__main__ import _load_texts

r, _, h, hypothesis_paths = _load_texts(
reference, hypothesis, regex)

r = _fix_channel(r.to_rttm())
h = _fix_channel(h.to_rttm())

r = r.grouped_by_filename()
h = h.grouped_by_filename()

keys = set(r.keys()) & set(h.keys())
missing = set(r.keys()) ^ set(h.keys())
if len(missing) > 0:
logging.warning(f'Missing {len(missing)} filenames:', missing)
logging.warning(f'Found {len(keys)} filenames:', keys)

md_eval_22 = shutil.which('md-eval-22.pl')
if not md_eval_22:
md_eval_22 = Path(__file__).parent / 'md-eval-22.pl'
if md_eval_22.exists():
pass
else:
url = 'https://github.com/nryant/dscore/raw/master/scorelib/md-eval-22.pl'
logging.info(f'md-eval-22.pl not found. Trying to download it from {url}.')
import urllib.request
urllib.request.urlretrieve(url, md_eval_22)
logging.info(f'Wrote {md_eval_22}')

def get_details(r, h, key, tmpdir):
r_file = tmpdir / f'{key}.ref.rttm'
h_file = tmpdir / f'{key}.hyp.rttm'
r.dump(r_file)
h.dump(h_file)

cmd = [
'perl', f'{md_eval_22}',
'-c', f'{collar}',
'-r', f'{r_file}',
'-s', f'{h_file}',
]

cp = subprocess.run(cmd, stdout=subprocess.PIPE,
check=True, universal_newlines=True)
# SCORED SPEAKER TIME =4309.340250 secs
# MISSED SPEAKER TIME =4309.340250 secs
# FALARM SPEAKER TIME =0.000000 secs
# SPEAKER ERROR TIME =0.000000 secs
# OVERALL SPEAKER DIARIZATION ERROR = 100.00 percent of scored speaker time `(ALL)

error_rate, = re.findall(r'OVERALL SPEAKER DIARIZATION ERROR = ([\d.]+) percent of scored speaker time', cp.stdout)
length, = re.findall(r'SCORED SPEAKER TIME =([\d.]+) secs', cp.stdout)
deletions, = re.findall(r'MISSED SPEAKER TIME =([\d.]+) secs', cp.stdout)
insertions, = re.findall(r'FALARM SPEAKER TIME =([\d.]+) secs', cp.stdout)
substitutions, = re.findall(r'SPEAKER ERROR TIME =([\d.]+) secs', cp.stdout)

def convert(string):
return decimal.Decimal(string)

return DiaErrorRate(
scored_speaker_time=convert(length),
missed_speaker_time=convert(deletions),
falarm_speaker_time=convert(insertions),
speaker_error_time=convert(substitutions),
error_rate=convert(error_rate) / 100,
)

per_reco = {}
with tempfile.TemporaryDirectory() as tmpdir:
tmpdir = Path(tmpdir)
for key in keys:
per_reco[key] = get_details(r[key], h[key], key, tmpdir)

md_eval = get_details(
meeteval.io.RTTM([line for key in keys for line in r[key]]),
meeteval.io.RTTM([line for key in keys for line in h[key]]),
'',
tmpdir,
)
summary = sum(per_reco.values())
error_rate = summary.error_rate.quantize(md_eval.error_rate)
if error_rate != md_eval.error_rate:
raise RuntimeError(
f'The error rate of md-eval-22.pl on all recordings '
f'({summary.error_rate})\n'
f'does not match the average error rate of md-eval-22.pl '
f'applied to each recording ({md_eval.error_rate}).'
)

from meeteval.wer.__main__ import _save_results
_save_results(per_reco, hypothesis_paths, per_reco_out, average_out)
10 changes: 9 additions & 1 deletion meeteval/io/rttm.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,18 @@ class RTTMLine(BaseLine):
SPEAKER CMU_20020319-1400_d01_NONE 1 130.430000 2.350 <NA> <NA> juliet <NA> <NA>
SPEAKER CMU_20020319-1400_d01_NONE 1 157.610000 3.060 <NA> <NA> tbc <NA> <NA>
SPEAKER CMU_20020319-1400_d01_NONE 1 130.490000 0.450 <NA> <NA> chek <NA> <NA>
Note:
The RTTM definition (Appendix A in "The 2009 (RT-09) Rich Transcription
Meeting Recognition Evaluation Plan") doesn't say anything about the
channel format or defaults, but dscore enforces a "1" for the channel
(https://github.com/nryant/dscore#rttm),
Hence, the default here is 1 for channel.
"""
type: str = 'SPEAKER'
filename: str = '<NA>'
channel: str = '<NA>'
channel: str = '1'
begin_time: 'float | int | str | decimal.Decimal' = 0
duration: 'float | int | str | decimal.Decimal' = 0
othography: 'str' = '<NA>'
Expand Down
Loading

0 comments on commit 76022ca

Please sign in to comment.