Skip to content

Commit

Permalink
Mod: Update argument typing for bert score, bleu, meteor and vocab me…
Browse files Browse the repository at this point in the history
…trics.
  • Loading branch information
Labbeti committed Mar 5, 2024
1 parent 1a4ba37 commit b030c8e
Show file tree
Hide file tree
Showing 12 changed files with 54 additions and 55 deletions.
14 changes: 7 additions & 7 deletions src/aac_metrics/classes/bert_score_mrefs.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,18 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-

from typing import Callable, Union
from typing import Union

import torch

from torch import nn, Tensor
from torch import Tensor, nn

from aac_metrics.classes.base import AACMetric
from aac_metrics.functional.bert_score_mrefs import (
bert_score_mrefs,
_load_model_and_tokenizer,
DEFAULT_BERT_SCORE_MODEL,
REDUCTIONS,
Reduction,
_load_model_and_tokenizer,
bert_score_mrefs,
)
from aac_metrics.utils.globals import _get_device

Expand Down Expand Up @@ -44,7 +44,7 @@ def __init__(
max_length: int = 64,
reset_state: bool = True,
idf: bool = False,
reduction: Union[str, Callable[[Tensor, ...], Tensor]] = "max",
reduction: Reduction = "max",
filter_nan: bool = True,
verbose: int = 0,
) -> None:
Expand Down Expand Up @@ -110,7 +110,7 @@ def extra_repr(self) -> str:
def get_output_names(self) -> tuple[str, ...]:
return (
"bert_score.precision",
"bert_score.recalll",
"bert_score.recall",
"bert_score.f1",
)

Expand Down
3 changes: 2 additions & 1 deletion src/aac_metrics/classes/bleu.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from aac_metrics.classes.base import AACMetric
from aac_metrics.functional.bleu import (
BLEU_OPTIONS,
BleuOption,
_bleu_compute,
_bleu_update,
)
Expand All @@ -32,7 +33,7 @@ def __init__(
self,
return_all_scores: bool = True,
n: int = 4,
option: str = "closest",
option: BleuOption = "closest",
verbose: int = 0,
tokenizer: Callable[[str], list[str]] = str.split,
) -> None:
Expand Down
4 changes: 2 additions & 2 deletions src/aac_metrics/classes/meteor.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from torch import Tensor

from aac_metrics.classes.base import AACMetric
from aac_metrics.functional.meteor import meteor
from aac_metrics.functional.meteor import Language, meteor


class METEOR(AACMetric[Union[tuple[dict[str, Tensor], dict[str, Tensor]], Tensor]]):
Expand All @@ -32,7 +32,7 @@ def __init__(
cache_path: Union[str, Path, None] = None,
java_path: Union[str, Path, None] = None,
java_max_memory: str = "2G",
language: str = "en",
language: Language = "en",
use_shell: Optional[bool] = None,
params: Optional[Iterable[float]] = None,
weights: Optional[Iterable[float]] = None,
Expand Down
7 changes: 2 additions & 5 deletions src/aac_metrics/classes/vocab.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,13 @@

import logging
import math

from typing import Callable, Union

import torch

from torch import Tensor

from aac_metrics.classes.base import AACMetric
from aac_metrics.functional.vocab import vocab

from aac_metrics.functional.vocab import PopStrategy, vocab

pylog = logging.getLogger(__name__)

Expand All @@ -36,7 +33,7 @@ def __init__(
seed: Union[None, int, torch.Generator] = 1234,
tokenizer: Callable[[str], list[str]] = str.split,
dtype: torch.dtype = torch.float64,
pop_strategy: str = "max",
pop_strategy: PopStrategy = "max",
verbose: int = 0,
) -> None:
super().__init__()
Expand Down
7 changes: 4 additions & 3 deletions src/aac_metrics/functional/bert_score_mrefs.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-

from typing import Callable, Optional, Union
from typing import Callable, Literal, Optional, Union

import torch
import torchmetrics
Expand All @@ -18,6 +18,7 @@

DEFAULT_BERT_SCORE_MODEL = _DEFAULT_MODEL
REDUCTIONS = ("mean", "max", "min")
Reduction = Union[Literal["mean", "max", "min"], Callable[..., Tensor]]


def bert_score_mrefs(
Expand All @@ -32,7 +33,7 @@ def bert_score_mrefs(
max_length: int = 64,
reset_state: bool = True,
idf: bool = False,
reduction: Union[str, Callable[..., Tensor]] = "max",
reduction: Reduction = "max",
filter_nan: bool = True,
verbose: int = 0,
) -> Union[tuple[dict[str, Tensor], dict[str, Tensor]], Tensor]:
Expand Down Expand Up @@ -62,7 +63,7 @@ def bert_score_mrefs(
:param verbose: The verbose level. defaults to 0.
:returns: A tuple of globals and locals scores or a scalar tensor with the main global score.
"""
check_metric_inputs(candidates, mult_references)
check_metric_inputs(candidates, mult_references, min_length=1)

if isinstance(model, str):
if tokenizer is not None:
Expand Down
24 changes: 11 additions & 13 deletions src/aac_metrics/functional/bleu.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,28 +3,26 @@

import logging
import math

from collections import Counter
from typing import Any, Callable, Optional, Union
from typing import Any, Callable, Literal, Optional, Union

import torch

from torch import Tensor

from aac_metrics.utils.checks import check_metric_inputs


pylog = logging.getLogger(__name__)

BLEU_OPTIONS = ("shortest", "average", "closest")
BleuOption = Literal["shortest", "average", "closest"]


def bleu(
candidates: list[str],
mult_references: list[list[str]],
return_all_scores: bool = True,
n: int = 4,
option: str = "closest",
option: BleuOption = "closest",
verbose: int = 0,
tokenizer: Callable[[str], list[str]] = str.split,
return_1_to_n: bool = False,
Expand Down Expand Up @@ -72,7 +70,7 @@ def bleu_1(
candidates: list[str],
mult_references: list[list[str]],
return_all_scores: bool = True,
option: str = "closest",
option: BleuOption = "closest",
verbose: int = 0,
tokenizer: Callable[[str], list[str]] = str.split,
return_1_to_n: bool = False,
Expand All @@ -93,7 +91,7 @@ def bleu_2(
candidates: list[str],
mult_references: list[list[str]],
return_all_scores: bool = True,
option: str = "closest",
option: BleuOption = "closest",
verbose: int = 0,
tokenizer: Callable[[str], list[str]] = str.split,
return_1_to_n: bool = False,
Expand All @@ -114,7 +112,7 @@ def bleu_3(
candidates: list[str],
mult_references: list[list[str]],
return_all_scores: bool = True,
option: str = "closest",
option: BleuOption = "closest",
verbose: int = 0,
tokenizer: Callable[[str], list[str]] = str.split,
return_1_to_n: bool = False,
Expand All @@ -135,7 +133,7 @@ def bleu_4(
candidates: list[str],
mult_references: list[list[str]],
return_all_scores: bool = True,
option: str = "closest",
option: BleuOption = "closest",
verbose: int = 0,
tokenizer: Callable[[str], list[str]] = str.split,
return_1_to_n: bool = False,
Expand Down Expand Up @@ -179,7 +177,7 @@ def _bleu_compute(
cooked_mrefs: list,
return_all_scores: bool = True,
n: int = 4,
option: str = "closest",
option: BleuOption = "closest",
verbose: int = 0,
return_1_to_n: bool = False,
) -> Union[Tensor, tuple[dict[str, Tensor], dict[str, Tensor]]]:
Expand All @@ -189,7 +187,7 @@ def _bleu_compute(
bleu_1_to_n_score, bleu_1_to_n_scores = __compute_bleu_score(
cooked_cands,
cooked_mrefs,
n,
n=n,
option=option,
verbose=verbose,
)
Expand Down Expand Up @@ -300,7 +298,7 @@ def __compute_bleu_score(
cooked_cands: list,
cooked_mrefs: list,
n: int,
option: Optional[str] = "closest",
option: BleuOption = "closest",
verbose: int = 0,
) -> tuple[list[float], list[list[float]]]:
SMALL = 1e-9
Expand Down Expand Up @@ -373,7 +371,7 @@ def __compute_bleu_score(

def __single_reflen(
reflens: list[int],
option: Optional[str] = None,
option: BleuOption,
testlen: Optional[int] = None,
) -> float:
if option == "shortest":
Expand Down
8 changes: 3 additions & 5 deletions src/aac_metrics/functional/meteor.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,25 +5,23 @@
import os.path as osp
import platform
import subprocess

from pathlib import Path
from subprocess import Popen
from typing import Iterable, Optional, Union
from typing import Iterable, Literal, Optional, Union

import torch

from torch import Tensor

from aac_metrics.utils.checks import check_java_path, check_metric_inputs
from aac_metrics.utils.globals import _get_cache_path, _get_java_path


pylog = logging.getLogger(__name__)


DNAME_METEOR_CACHE = osp.join("aac-metrics", "meteor")
FNAME_METEOR_JAR = osp.join(DNAME_METEOR_CACHE, "meteor-1.5.jar")
SUPPORTED_LANGUAGES = ("en", "cz", "de", "es", "fr")
Language = Literal["en", "cz", "de", "es", "fr"]


def meteor(
Expand All @@ -33,7 +31,7 @@ def meteor(
cache_path: Union[str, Path, None] = None,
java_path: Union[str, Path, None] = None,
java_max_memory: str = "2G",
language: str = "en",
language: Language = "en",
use_shell: Optional[bool] = None,
params: Optional[Iterable[float]] = None,
weights: Optional[Iterable[float]] = None,
Expand Down
13 changes: 6 additions & 7 deletions src/aac_metrics/functional/mult_cands.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,16 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-

from typing import Callable, Union
from typing import Callable, Literal, Union

import torch
import tqdm

from torch import Tensor

from aac_metrics.utils.checks import is_mult_sents


SELECTIONS = ("max", "min", "mean")
Selection = Literal["max", "min", "mean"]


def mult_cands_metric(
Expand All @@ -21,8 +20,8 @@ def mult_cands_metric(
mult_references: list[list[str]],
return_all_scores: bool = True,
return_all_cands_scores: bool = False,
selection: str = "max",
reduction: Callable[[Tensor], Tensor] = torch.mean,
selection: Selection = "max",
reduction_fn: Callable[[Tensor], Tensor] = torch.mean,
**kwargs,
) -> Union[tuple[dict[str, Tensor], dict[str, Tensor]], Tensor]:
"""Multiple candidates metric wrapper.
Expand All @@ -32,7 +31,7 @@ def mult_cands_metric(
:param mult_candidates: The list of list of sentences to evaluate.
:param mult_references: The references input.
:param selection: The selection to apply. Can be "max", "min" or "mean". defaults to "max".
:param reduction: The reduction function to apply to local scores. defaults to torch.mean.
:param reduction_fn: The reduction function to apply to local scores. defaults to torch.mean.
:param **kwargs: The keywords arguments given to the metric call.
:returns: A tuple of globals and locals scores or a scalar tensor with the main global score.
"""
Expand Down Expand Up @@ -113,7 +112,7 @@ def mult_cands_metric(
f"{k}_all": scores.transpose(0, 1) for k, scores in all_sents_scores.items()
}

reduction_fn = reduction
reduction_fn = reduction_fn
outs_corpus = {k: reduction_fn(scores) for k, scores in outs_sents.items()}

if return_all_scores:
Expand Down
3 changes: 1 addition & 2 deletions src/aac_metrics/functional/spider_max.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
from typing import Callable, Iterable, Optional, Union

import torch

from torch import Tensor

from aac_metrics.functional.mult_cands import mult_cands_metric
Expand Down Expand Up @@ -74,7 +73,7 @@ def spider_max(
return_all_scores=return_all_scores,
return_all_cands_scores=return_all_cands_scores,
selection="max",
reduction=torch.mean,
reduction_fn=torch.mean,
# CIDEr args
n=n,
sigma=sigma,
Expand Down
12 changes: 6 additions & 6 deletions src/aac_metrics/functional/vocab.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,27 +2,28 @@
# -*- coding: utf-8 -*-

import logging

from typing import Callable, Union
from typing import Callable, Literal, Union

import torch

from torch import Tensor

from aac_metrics.utils.checks import check_metric_inputs, is_mono_sents


pylog = logging.getLogger(__name__)


POP_STRATEGIES = ("max", "min")
PopStrategy = Literal["max", "min"]


def vocab(
candidates: list[str],
mult_references: Union[list[list[str]], None],
return_all_scores: bool = True,
seed: Union[None, int, torch.Generator] = 1234,
tokenizer: Callable[[str], list[str]] = str.split,
dtype: torch.dtype = torch.float64,
pop_strategy: str = "max",
pop_strategy: PopStrategy = "max",
verbose: int = 0,
) -> Union[tuple[dict[str, Tensor], dict[str, Tensor]], Tensor]:
"""Compute vocabulary statistics.
Expand Down Expand Up @@ -84,7 +85,6 @@ def vocab(
elif isinstance(pop_strategy, int):
n_samples = pop_strategy
else:
POP_STRATEGIES = ("max", "min")
raise ValueError(
f"Invalid argument {pop_strategy=}. (expected one of {POP_STRATEGIES} or an integer value)"
)
Expand Down
Loading

0 comments on commit b030c8e

Please sign in to comment.