-
Notifications
You must be signed in to change notification settings - Fork 4
/
Copy pathutils.py
96 lines (86 loc) · 3.49 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
def _expand_binary_labels(labels, label_weights, label_channels):
bin_labels = labels.new_full((labels.size(0), label_channels), 0)
inds = torch.nonzero(labels >= 1).squeeze()
if inds.numel() > 0:
bin_labels[inds, labels[inds] - 1] = 1
bin_label_weights = label_weights.view(-1, 1).expand(
label_weights.size(0), label_channels)
return bin_labels, bin_label_weights
class FocalLoss(nn.Module):
def __init__(self, gamma=0, alpha=None, reduction=True):
super(FocalLoss, self).__init__()
self.gamma = gamma
self.alpha = alpha
if isinstance(alpha,(float,int)): self.alpha = torch.Tensor([alpha,1-alpha])
if isinstance(alpha,list): self.alpha = torch.Tensor(alpha)
self.reduction = reduction
def forward(self, input, target):
logpt = F.cross_entropy(input, target, reduction='none')
pt = torch.exp(-logpt)
if self.alpha is not None:
if self.alpha.type()!=input.data.type():
self.alpha = self.alpha.type_as(input.data)
at = self.alpha.gather(0,target.data.view(-1))
logpt = logpt * Variable(at)
loss = (1-pt)**self.gamma * logpt
if self.reduction: return loss.mean()
else: return loss
class GHMC(nn.Module):
def __init__(
self,
bins=10,
momentum=0,
use_sigmoid=True,
loss_weight=1.0):
super(GHMC, self).__init__()
self.bins = bins
self.momentum = momentum
self.edges = [float(x) / bins for x in range(bins+1)]
self.edges[-1] += 1e-6
if momentum > 0:
self.acc_sum = [0.0 for _ in range(bins)]
self.use_sigmoid = use_sigmoid
self.loss_weight = loss_weight
def forward(self, pred, target, label_weight, *args, **kwargs):
""" Args:
pred [batch_num, class_num]:
The direct prediction of classification fc layer.
target [batch_num, class_num]:
Binary class target for each sample.
label_weight [batch_num, class_num]:
the value is 1 if the sample is valid and 0 if ignored.
"""
if not self.use_sigmoid:
raise NotImplementedError
# the target should be binary class label
if pred.dim() != target.dim():
target, label_weight = _expand_binary_labels(target, label_weight, pred.size(-1))
target, label_weight = target.float(), label_weight.float()
edges = self.edges
mmt = self.momentum
weights = torch.zeros_like(pred)
# gradient length
g = torch.abs(pred.sigmoid().detach() - target)
valid = label_weight > 0
tot = max(valid.float().sum().item(), 1.0)
n = 0 # n valid bins
for i in range(self.bins):
inds = (g >= edges[i]) & (g < edges[i+1]) & valid
num_in_bin = inds.sum().item()
if num_in_bin > 0:
if mmt > 0:
self.acc_sum[i] = mmt * self.acc_sum[i] \
+ (1 - mmt) * num_in_bin
weights[inds] = tot / self.acc_sum[i]
else:
weights[inds] = tot / num_in_bin
n += 1
if n > 0:
weights = weights / n
loss = F.binary_cross_entropy_with_logits(
pred, target, weights, reduction='sum') / tot
return loss * self.loss_weight