-
Notifications
You must be signed in to change notification settings - Fork 238
/
Copy pathcomposite_loss.py
229 lines (191 loc) · 9.07 KB
/
composite_loss.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
# For licensing see accompanying LICENSE file.
# Copyright (C) 2023 Apple Inc. All Rights Reserved.
from __future__ import annotations
import argparse
import copy
import json
from typing import Any, List, Mapping, Tuple
from torch import Tensor
from loss_fn import LOSS_REGISTRY, BaseCriteria, build_loss_fn
from options.utils import flatten_yaml_as_dict
from utils import logger
# CompositeLoss can be used with any task. Therefore, we register both name and type
# as the same.
@LOSS_REGISTRY.register(name="composite_loss", type="composite_loss")
class CompositeLoss(BaseCriteria):
"""Combines different loss functions and returns the weighted sum of these losses.
`loss_category` and `loss_weight` are two mandatory keys that allows us to combine
different losses and compute their weighted sum. The `loss_category` specifies the category
of a loss function and is a string (e.g., classification). The `loss_weight` specifies
the contribution of a loss function and is a float value (e.g., 1.0). The sum of `loss_weight`s
corresponding to different loss functions is not required to be 1.
Args:
opts: command-line arguments
Example::
# Example yaml config for combining classification and neural_augmentation loss function is given below.
# Please note that configuration for each loss function should start with `-` in `composite_loss`.
loss:
category: "composite_loss"
composite_loss:
- loss_category: "classification"
loss_weight: 1.0
classification:
name: "cross_entropy"
cross_entropy:
label_smoothing: 0.1
- loss_category: "neural_augmentation"
loss_weight: 1.0
neural_augmentation:
perceptual_metric: "psnr"
target_value: [ 40, 10 ]
curriculum_method: "cosine"
"""
def __init__(
self,
opts: argparse.Namespace,
*args,
**kwargs,
) -> None:
(
task_loss_fn_mapping,
task_loss_wts_mapping,
) = CompositeLoss.build_composite_loss_fn(opts, *args, **kwargs)
super().__init__(opts, *args, **kwargs)
self.loss_fns = task_loss_fn_mapping
self.loss_weights = task_loss_wts_mapping
@classmethod
def build_composite_loss_fn(
cls, opts: argparse.Namespace, *args, **kwargs
) -> Tuple[Mapping[str, BaseCriteria], Mapping[str, float]]:
"""Build loss functions from command line arguments and loss registry
Args:
opts: command-line arguments
Returns:
A tuple of two dictionaries. The first dictionary, task_loss_fn_mapping, contains
information about loss function category and module. The second dictionary, `task_loss_wts_mapping`
contains the information about loss function category and weight.
"""
composite_loss_opts = getattr(opts, "loss.composite_loss")
if composite_loss_opts is None:
logger.error(
f"{cls.__name__} can't be None. Please specify --loss.composite-loss using yaml file"
)
if not isinstance(composite_loss_opts, List):
logger.error(
f"{cls.__name__} options are expected as a List. "
f"Got type: {type(composite_loss_opts)} and values: {composite_loss_opts}"
)
num_loss_fns = len(composite_loss_opts)
if num_loss_fns < 1:
logger.error(f"We need at least one loss function if using {cls.__name__}")
task_loss_fn_mapping = {}
task_loss_wts_mapping = {}
for i, composite_loss_opts_as_dict in enumerate(composite_loss_opts):
if "loss_category" not in composite_loss_opts_as_dict:
logger.error("loss_category is a mandatory key")
if "loss_weight" not in composite_loss_opts_as_dict:
logger.error("Loss weight is a mandatory")
loss_category = composite_loss_opts_as_dict.pop("loss_category")
loss_weight = composite_loss_opts_as_dict.pop("loss_weight")
if not isinstance(loss_weight, (float, int)):
logger.error(
f"loss weight should be either int or float. "
f"Got: value={loss_weight}, type={type(loss_weight)}"
)
# flatten the dictionary
composite_loss_opts_as_dict = flatten_yaml_as_dict(
composite_loss_opts_as_dict
)
# `composite_loss_opts_as_dict` only contains the values of command-line arguments that are
# defined in the yaml file. Therefore, if a user misses few arguments, we won't have access
# to default values, leading to an error. To avoid this, we create a local copy of global
# command-line arguments and update it with `composite_loss_opts_as_dict` arguments
loss_opts = copy.deepcopy(opts)
# override the global opts with loss_fn specific opts in local copy
for k, v in composite_loss_opts_as_dict.items():
# we need to prefix each argument with loss because we define individual losses as
# `loss.classification.*` and not `classification.*`
setattr(loss_opts, "loss." + k, v)
# given the category of a loss function, build the criteria
task_loss_fn_mapping[loss_category] = build_loss_fn(
opts=loss_opts, category=loss_category, *args, **kwargs
)
task_loss_wts_mapping[loss_category] = loss_weight
# see if the keys in task_loss_fn_mapping and task_loss_wts_mapping are the same or not
# i.e., intersection is null.
is_intersection = task_loss_fn_mapping.keys().isdisjoint(task_loss_wts_mapping)
assert is_intersection is False, (
f"The keys in task_loss_fn_mapping and task_loss_wts_mapping are not the same. "
f"Got: {task_loss_fn_mapping.keys()} and {task_loss_wts_mapping.keys()}"
)
return task_loss_fn_mapping, task_loss_wts_mapping
@classmethod
def add_arguments(cls, parser: argparse.ArgumentParser) -> argparse.ArgumentParser:
"""Add criterion-specific arguments to the parser."""
if cls != CompositeLoss:
# Don't re-register arguments in subclasses that don't override `add_arguments()`.
return parser
group = parser.add_argument_group(cls.__name__)
group.add_argument("--loss.composite-loss", type=json.loads, action="append")
return parser
def forward(
self, input_sample: Any, prediction: Any, target: Any, *args, **kwargs
) -> Any:
"""Compute the weighted sum of different loss functions.
Args:
input_sample: Input to the model.
prediction: Model's output
target: Ground truth labels
Returns:
A mapping of the form (str: scalar loss value) with `total_loss` as a mandatory key.
The other keys corresponds to loss category names and their values contain category-specific
scalar loss values. total_loss is weighted sum of these category-specific losses.
"""
outputs = {}
total_loss = 0.0
for loss_name, loss_layer in self.loss_fns.items():
loss_wt = self.loss_weights[loss_name]
loss_val = loss_layer(
input_sample=input_sample,
prediction=prediction,
target=target,
*args,
**kwargs,
)
if not isinstance(loss_val, (Tensor, Mapping)):
logger.error(
"Loss value is expected as a scalar or dictionary of scalars with total_loss as a "
"mandatory key."
)
if isinstance(loss_val, Mapping) and "total_loss" in loss_val:
loss_val = loss_val["total_loss"]
if not isinstance(loss_val, Tensor):
logger.error(
f"Value corresponding to total_loss key in {loss_val} is expected to be scalar."
f"Got: {type(loss_val)}"
)
# scale the loss
loss_val = loss_val * loss_wt
total_loss += loss_val
outputs[loss_name] = loss_val
outputs.update({"total_loss": total_loss})
return outputs
def train(self, mode: bool = True) -> None:
"""Sets the loss functions in training mode."""
for loss_name, loss_layer in self.loss_fns.items():
loss_layer.train(mode=mode)
def eval(self) -> None:
"""Sets the loss functions in evaluation mode."""
for loss_name, loss_layer in self.loss_fns.items():
loss_layer.eval()
def __repr__(self) -> str:
repr_str = f"{self.__class__.__name__}(\n\t"
for k, v in self.loss_fns.items():
repr_str += (
v.__repr__()
.replace("\n\t", " ")
.replace("\n)", f" loss_wt={self.loss_weights[k]})")
)
repr_str += "\n\t"
repr_str += "\n)"
return repr_str