-
Notifications
You must be signed in to change notification settings - Fork 238
/
Copy pathbase_criteria.py
61 lines (48 loc) · 1.73 KB
/
base_criteria.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
#
# For licensing see accompanying LICENSE file.
# Copyright (C) 2023 Apple Inc. All Rights Reserved.
#
import abc
import argparse
from typing import Any
from torch import nn
from utils import logger
class BaseCriteria(nn.Module, abc.ABC):
"""Base class for defining loss functions. Sub-classes must implement compute_loss function.
Args:
opts: command line arguments
"""
def __init__(self, opts: argparse.Namespace, *args, **kwargs) -> None:
super(BaseCriteria, self).__init__()
self.opts = opts
# small value for numerical stability purposes that sub-classes may want to use.
self.eps = 1e-7
@classmethod
def add_arguments(cls, parser: argparse.ArgumentParser) -> argparse.ArgumentParser:
"""Add criterion-specific arguments to the parser."""
if cls != BaseCriteria:
# Don't re-register arguments in subclasses that don't override `add_arguments()`.
return parser
group = parser.add_argument_group(cls.__name__)
group.add_argument(
"--loss.category",
type=str,
default=None,
help="Loss function category (e.g., classification). Defaults to None.",
)
return parser
@abc.abstractmethod
def forward(
self, input_sample: Any, prediction: Any, target: Any, *args, **kwargs
) -> Any:
"""Compute the loss.
Args:
input_sample: Input to the model.
prediction: Model's output
target: Ground truth labels
"""
raise NotImplementedError
def extra_repr(self) -> str:
return ""
def __repr__(self) -> str:
return "{}({}\n)".format(self.__class__.__name__, self.extra_repr())