Skip to content

Commit

Permalink
Fix: BERTScoreMrefs backward compatibility with older torchmetrics ve…
Browse files Browse the repository at this point in the history
…rsions.
  • Loading branch information
Labbeti committed Jan 11, 2024
1 parent e671ed9 commit 04858ae
Show file tree
Hide file tree
Showing 7 changed files with 131 additions and 46 deletions.
7 changes: 7 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,13 @@

All notable changes to this project will be documented in this file.

## [0.5.4] UNRELEASED
### Added
- `Version` class to handle versionning.

### Fixed
- Backward compatibility of `BERTScoreMrefs` with torchmetrics prior to 1.0.0.

## [0.5.3] 2024-01-09
### Fixed
- Fix `BERTScoreMrefs` computation when all multiple references sizes are equal.
Expand Down
2 changes: 2 additions & 0 deletions src/aac_metrics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@

__author__ = "Etienne Labbé (Labbeti)"
__author_email__ = "[email protected]"
__docs__ = "Audio Captioning Metrics"
__docs_url__ = "https://aac-metrics.readthedocs.io/en/stable/"
__license__ = "MIT"
__maintainer__ = "Etienne Labbé (Labbeti)"
__name__ = "aac-metrics"
Expand Down
17 changes: 13 additions & 4 deletions src/aac_metrics/functional/bert_score_mrefs.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from typing import Callable, Optional, Union

import torch
import torchmetrics

from torch import nn, Tensor
from torchmetrics.functional.text.bert import bert_score, _DEFAULT_MODEL
Expand All @@ -14,6 +15,7 @@
from aac_metrics.utils.checks import check_metric_inputs
from aac_metrics.utils.collections import flat_list, unflat_list, duplicate_list
from aac_metrics.utils.globals import _get_device
from aac_metrics.utils.packaging import Version


DEFAULT_BERT_SCORE_MODEL = _DEFAULT_MODEL
Expand All @@ -32,7 +34,7 @@ def bert_score_mrefs(
max_length: int = 64,
reset_state: bool = True,
idf: bool = False,
reduction: Union[str, Callable[[Tensor, ...], Tensor]] = "max",
reduction: Union[str, Callable[..., Tensor]] = "max",
filter_nan: bool = True,
verbose: int = 0,
) -> Union[tuple[dict[str, Tensor], dict[str, Tensor]], Tensor]:
Expand Down Expand Up @@ -144,9 +146,16 @@ def bert_score_mrefs(
reduction_fn = reduction

if len(sizes) > 0 and all(size == sizes[0] for size in sizes):
sents_scores = {
k: reduction_fn(torch.stack(v), dim=1) for k, v in sents_scores.items()
}
torchmetrics_version = Version.from_str(torchmetrics.__version__)
if torchmetrics_version < Version(1, 0, 0):
# backward compatibility
sents_scores = {
k: reduction_fn(torch.as_tensor(v, dtype=dtype), dim=1) for k, v in sents_scores.items() # type: ignore
}
else:
sents_scores = {
k: reduction_fn(torch.stack(v), dim=1) for k, v in sents_scores.items() # type: ignore
}
else:
sents_scores = {
k: torch.stack([reduction_fn(torch.as_tensor(vi, dtype=dtype)) for vi in v])
Expand Down
23 changes: 9 additions & 14 deletions src/aac_metrics/utils/checks.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,18 +2,18 @@
# -*- coding: utf-8 -*-

import logging
import re
import subprocess

from pathlib import Path
from subprocess import CalledProcessError
from typing import Any, Union
from typing_extensions import TypeGuard

from aac_metrics.utils.packaging import Version


pylog = logging.getLogger(__name__)

VERSION_PATTERN = r"(?P<major>\d+)\.(?P<minor>\d+)\.(?P<patch>\d+).*"
MIN_JAVA_MAJOR_VERSION = 8
MAX_JAVA_MAJOR_VERSION = 13

Expand Down Expand Up @@ -103,18 +103,13 @@ def _get_java_version(java_path: str) -> str:
return version


def _check_java_version(version: str, min_major: int, max_major: int) -> bool:
result = re.match(VERSION_PATTERN, version)
if result is None:
raise ValueError(
f"Invalid Java version {version=}. (expected version with pattern={VERSION_PATTERN})"
)

major_version = int(result["major"])
minor_version = int(result["minor"])
def _check_java_version(version_str: str, min_major: int, max_major: int) -> bool:
version = Version.from_str(version_str)

if major_version == 1 and minor_version <= 8:
if version.major == "1" and version.minor <= "8":
# java <= 8 use versioning "1.MAJOR.MINOR" and > 8 use "MAJOR.MINOR.PATCH"
major_version = minor_version
version.major = version.minor
version.minor = version.patch
version.patch = "0" # unknown patch, but it does not matter here

return min_major <= major_version <= max_major
return Version(min_major, 0, 0) <= version < Version(max_major, 0, 0)
21 changes: 0 additions & 21 deletions src/aac_metrics/utils/imports.py

This file was deleted.

98 changes: 98 additions & 0 deletions src/aac_metrics/utils/packaging.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-

import re

from dataclasses import dataclass, asdict, astuple
from functools import cache
from importlib.util import find_spec
from typing import Any, ClassVar, Mapping


@cache
def _package_is_available(package_name: str) -> bool:
"""Returns True if package is installed in the current python environment."""
try:
return find_spec(package_name) is not None
except AttributeError:
# Python 3.6
return False
except (ImportError, ModuleNotFoundError):
# Python 3.7+
return False


@dataclass(init=True, repr=True, eq=True, slots=True)
class Version:
_VERSION_FORMAT: ClassVar[str] = "{major}.{minor}.{patch}"
_VERSION_PATTERN: ClassVar[
str
] = r"(?P<major>[^\.]+)\.(?P<minor>[^\.]+)\.(?P<patch>[^\.]+).*"

major: str
minor: str
patch: str

def __init__(self, major: Any, minor: Any, patch: Any) -> None:
major = str(major)
minor = str(minor)
patch = str(patch)

self.major = major
self.minor = minor
self.patch = patch

@classmethod
def from_dict(cls, version: Mapping[str, Any]) -> "Version":
major = version["major"]
minor = version["minor"]
patch = version["patch"]
return Version(major, minor, patch)

@classmethod
def from_str(cls, version: str) -> "Version":
matched = re.match(Version._VERSION_PATTERN, version)
if matched is None:
raise ValueError(
f"Invalid argument {version=}. (does not match pattern {Version._VERSION_PATTERN})"
)
matched_dict = matched.groupdict()
return cls.from_dict(matched_dict)

@classmethod
def from_tuple(cls, version: tuple[Any, Any, Any]) -> "Version":
major = version[0]
minor = version[1]
patch = version[2]
return Version(major, minor, patch)

def to_dict(self) -> dict[str, str]:
return asdict(self)

def to_str(self) -> str:
return Version._VERSION_FORMAT.format(**self.to_dict())

def to_tuple(self) -> tuple[str, str, str]:
return astuple(self) # type: ignore

def __lt__(self, other: "Version") -> bool:
if self.major < other.major:
return True
if self.major == other.major and self.minor < other.minor:
return True
if (
self.major == other.major
and self.minor == other.minor
and self.patch < other.patch
):
return True
return False

def __le__(self, other: "Version") -> bool:
return self == other or self < other

def __ge__(self, other: "Version") -> bool:
return not (self < other)

def __gt__(self, other: "Version") -> bool:
return not (self <= other)
9 changes: 2 additions & 7 deletions tests/test_bleu_tchmet.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,19 +5,14 @@

from unittest import TestCase

from aac_metrics.classes.bleu import BLEU
from aac_metrics.utils.imports import _TORCHMETRICS_AVAILABLE
from torchmetrics.text.bleu import BLEUScore

if _TORCHMETRICS_AVAILABLE:
from torchmetrics.text.bleu import BLEUScore
from aac_metrics.classes.bleu import BLEU


class TestBleu(TestCase):
# Tests methods
def test_bleu(self) -> None:
if not _TORCHMETRICS_AVAILABLE:
return None

cands = ["a man is speaking", "birds chirping"]
mrefs = [
[
Expand Down

0 comments on commit 04858ae

Please sign in to comment.