Skip to content

Commit

Permalink
Mod: Use kwds arguments instead of pos args in metric classes compute…
Browse files Browse the repository at this point in the history
… and update doc usage page.
  • Loading branch information
Labbeti committed Aug 14, 2023
1 parent fbc0064 commit 3e8ddfa
Show file tree
Hide file tree
Showing 14 changed files with 93 additions and 92 deletions.
8 changes: 4 additions & 4 deletions docs/usage.rst
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,16 @@ Usage
Evaluate default AAC metrics
############################

The full evaluation process to compute AAC metrics can be done with `aac_metrics.aac_evaluate` function.
The full evaluation process to compute AAC metrics can be done with `aac_metrics.dcase2023_evaluate` function.

.. code-block:: python
from aac_metrics import aac_evaluate
from aac_metrics import dcase2023_evaluate
candidates: list[str] = ["a man is speaking", ...]
mult_references: list[list[str]] = [["a man speaks.", "someone speaks.", "a man is speaking while a bird is chirping in the background"], ...]
corpus_scores, _ = aac_evaluate(candidates, mult_references)
corpus_scores, _ = dcase2023_evaluate(candidates, mult_references)
print(corpus_scores)
# dict containing the score of each aac metric: "bleu_1", "bleu_2", "bleu_3", "bleu_4", "rouge_l", "meteor", "cider_d", "spice", "spider"
# {"bleu_1": tensor(0.7), "bleu_2": ..., ...}
Expand All @@ -25,7 +25,7 @@ Evaluate a specific metric
Evaluate a specific metric can be done using the `aac_metrics.functional.<metric_name>.<metric_name>` function or the `aac_metrics.classes.<metric_name>.<metric_name>` class.

.. warning::
Unlike `aac_evaluate`, the tokenization with PTBTokenizer is not done with these functions, but you can do it manually with `preprocess_mono_sents` and `preprocess_mult_sents` functions.
Unlike `dcase2023_evaluate`, the tokenization with PTBTokenizer is not done with these functions, but you can do it manually with `preprocess_mono_sents` and `preprocess_mult_sents` functions.

.. code-block:: python
Expand Down
13 changes: 7 additions & 6 deletions src/aac_metrics/classes/bleu.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,12 +53,13 @@ def __init__(

def compute(self) -> Union[tuple[dict[str, Tensor], dict[str, Tensor]], Tensor]:
return _bleu_compute(
self._cooked_cands,
self._cooked_mrefs,
self._return_all_scores,
self._n,
self._option,
self._verbose,
cooked_cands=self._cooked_cands,
cooked_mrefs=self._cooked_mrefs,
return_all_scores=self._return_all_scores,
n=self._n,
option=self._option,
verbose=self._verbose,
return_1_to_n=False,
)

def extra_repr(self) -> str:
Expand Down
26 changes: 13 additions & 13 deletions src/aac_metrics/classes/cider_d.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,13 +49,13 @@ def __init__(

def compute(self) -> Union[tuple[dict[str, Tensor], dict[str, Tensor]], Tensor]:
return _cider_d_compute(
self._cooked_cands,
self._cooked_mrefs,
self._return_all_scores,
self._n,
self._sigma,
self._return_tfidf,
self._scale,
cooked_cands=self._cooked_cands,
cooked_mrefs=self._cooked_mrefs,
return_all_scores=self._return_all_scores,
n=self._n,
sigma=self._sigma,
return_tfidf=self._return_tfidf,
scale=self._scale,
)

def extra_repr(self) -> str:
Expand All @@ -75,10 +75,10 @@ def update(
mult_references: list[list[str]],
) -> None:
self._cooked_cands, self._cooked_mrefs = _cider_d_update(
candidates,
mult_references,
self._n,
self._tokenizer,
self._cooked_cands,
self._cooked_mrefs,
candidates=candidates,
mult_references=mult_references,
n=self._n,
tokenizer=self._tokenizer,
prev_cooked_cands=self._cooked_cands,
prev_cooked_mrefs=self._cooked_mrefs,
)
2 changes: 1 addition & 1 deletion src/aac_metrics/classes/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ def __hash__(self) -> int:
class DCASE2023Evaluate(Evaluate):
"""Evaluate candidates with multiple references with DCASE2023 Audio Captioning metrics.
For more information, see :func:`~aac_metrics.functional.evaluate.aac_evaluate`.
For more information, see :func:`~aac_metrics.functional.evaluate.dcase2023_evaluate`.
"""

def __init__(
Expand Down
26 changes: 13 additions & 13 deletions src/aac_metrics/classes/fense.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,19 +66,19 @@ def __init__(

def compute(self) -> Union[tuple[dict[str, Tensor], dict[str, Tensor]], Tensor]:
return fense(
self._candidates,
self._mult_references,
self._return_all_scores,
self._sbert_model,
self._echecker,
self._echecker_tokenizer,
self._error_threshold,
self._device,
self._batch_size,
self._reset_state,
self._return_probs,
self._penalty,
self._verbose,
candidates=self._candidates,
mult_references=self._mult_references,
return_all_scores=self._return_all_scores,
sbert_model=self._sbert_model,
echecker=self._echecker,
echecker_tokenizer=self._echecker_tokenizer,
error_threshold=self._error_threshold,
device=self._device,
batch_size=self._batch_size,
reset_state=self._reset_state,
return_probs=self._return_probs,
penalty=self._penalty,
verbose=self._verbose,
)

def extra_repr(self) -> str:
Expand Down
20 changes: 10 additions & 10 deletions src/aac_metrics/classes/fluerr.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,16 +64,16 @@ def __init__(

def compute(self) -> Union[tuple[dict[str, Tensor], dict[str, Tensor]], Tensor]:
return fluerr(
self._candidates,
self._return_all_scores,
self._echecker,
self._echecker_tokenizer,
self._error_threshold,
self._device,
self._batch_size,
self._reset_state,
self._return_probs,
self._verbose,
candidates=self._candidates,
return_all_scores=self._return_all_scores,
echecker=self._echecker,
echecker_tokenizer=self._echecker_tokenizer,
error_threshold=self._error_threshold,
device=self._device,
batch_size=self._batch_size,
reset_state=self._reset_state,
return_probs=self._return_probs,
verbose=self._verbose,
)

def extra_repr(self) -> str:
Expand Down
16 changes: 8 additions & 8 deletions src/aac_metrics/classes/meteor.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,14 +47,14 @@ def __init__(

def compute(self) -> Union[tuple[dict[str, Tensor], dict[str, Tensor]], Tensor]:
return meteor(
self._candidates,
self._mult_references,
self._return_all_scores,
self._cache_path,
self._java_path,
self._java_max_memory,
self._language,
self._verbose,
candidates=self._candidates,
mult_references=self._mult_references,
return_all_scores=self._return_all_scores,
cache_path=self._cache_path,
java_path=self._java_path,
java_max_memory=self._java_max_memory,
language=self._language,
verbose=self._verbose,
)

def extra_repr(self) -> str:
Expand Down
14 changes: 7 additions & 7 deletions src/aac_metrics/classes/rouge_l.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,8 @@ def __init__(

def compute(self) -> Union[tuple[dict[str, Tensor], dict[str, Tensor]], Tensor]:
return _rouge_l_compute(
self._rouge_l_scores,
self._return_all_scores,
rouge_l_scs=self._rouge_l_scores,
return_all_scores=self._return_all_scores,
)

def extra_repr(self) -> str:
Expand All @@ -62,9 +62,9 @@ def update(
mult_references: list[list[str]],
) -> None:
self._rouge_l_scores = _rouge_l_update(
candidates,
mult_references,
self._beta,
self._tokenizer,
self._rouge_l_scores,
candidates=candidates,
mult_references=mult_references,
beta=self._beta,
tokenizer=self._tokenizer,
prev_rouge_l_scores=self._rouge_l_scores,
)
16 changes: 8 additions & 8 deletions src/aac_metrics/classes/sbert_sim.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,14 +57,14 @@ def __init__(

def compute(self) -> Union[tuple[dict[str, Tensor], dict[str, Tensor]], Tensor]:
return sbert_sim(
self._candidates,
self._mult_references,
self._return_all_scores,
self._sbert_model,
self._device,
self._batch_size,
self._reset_state,
self._verbose,
candidates=self._candidates,
mult_references=self._mult_references,
return_all_scores=self._return_all_scores,
sbert_model=self._sbert_model,
device=self._device,
batch_size=self._batch_size,
reset_state=self._reset_state,
verbose=self._verbose,
)

def extra_repr(self) -> str:
Expand Down
22 changes: 11 additions & 11 deletions src/aac_metrics/classes/spice.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,17 +57,17 @@ def __init__(

def compute(self) -> Union[tuple[dict[str, Tensor], dict[str, Tensor]], Tensor]:
return spice(
self._candidates,
self._mult_references,
self._return_all_scores,
self._cache_path,
self._java_path,
self._tmp_path,
self._n_threads,
self._java_max_memory,
self._timeout,
self._separate_cache_dir,
self._verbose,
candidates=self._candidates,
mult_references=self._mult_references,
return_all_scores=self._return_all_scores,
cache_path=self._cache_path,
java_path=self._java_path,
tmp_path=self._tmp_path,
n_threads=self._n_threads,
java_max_memory=self._java_max_memory,
timeout=self._timeout,
separate_cache_dir=self._separate_cache_dir,
verbose=self._verbose,
)

def extra_repr(self) -> str:
Expand Down
6 changes: 3 additions & 3 deletions src/aac_metrics/classes/spider.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,9 +61,9 @@ def __init__(

def compute(self) -> Union[tuple[dict[str, Tensor], dict[str, Tensor]], Tensor]:
return spider(
self._candidates,
self._mult_references,
self._return_all_scores,
candidates=self._candidates,
mult_references=self._mult_references,
return_all_scores=self._return_all_scores,
# CIDEr args
n=self._n,
sigma=self._sigma,
Expand Down
6 changes: 3 additions & 3 deletions src/aac_metrics/classes/spider_fl.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,9 +88,9 @@ def __init__(

def compute(self) -> Union[tuple[dict[str, Tensor], dict[str, Tensor]], Tensor]:
return spider_fl(
self._candidates,
self._mult_references,
self._return_all_scores,
candidates=self._candidates,
mult_references=self._mult_references,
return_all_scores=self._return_all_scores,
# CIDEr args
n=self._n,
sigma=self._sigma,
Expand Down
8 changes: 4 additions & 4 deletions src/aac_metrics/classes/spider_max.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,10 +63,10 @@ def __init__(

def compute(self) -> Union[tuple[dict[str, Tensor], dict[str, Tensor]], Tensor]:
return spider_max(
self._mult_candidates,
self._mult_references,
self._return_all_scores,
self._return_all_cands_scores,
mult_candidates=self._mult_candidates,
mult_references=self._mult_references,
return_all_scores=self._return_all_scores,
return_all_cands_scores=self._return_all_cands_scores,
n=self._n,
sigma=self._sigma,
cache_path=self._cache_path,
Expand Down
2 changes: 1 addition & 1 deletion src/aac_metrics/functional/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ def evaluate(
)

if preprocess:
common_kwds = dict(
common_kwds: dict[str, Any] = dict(
cache_path=cache_path,
java_path=java_path,
tmp_path=tmp_path,
Expand Down

0 comments on commit 3e8ddfa

Please sign in to comment.