diff --git a/CHANGELOG.md b/CHANGELOG.md index 50f7d49..a18612b 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,6 +5,7 @@ All notable changes to this project will be documented in this file. ## [0.4.3] UNRELEASED ### Fixed - Output name `sbert_sim` in FENSE and SBERTSim classes. +- `Evaluate` class instantiation with `torchmetrics` >= 0.11. ## [0.4.2] 2023-04-19 ### Fixed diff --git a/pyproject.toml b/pyproject.toml index e4fef69..18a97cd 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -49,6 +49,7 @@ dev = [ "black==22.8.0", "scikit-image==0.19.2", "matplotlib==3.5.2", + "torchmetrics>=0.10", ] [tool.setuptools.packages.find] diff --git a/src/aac_metrics/classes/base.py b/src/aac_metrics/classes/base.py index 110d1f0..aa9253c 100644 --- a/src/aac_metrics/classes/base.py +++ b/src/aac_metrics/classes/base.py @@ -3,51 +3,37 @@ from typing import Any, Optional -from aac_metrics.utils.imports import _TORCHMETRICS_AVAILABLE +from torch import nn -if _TORCHMETRICS_AVAILABLE: - from torchmetrics import Metric as __BaseMetric # type: ignore +class AACMetric(nn.Module): + """Base Metric module used when torchmetrics is not installed.""" - class AACMetric(__BaseMetric): # type: ignore - """Base Metric module used when torchmetrics is installed.""" + # Global values + full_state_update: Optional[bool] = False + higher_is_better: Optional[bool] = None + is_differentiable: Optional[bool] = False - # The theorical minimal value of the main global score of the metric. - min_value: Optional[float] = None - # The theorical maximal value of the main global score of the metric. - max_value: Optional[float] = None + # The theorical minimal value of the main global score of the metric. + min_value: Optional[float] = None + # The theorical maximal value of the main global score of the metric. + max_value: Optional[float] = None -else: - from torch import nn + def __init__(self, **kwargs: Any) -> None: + super().__init__(**kwargs) - class AACMetric(nn.Module): - """Base Metric module used when torchmetrics is not installed.""" + # Public methods + def compute(self) -> Any: + return None - # Global values - full_state_update: Optional[bool] = False - higher_is_better: Optional[bool] = None - is_differentiable: Optional[bool] = False + def forward(self, *args: Any, **kwargs: Any) -> Any: + self.update(*args, **kwargs) + output = self.compute() + self.reset() + return output - # The theorical minimal value of the main global score of the metric. - min_value: Optional[float] = None - # The theorical maximal value of the main global score of the metric. - max_value: Optional[float] = None + def reset(self) -> None: + pass - def __init__(self, **kwargs: Any) -> None: - super().__init__(**kwargs) - - # Public methods - def compute(self) -> Any: - return None - - def forward(self, *args: Any, **kwargs: Any) -> Any: - self.update(*args, **kwargs) - output = self.compute() - self.reset() - return output - - def reset(self) -> None: - pass - - def update(self, *args, **kwargs) -> None: - pass + def update(self, *args, **kwargs) -> None: + pass