Skip to content

Commit

Permalink
Del: Remove Version dataclass and use packaging instead to handle jav…
Browse files Browse the repository at this point in the history
…a and torchmetrics versions.
  • Loading branch information
Labbeti committed Feb 2, 2024
1 parent 0e1d3b3 commit dd30ec8
Show file tree
Hide file tree
Showing 5 changed files with 18 additions and 119 deletions.
3 changes: 0 additions & 3 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,6 @@
All notable changes to this project will be documented in this file.

## [0.5.4] UNRELEASED
### Added
- `Version` class to handle versionning.

### Fixed
- Backward compatibility of `BERTScoreMrefs` with torchmetrics prior to 1.0.0.

Expand Down
7 changes: 4 additions & 3 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
# -*- coding: utf-8 -*-

torch>=1.10.1
numpy>=1.21.2
packaging>=23
pyyaml>=6.0
tqdm>=4.64.0
sentence-transformers>=2.2.2
transformers
torch>=1.10.1
torchmetrics>=0.11.4
tqdm>=4.64.0
transformers
16 changes: 7 additions & 9 deletions src/aac_metrics/functional/bert_score_mrefs.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,18 +5,16 @@

import torch
import torchmetrics

from torch import nn, Tensor
from torchmetrics.functional.text.bert import bert_score, _DEFAULT_MODEL
from packaging.version import Version
from torch import Tensor, nn
from torchmetrics.functional.text.bert import _DEFAULT_MODEL, bert_score
from transformers import logging as tfmers_logging
from transformers.models.auto.modeling_auto import AutoModel
from transformers.models.auto.tokenization_auto import AutoTokenizer
from transformers import logging as tfmers_logging

from aac_metrics.utils.checks import check_metric_inputs
from aac_metrics.utils.collections import flat_list, unflat_list, duplicate_list
from aac_metrics.utils.collections import duplicate_list, flat_list, unflat_list
from aac_metrics.utils.globals import _get_device
from aac_metrics.utils.packaging import Version


DEFAULT_BERT_SCORE_MODEL = _DEFAULT_MODEL
REDUCTIONS = ("mean", "max", "min")
Expand Down Expand Up @@ -146,8 +144,8 @@ def bert_score_mrefs(
reduction_fn = reduction

if len(sizes) > 0 and all(size == sizes[0] for size in sizes):
torchmetrics_version = Version.from_str(torchmetrics.__version__)
if torchmetrics_version < Version(1, 0, 0):
torchmetrics_version = Version(torchmetrics.__version__)
if torchmetrics_version < Version("1.0.0"):
# backward compatibility
sents_scores = {
k: reduction_fn(torch.as_tensor(v, dtype=dtype), dim=1) for k, v in sents_scores.items() # type: ignore
Expand Down
16 changes: 7 additions & 9 deletions src/aac_metrics/utils/checks.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,12 @@

import logging
import subprocess

from pathlib import Path
from subprocess import CalledProcessError
from typing import Any, Union
from typing_extensions import TypeGuard

from aac_metrics.utils.packaging import Version

from packaging.version import Version
from typing_extensions import TypeGuard

pylog = logging.getLogger(__name__)

Expand Down Expand Up @@ -104,12 +102,12 @@ def _get_java_version(java_path: str) -> str:


def _check_java_version(version_str: str, min_major: int, max_major: int) -> bool:
version = Version.from_str(version_str)
version = Version(version_str)

if version.major == 1 and version.minor <= 8:
# java <= 8 use versioning "1.MAJOR.MINOR" and > 8 use "MAJOR.MINOR.PATCH"
# java <= 8 use versioning "1.MAJOR.MINOR" and > 8 use "MAJOR.MINOR.MICRO"
version.major = version.minor
version.minor = version.patch
version.patch = 0 # unknown patch, but it does not matter here
version.minor = version.micro
version.micro = 0 # unknown micro, but it does not matter here

return Version(min_major) <= version < Version(max_major + 1)
return Version(f"{min_major}") <= version < Version(f"{max_major + 1}")
95 changes: 0 additions & 95 deletions src/aac_metrics/utils/packaging.py

This file was deleted.

0 comments on commit dd30ec8

Please sign in to comment.