Skip to content

Commit

Permalink
Fix: FER name in comment and internal variables. Also fix internal ca…
Browse files Browse the repository at this point in the history
…ll in vocab metric.
  • Loading branch information
Labbeti committed Nov 3, 2023
1 parent 9521615 commit 7d3297d
Show file tree
Hide file tree
Showing 7 changed files with 12 additions and 12 deletions.
2 changes: 1 addition & 1 deletion src/aac_metrics/classes/fer.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ class FER(AACMetric[Union[tuple[dict[str, Tensor], dict[str, Tensor]], Tensor]])
- Paper: https://arxiv.org/abs/2110.04684
- Original implementation: https://github.com/blmoistawinde/fense
For more information, see :func:`~aac_metrics.functional.fluerr.fluerr`.
For more information, see :func:`~aac_metrics.functional.fer.fer`.
"""

full_state_update = False
Expand Down
2 changes: 1 addition & 1 deletion src/aac_metrics/classes/vocab.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def compute(self) -> Union[tuple[dict[str, Tensor], dict[str, Tensor]], Tensor]:

def get_output_names(self) -> tuple[str, ...]:
return (
"vocab.cands",
"vocab",
"vocab.mrefs_full",
"vocab.ratio_full",
"vocab.mrefs_avg",
Expand Down
12 changes: 6 additions & 6 deletions src/aac_metrics/functional/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import time

from functools import partial
from typing import Any, Callable, Iterable, Union
from typing import Any, Callable, Iterable, Optional, Union

import torch

Expand Down Expand Up @@ -53,7 +53,7 @@
# DCASE challenge task6a metrics for 2023
"dcase2023": (
"meteor",
"spider_fl", # includes cider_d, spice, spider, fluerr
"spider_fl", # includes cider_d, spice, spider, fer
),
# All metrics
"all": (
Expand All @@ -63,8 +63,8 @@
"bleu_4",
"meteor",
"rouge_l",
"fense", # includes sbert, fluerr
"spider_fl", # includes cider_d, spice, spider, fluerr
"fense", # includes sbert, fer
"spider_fl", # includes cider_d, spice, spider, fer
"vocab",
),
}
Expand Down Expand Up @@ -243,9 +243,9 @@ def _get_metric_factory_functions(
tmp_path: str = ...,
device: Union[str, torch.device, None] = "auto",
verbose: int = 0,
init_kwds: dict[str, Any] = ...,
init_kwds: Optional[dict[str, Any]] = ...,
) -> dict[str, Callable[[list[str], list[list[str]]], Any]]:
if init_kwds is ...:
if init_kwds is None or init_kwds is ...:
init_kwds = {}

init_kwds = init_kwds | dict(return_all_scores=return_all_scores)
Expand Down
2 changes: 1 addition & 1 deletion src/aac_metrics/functional/fense.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ def _fense_from_outputs(
fer_outs: tuple[dict[str, Tensor], dict[str, Tensor]],
penalty: float = 0.9,
) -> tuple[dict[str, Tensor], dict[str, Tensor]]:
"""Combines SBERT and FluErr outputs.
"""Combines SBERT and FER outputs.
Based on https://github.com/blmoistawinde/fense/blob/main/fense/evaluator.py#L121
"""
Expand Down
2 changes: 1 addition & 1 deletion src/aac_metrics/functional/fer.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ def fer(
)
fer_scores = (probs_outs_sents["error"] > error_threshold).astype(float)

fer_scores = torch.from_numpy(fluerr_scores)
fer_scores = torch.from_numpy(fer_scores)
fer_score = fer_scores.mean()

if return_all_scores:
Expand Down
2 changes: 1 addition & 1 deletion src/aac_metrics/functional/spider_fl.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ def _spider_fl_from_outputs(
fer_outs: tuple[dict[str, Tensor], dict[str, Tensor]],
penalty: float = 0.9,
) -> tuple[dict[str, Tensor], dict[str, Tensor]]:
"""Combines SPIDEr and FluErr outputs.
"""Combines SPIDEr and FER outputs.
Based on https://github.com/felixgontier/dcase-2023-baseline/blob/main/metrics.py#L48
"""
Expand Down
2 changes: 1 addition & 1 deletion src/aac_metrics/functional/vocab.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ def vocab(
for refs in tok_mrefs
]
popped_refs = [refs[idx] for idx, refs in zip(indexes, tok_mrefs)]
vocab_mrefs_len_i = _corpus_vocab(popped_refs)
vocab_mrefs_len_i = _corpus_vocab(popped_refs, dtype)
vocab_mrefs_lens[i] = vocab_mrefs_len_i

vocab_mrefs_avg = vocab_mrefs_lens.mean()
Expand Down

0 comments on commit 7d3297d

Please sign in to comment.