Skip to content

Commit

Permalink
Mod: Update typing for language servers.
Browse files Browse the repository at this point in the history
  • Loading branch information
Labbeti committed Apr 12, 2024
1 parent 5fd196e commit 33a745b
Show file tree
Hide file tree
Showing 28 changed files with 212 additions and 151 deletions.
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,10 @@

All notable changes to this project will be documented in this file.

## [0.5.5] UNRELEASED
### Changed
- Update metric typing for language servers.

## [0.5.4] 2024-03-04
### Fixed
- Backward compatibility of `BERTScoreMrefs` with torchmetrics prior to 1.0.0.
Expand Down
8 changes: 6 additions & 2 deletions src/aac_metrics/classes/bert_score_mrefs.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,15 @@
from aac_metrics.functional.bert_score_mrefs import (
DEFAULT_BERT_SCORE_MODEL,
REDUCTIONS,
BERTScoreMRefsOuts,
Reduction,
_load_model_and_tokenizer,
bert_score_mrefs,
)
from aac_metrics.utils.globals import _get_device


class BERTScoreMRefs(AACMetric):
class BERTScoreMRefs(AACMetric[Union[BERTScoreMRefsOuts, Tensor]]):
"""BERTScore metric which supports multiple references.
The implementation is based on the bert_score implementation of torchmetrics.
Expand All @@ -37,6 +38,7 @@ class BERTScoreMRefs(AACMetric):
def __init__(
self,
return_all_scores: bool = True,
*,
model: Union[str, nn.Module] = DEFAULT_BERT_SCORE_MODEL,
device: Union[str, torch.device, None] = "cuda_if_available",
batch_size: int = 32,
Expand Down Expand Up @@ -79,7 +81,9 @@ def __init__(
self._candidates = []
self._mult_references = []

def compute(self) -> Union[tuple[dict[str, Tensor], dict[str, Tensor]], Tensor]:
def compute(
self,
) -> Union[BERTScoreMRefsOuts, Tensor]:
return bert_score_mrefs(
candidates=self._candidates,
mult_references=self._mult_references,
Expand Down
6 changes: 4 additions & 2 deletions src/aac_metrics/classes/bleu.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,13 @@
from aac_metrics.functional.bleu import (
BLEU_OPTIONS,
BleuOption,
BLEUOuts,
_bleu_compute,
_bleu_update,
)


class BLEU(AACMetric[Union[tuple[dict[str, Tensor], dict[str, Tensor]], Tensor]]):
class BLEU(AACMetric[Union[BLEUOuts, Tensor]]):
"""BiLingual Evaluation Understudy metric class.
- Paper: https://www.aclweb.org/anthology/P02-1040.pdf
Expand All @@ -32,6 +33,7 @@ class BLEU(AACMetric[Union[tuple[dict[str, Tensor], dict[str, Tensor]], Tensor]]
def __init__(
self,
return_all_scores: bool = True,
*,
n: int = 4,
option: BleuOption = "closest",
verbose: int = 0,
Expand All @@ -52,7 +54,7 @@ def __init__(
self._cooked_cands = []
self._cooked_mrefs = []

def compute(self) -> Union[tuple[dict[str, Tensor], dict[str, Tensor]], Tensor]:
def compute(self) -> Union[BLEUOuts, Tensor]:
return _bleu_compute(
cooked_cands=self._cooked_cands,
cooked_mrefs=self._cooked_mrefs,
Expand Down
12 changes: 5 additions & 7 deletions src/aac_metrics/classes/cider_d.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,15 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-

from typing import Any, Callable, Union
from typing import Callable, Union

from torch import Tensor

from aac_metrics.classes.base import AACMetric
from aac_metrics.functional.cider_d import (
_cider_d_compute,
_cider_d_update,
)
from aac_metrics.functional.cider_d import CIDErDOuts, _cider_d_compute, _cider_d_update


class CIDErD(AACMetric[Union[tuple[dict[str, Tensor], dict[str, Any]], Tensor]]):
class CIDErD(AACMetric[Union[CIDErDOuts, Tensor]]):
"""Consensus-based Image Description Evaluation metric class.
- Paper: https://arxiv.org/pdf/1411.5726.pdf
Expand All @@ -30,6 +27,7 @@ class CIDErD(AACMetric[Union[tuple[dict[str, Tensor], dict[str, Any]], Tensor]])
def __init__(
self,
return_all_scores: bool = True,
*,
n: int = 4,
sigma: float = 6.0,
tokenizer: Callable[[str], list[str]] = str.split,
Expand All @@ -47,7 +45,7 @@ def __init__(
self._cooked_cands = []
self._cooked_mrefs = []

def compute(self) -> Union[tuple[dict[str, Tensor], dict[str, Tensor]], Tensor]:
def compute(self) -> Union[CIDErDOuts, Tensor]:
return _cider_d_compute(
cooked_cands=self._cooked_cands,
cooked_mrefs=self._cooked_mrefs,
Expand Down
12 changes: 5 additions & 7 deletions src/aac_metrics/classes/fense.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,28 +2,25 @@
# -*- coding: utf-8 -*-

import logging

from typing import Union

import torch

from sentence_transformers import SentenceTransformer
from torch import Tensor

from aac_metrics.classes.base import AACMetric
from aac_metrics.functional.fense import fense, _load_models_and_tokenizer
from aac_metrics.functional.fense import FENSEOuts, _load_models_and_tokenizer, fense
from aac_metrics.functional.fer import (
BERTFlatClassifier,
_ERROR_NAMES,
DEFAULT_FER_MODEL,
BERTFlatClassifier,
)
from aac_metrics.functional.sbert_sim import DEFAULT_SBERT_SIM_MODEL


pylog = logging.getLogger(__name__)


class FENSE(AACMetric[Union[tuple[dict[str, Tensor], dict[str, Tensor]], Tensor]]):
class FENSE(AACMetric[Union[FENSEOuts, Tensor]]):
"""Fluency ENhanced Sentence-bert Evaluation (FENSE)
- Paper: https://arxiv.org/abs/2110.04684
Expand All @@ -42,6 +39,7 @@ class FENSE(AACMetric[Union[tuple[dict[str, Tensor], dict[str, Tensor]], Tensor]
def __init__(
self,
return_all_scores: bool = True,
*,
sbert_model: Union[str, SentenceTransformer] = DEFAULT_SBERT_SIM_MODEL,
echecker: Union[str, BERTFlatClassifier] = DEFAULT_FER_MODEL,
error_threshold: float = 0.9,
Expand Down Expand Up @@ -77,7 +75,7 @@ def __init__(
self._candidates = []
self._mult_references = []

def compute(self) -> Union[tuple[dict[str, Tensor], dict[str, Tensor]], Tensor]:
def compute(self) -> Union[FENSEOuts, Tensor]:
return fense(
candidates=self._candidates,
mult_references=self._mult_references,
Expand Down
15 changes: 7 additions & 8 deletions src/aac_metrics/classes/fer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,28 +2,26 @@
# -*- coding: utf-8 -*-

import logging

from typing import Union

import torch

from torch import Tensor

from aac_metrics.classes.base import AACMetric
from aac_metrics.functional.fer import (
BERTFlatClassifier,
fer,
_load_echecker_and_tokenizer,
_ERROR_NAMES,
DEFAULT_FER_MODEL,
BERTFlatClassifier,
FEROuts,
_load_echecker_and_tokenizer,
fer,
)
from aac_metrics.utils.globals import _get_device


pylog = logging.getLogger(__name__)


class FER(AACMetric[Union[tuple[dict[str, Tensor], dict[str, Tensor]], Tensor]]):
class FER(AACMetric[Union[FEROuts, Tensor]]):
"""Return Fluency Error Rate (FER) detected by a pre-trained BERT model.
- Paper: https://arxiv.org/abs/2110.04684
Expand All @@ -42,6 +40,7 @@ class FER(AACMetric[Union[tuple[dict[str, Tensor], dict[str, Tensor]], Tensor]])
def __init__(
self,
return_all_scores: bool = True,
*,
echecker: Union[str, BERTFlatClassifier] = DEFAULT_FER_MODEL,
error_threshold: float = 0.9,
device: Union[str, torch.device, None] = "cuda_if_available",
Expand Down Expand Up @@ -72,7 +71,7 @@ def __init__(

self._candidates = []

def compute(self) -> Union[tuple[dict[str, Tensor], dict[str, Tensor]], Tensor]:
def compute(self) -> Union[FEROuts, Tensor]:
return fer(
candidates=self._candidates,
return_all_scores=self._return_all_scores,
Expand Down
7 changes: 4 additions & 3 deletions src/aac_metrics/classes/meteor.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,10 @@
from torch import Tensor

from aac_metrics.classes.base import AACMetric
from aac_metrics.functional.meteor import Language, meteor
from aac_metrics.functional.meteor import Language, METEOROuts, meteor


class METEOR(AACMetric[Union[tuple[dict[str, Tensor], dict[str, Tensor]], Tensor]]):
class METEOR(AACMetric[Union[METEOROuts, Tensor]]):
"""Metric for Evaluation of Translation with Explicit ORdering metric class.
- Paper: https://dl.acm.org/doi/pdf/10.5555/1626355.1626389
Expand All @@ -29,6 +29,7 @@ class METEOR(AACMetric[Union[tuple[dict[str, Tensor], dict[str, Tensor]], Tensor
def __init__(
self,
return_all_scores: bool = True,
*,
cache_path: Union[str, Path, None] = None,
java_path: Union[str, Path, None] = None,
java_max_memory: str = "2G",
Expand All @@ -52,7 +53,7 @@ def __init__(
self._candidates = []
self._mult_references = []

def compute(self) -> Union[tuple[dict[str, Tensor], dict[str, Tensor]], Tensor]:
def compute(self) -> Union[METEOROuts, Tensor]:
return meteor(
candidates=self._candidates,
mult_references=self._mult_references,
Expand Down
10 changes: 4 additions & 6 deletions src/aac_metrics/classes/rouge_l.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,10 @@
from torch import Tensor

from aac_metrics.classes.base import AACMetric
from aac_metrics.functional.rouge_l import (
_rouge_l_compute,
_rouge_l_update,
)
from aac_metrics.functional.rouge_l import ROUGELOuts, _rouge_l_compute, _rouge_l_update


class ROUGEL(AACMetric[Union[tuple[dict[str, Tensor], dict[str, Tensor]], Tensor]]):
class ROUGEL(AACMetric[Union[ROUGELOuts, Tensor]]):
"""Recall-Oriented Understudy for Gisting Evaluation class.
- Paper: https://aclanthology.org/W04-1013.pdf
Expand All @@ -30,6 +27,7 @@ class ROUGEL(AACMetric[Union[tuple[dict[str, Tensor], dict[str, Tensor]], Tensor
def __init__(
self,
return_all_scores: bool = True,
*,
beta: float = 1.2,
tokenizer: Callable[[str], list[str]] = str.split,
) -> None:
Expand All @@ -40,7 +38,7 @@ def __init__(

self._rouge_l_scores = []

def compute(self) -> Union[tuple[dict[str, Tensor], dict[str, Tensor]], Tensor]:
def compute(self) -> Union[ROUGELOuts, Tensor]:
return _rouge_l_compute(
rouge_l_scs=self._rouge_l_scores,
return_all_scores=self._return_all_scores,
Expand Down
13 changes: 6 additions & 7 deletions src/aac_metrics/classes/sbert_sim.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,27 +2,25 @@
# -*- coding: utf-8 -*-

import logging

from typing import Union

import torch

from sentence_transformers import SentenceTransformer
from torch import Tensor

from aac_metrics.classes.base import AACMetric
from aac_metrics.functional.sbert_sim import (
sbert_sim,
_load_sbert,
DEFAULT_SBERT_SIM_MODEL,
SBERTSimOuts,
_load_sbert,
sbert_sim,
)
from aac_metrics.utils.globals import _get_device


pylog = logging.getLogger(__name__)


class SBERTSim(AACMetric[Union[tuple[dict[str, Tensor], dict[str, Tensor]], Tensor]]):
class SBERTSim(AACMetric[Union[SBERTSimOuts, Tensor]]):
"""Cosine-similarity of the Sentence-BERT embeddings.
- Paper: https://arxiv.org/abs/1908.10084
Expand All @@ -41,6 +39,7 @@ class SBERTSim(AACMetric[Union[tuple[dict[str, Tensor], dict[str, Tensor]], Tens
def __init__(
self,
return_all_scores: bool = True,
*,
sbert_model: Union[str, SentenceTransformer] = DEFAULT_SBERT_SIM_MODEL,
device: Union[str, torch.device, None] = "cuda_if_available",
batch_size: int = 32,
Expand All @@ -65,7 +64,7 @@ def __init__(
self._candidates = []
self._mult_references = []

def compute(self) -> Union[tuple[dict[str, Tensor], dict[str, Tensor]], Tensor]:
def compute(self) -> Union[SBERTSimOuts, Tensor]:
return sbert_sim(
candidates=self._candidates,
mult_references=self._mult_references,
Expand Down
9 changes: 4 additions & 5 deletions src/aac_metrics/classes/spice.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,20 +2,18 @@
# -*- coding: utf-8 -*-

import logging

from pathlib import Path
from typing import Iterable, Optional, Union

from torch import Tensor

from aac_metrics.classes.base import AACMetric
from aac_metrics.functional.spice import spice

from aac_metrics.functional.spice import SPICEOuts, spice

pylog = logging.getLogger(__name__)


class SPICE(AACMetric[Union[tuple[dict[str, Tensor], dict[str, Tensor]], Tensor]]):
class SPICE(AACMetric[Union[SPICEOuts, Tensor]]):
"""Semantic Propositional Image Caption Evaluation class.
- Paper: https://arxiv.org/pdf/1607.08822.pdf
Expand All @@ -33,6 +31,7 @@ class SPICE(AACMetric[Union[tuple[dict[str, Tensor], dict[str, Tensor]], Tensor]
def __init__(
self,
return_all_scores: bool = True,
*,
cache_path: Union[str, Path, None] = None,
java_path: Union[str, Path, None] = None,
tmp_path: Union[str, Path, None] = None,
Expand All @@ -58,7 +57,7 @@ def __init__(
self._candidates = []
self._mult_references = []

def compute(self) -> Union[tuple[dict[str, Tensor], dict[str, Tensor]], Tensor]:
def compute(self) -> Union[SPICEOuts, Tensor]:
return spice(
candidates=self._candidates,
mult_references=self._mult_references,
Expand Down
Loading

0 comments on commit 33a745b

Please sign in to comment.