Skip to content

Commit

Permalink
Mod: AACMetric now always inherit from nn.Module instead of torchmetr…
Browse files Browse the repository at this point in the history
…ics.Metric. Update CHANGELOG with fixed multiple heritance pb with Evaluate class.
  • Loading branch information
Labbeti committed May 26, 2023
1 parent 6aa531b commit 0f83021
Show file tree
Hide file tree
Showing 3 changed files with 27 additions and 39 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ All notable changes to this project will be documented in this file.
## [0.4.3] UNRELEASED
### Fixed
- Output name `sbert_sim` in FENSE and SBERTSim classes.
- `Evaluate` class instantiation with `torchmetrics` >= 0.11.

## [0.4.2] 2023-04-19
### Fixed
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ dev = [
"black==22.8.0",
"scikit-image==0.19.2",
"matplotlib==3.5.2",
"torchmetrics>=0.10",
]

[tool.setuptools.packages.find]
Expand Down
64 changes: 25 additions & 39 deletions src/aac_metrics/classes/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,51 +3,37 @@

from typing import Any, Optional

from aac_metrics.utils.imports import _TORCHMETRICS_AVAILABLE
from torch import nn


if _TORCHMETRICS_AVAILABLE:
from torchmetrics import Metric as __BaseMetric # type: ignore
class AACMetric(nn.Module):
"""Base Metric module used when torchmetrics is not installed."""

class AACMetric(__BaseMetric): # type: ignore
"""Base Metric module used when torchmetrics is installed."""
# Global values
full_state_update: Optional[bool] = False
higher_is_better: Optional[bool] = None
is_differentiable: Optional[bool] = False

# The theorical minimal value of the main global score of the metric.
min_value: Optional[float] = None
# The theorical maximal value of the main global score of the metric.
max_value: Optional[float] = None
# The theorical minimal value of the main global score of the metric.
min_value: Optional[float] = None
# The theorical maximal value of the main global score of the metric.
max_value: Optional[float] = None

else:
from torch import nn
def __init__(self, **kwargs: Any) -> None:
super().__init__(**kwargs)

class AACMetric(nn.Module):
"""Base Metric module used when torchmetrics is not installed."""
# Public methods
def compute(self) -> Any:
return None

# Global values
full_state_update: Optional[bool] = False
higher_is_better: Optional[bool] = None
is_differentiable: Optional[bool] = False
def forward(self, *args: Any, **kwargs: Any) -> Any:
self.update(*args, **kwargs)
output = self.compute()
self.reset()
return output

# The theorical minimal value of the main global score of the metric.
min_value: Optional[float] = None
# The theorical maximal value of the main global score of the metric.
max_value: Optional[float] = None
def reset(self) -> None:
pass

def __init__(self, **kwargs: Any) -> None:
super().__init__(**kwargs)

# Public methods
def compute(self) -> Any:
return None

def forward(self, *args: Any, **kwargs: Any) -> Any:
self.update(*args, **kwargs)
output = self.compute()
self.reset()
return output

def reset(self) -> None:
pass

def update(self, *args, **kwargs) -> None:
pass
def update(self, *args, **kwargs) -> None:
pass

0 comments on commit 0f83021

Please sign in to comment.