-
Notifications
You must be signed in to change notification settings - Fork 2
/
metrics.py
257 lines (210 loc) · 10.1 KB
/
metrics.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
import numpy as np
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from sklearn.utils.extmath import cartesian
from hausdorff import hausdorff_distance
__all__ = ['Dice loss', 'Cross entropy', 'Focal loss', 'Dice Iou Cross entropy', 'Binary dice loss']
class IOU(nn.Module):
'''
Calculate Intersection over Union (IoU) for semantic segmentation.
Args:
logits (torch.Tensor): Predicted tensor of shape (batch_size, num_classes, height, width, (depth))
target (torch.Tensor): Ground truth tensor of shape (batch_size, height, width, (depth))
num_classes (int): Number of classes
Returns:
tensor: Mean Intersection over Union (IoU) for the batch.
list: List of IOU score for each class
'''
def __init__(self, num_classes, ignore_index=[0]):
super(IOU, self).__init__()
self.num_classes = num_classes
self.ignore_index = ignore_index
def forward(self, logits, target):
pred = logits.argmax(dim=1)
target = target.argmax(dim=1)
ious = []
for cls in range(self.num_classes):
if cls in self.ignore_index: continue
pred_mask = (pred == cls)
target_mask = (target == cls)
intersection = (pred_mask & target_mask).sum().float()
union = (pred_mask | target_mask).sum().float()
if union == 0: iou = 1.0
else: iou = (intersection / union).item()
ious.append(iou)
mean_iou = sum(ious) / (self.num_classes - len(self.ignore_index))
return torch.tensor(mean_iou), ious
class BinaryDice(nn.Module):
'''
Calculate Binary Dice score and Dice loss for binary segmentation or each class in Multiclass segmentation
Args:
logits (torch.Tensor): Predicted tensor of shape (batch_size, height, width, (depth))
target (torch.Tensor): Ground truth tensor of shape (batch_size, height, width. (depth))
Returns:
tensor: Dice score
tensor: Dice loss
'''
def __init__(self, smooth=1e-5, p=2):
super(BinaryDice, self).__init__()
self.smooth = smooth
self.p = p
def forward(self, logits, target):
assert logits.shape[0] == target.shape[0], "logits & Target batch size don't match"
smooth = 1e-5
intersect = torch.sum(logits * target)
y_sum = torch.sum(target * target)
z_sum = torch.sum(logits * logits)
dice = (2 * intersect + smooth) / (z_sum + y_sum + smooth)
loss = 1 - dice
return dice, loss
class Dice(nn.Module):
'''
Calculate Dice score and Dice loss for multiclass semantic segmentation
Args:
output (torch.Tensor): Predicted tensor of shape (batch_size, num_classes, height, width, (depth))
target (torch.Tensor): Ground truth tensor of shape (batch_size, height, width, (depth))
num_classes (int): Number of classes
Returns:
tensor: Mean dice score over classes
tensor: Mean dice loss over classes
list: dice score for each classes
listL dice loss for each classes
'''
def __init__(self, num_classes, weight=None, softmax=True, ignore_index=[0]):
super(Dice, self).__init__()
self.num_classes = num_classes
self.weight = weight
self.softmax = softmax
self.ignore_index = ignore_index
self.binary_dice = BinaryDice()
def forward(self, logits, target):
assert logits.shape == target.shape, 'logits & Target shape do not match'
if self.softmax: logits = F.softmax(logits, dim=1)
DICE, LOSS = 0.0, 0.0
CLS_DICE, CLS_LOSS = [], []
for clx in range(target.shape[1]):
if clx in self.ignore_index: continue
dice, loss = self.binary_dice(logits[:, clx], target[:, clx])
CLS_DICE.append(dice.item())
CLS_LOSS.append(loss.item())
if self.weight is not None: dice *= self.weights[clx]
DICE += dice
LOSS += loss
num_valid_classes = self.num_classes - len(self.ignore_index)
return DICE / num_valid_classes, LOSS / num_valid_classes, CLS_DICE, CLS_LOSS
class WeightedHausdorffDistance(nn.Module):
def __init__(self, height, width, p=-9, return_2_terms=False, device=torch.device('cuda')):
'''
height (int): image height
width (int): image width
return_2_terms (bool): Whether to return the 2 terms
of the WHD instead of their sum.
'''
super().__init__()
self.height, self.width = height, width
self.size = torch.tensor([height, width], dtype=torch.get_default_dtype(), device=device)
self.max_dist = math.sqrt(height**2 + width**2)
self.n_pixels = height * width
self.all_img_locations = torch.from_numpy(cartesian([np.arange(height), np.arange(width)]))
self.all_img_locations = self.all_img_locations.to(device=device, dtype=torch.get_default_dtype())
self.return_2_terms = return_2_terms
self.p = p
def _assert_no_grad(self, variables):
for var in variables:
assert not var.requires_grad, \
"nn criterions don't compute the gradient w.r.t. targets - please " \
"mark these variables as volatile or not requiring gradients"
def cdist(self, x, y):
'''
Compute distance between each pair of the two collections of inputs.
x: Nxd Tensor
y: Mxd Tensor
return: NxM matrix where dist[i,j] is the norm between x[i,:] and y[j,:]
i.e. dist[i,j] = || x[i,:] - y[j,:] ||
'''
difs = x.unsqueeze(1) - y.unsqueeze(0)
dists = torch.sum(difs**2, -1).sqrt()
return dists
def generalize_mean(self, tensor, dim, p=-9, keepdim=False):
assert p < 0
res= torch.mean((tensor + 1e-6)**p, dim, keepdim=keepdim)**(1./p)
return res
def forward(self, prob_map, gt, orig_sizes):
'''
prob_map: (B x H x W) Tensor of the probability map of the estimation.
B is batch size, H is height and W is width.
Values must be between 0 and 1.
gt: List of Tensors of the Ground Truth points.
Must be of size B as in prob_map.
Each element in the list must be a 2D Tensor,
where each row is the (y, x), i.e, (row, col) of a GT point.
orig_sizes: Bx2 Tensor containing the size
of the original images.
B is batch size.
The size must be in (height, width) format.
return: Single-scalar Tensor with the Weighted Hausdorff Distance.
If self.return_2_terms=True, then return a tuple containing
the two terms of the Weighted Hausdorff Distance.
'''
self._assert_no_grad(gt)
assert prob_map.dim() == 3, 'The probability map must be (B x H x W)'
assert prob_map.size()[1:3] == (self.height, self.width), \
'You must configure the WeightedHausdorffDistance with the height and width of the ' \
'probability map that you are using, got a probability map of size %s'\
% str(prob_map.size())
batch_size = prob_map.shape[0]
assert batch_size == len(gt)
terms_1 = []
terms_2 = []
for b in range(batch_size):
# One by one
prob_map_b = prob_map[b, :, :]
gt_b = gt[b]
orig_size_b = orig_sizes[b, :]
norm_factor = (orig_size_b / self.size).unsqueeze(0)
n_gt_pts = gt_b.size()[0]
# Corner case: no GT points
if gt_b.ndimension() == 1 and (gt_b < 0).all().item() == 0:
terms_1.append(torch.tensor([0],
dtype=torch.get_default_dtype()))
terms_2.append(torch.tensor([self.max_dist],
dtype=torch.get_default_dtype()))
continue
# Pairwise distances between all possible locations and the GTed locations
n_gt_pts = gt_b.size()[0]
normalized_x = norm_factor.repeat(self.n_pixels, 1) * self.all_img_locations
normalized_y = norm_factor.repeat(len(gt_b), 1) * gt_b
d_matrix = self.cdist(normalized_x, normalized_y)
# Reshape probability map as a long column vector
# and prepare it for mulitplication
p = prob_map_b.view(prob_map_b.nelement())
n_est_pts = p.sum()
p_replicated = p.view(-1, 1).repeat(1, n_gt_pts)
# Weighted Hausdorff Distance
term_1 = (1 / (n_est_pts + 1e-6)) * torch.sum(p * torch.min(d_matrix, 1)[0])
weighted_d_matrix = (1 - p_replicated)*self.max_dist + p_replicated*d_matrix
minn = self.generalize_mean(weighted_d_matrix,
p=self.p,
dim=0, keepdim=False)
term_2 = torch.mean(minn)
terms_1.append(term_1)
terms_2.append(term_2)
terms_1 = torch.stack(terms_1)
terms_2 = torch.stack(terms_2)
if self.return_2_terms: res = terms_1.mean(), terms_2.means()
else: res = terms_1.mean() + terms_2.mean()
return res
class HD(nn.Module):
def __init__(self):
super().__init__()
def forward(self, logits, target):
_,logits = torch.max(logits, dim=1)
_,target = torch.max(target, dim=1)
logits = logits.detach().cpu().numpy()
target = target.detach().cpu().numpy()
hd = 0
for index in range(logits.shape[0]):
hd += hausdorff_distance(logits[index], target[index], distance='euclidean')
return hd / logits.shape[0]