-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathloss.py
119 lines (92 loc) · 5.06 KB
/
loss.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
import torch
import torch.nn as nn
# borrow from https://github.com/alfonmedela/triplet-loss-pytorch/blob/master/loss_functions/triplet_loss.py
def pairwise_distance_torch(embeddings, device):
"""Computes the pairwise distance matrix with numerical stability.
output[i, j] = || feature[i, :] - feature[j, :] ||_2
Args:
embeddings: 2-D Tensor of size [number of data, feature dimension].
Returns:
pairwise_distances: 2-D Tensor of size [number of data, number of data].
"""
# pairwise distance matrix with precise embeddings
precise_embeddings = embeddings.to(dtype=torch.float32)
c1 = torch.pow(precise_embeddings, 2).sum(axis=-1)
c2 = torch.pow(precise_embeddings.transpose(0, 1), 2).sum(axis=0)
c3 = precise_embeddings @ precise_embeddings.transpose(0, 1)
c1 = c1.reshape((c1.shape[0], 1))
c2 = c2.reshape((1, c2.shape[0]))
c12 = c1 + c2
pairwise_distances_squared = c12 - 2.0 * c3
# Deal with numerical inaccuracies. Set small negatives to zero.
pairwise_distances_squared = torch.max(pairwise_distances_squared, torch.tensor([0.]).to(device))
# Get the mask where the zero distances are at.
error_mask = pairwise_distances_squared.clone()
error_mask[error_mask > 0.0] = 1.
error_mask[error_mask <= 0.0] = 0.
pairwise_distances = torch.mul(pairwise_distances_squared, error_mask)
# Explicitly set diagonals to zero.
mask_offdiagonals = torch.ones((pairwise_distances.shape[0], pairwise_distances.shape[1])) - torch.diag(torch.ones(pairwise_distances.shape[0]))
pairwise_distances = torch.mul(pairwise_distances.to(device), mask_offdiagonals.to(device))
return pairwise_distances
def TripletSemiHardLoss(y_true, y_pred, device, margin=1.0):
"""Computes the triplet loss_functions with semi-hard negative mining.
The loss_functions encourages the positive distances (between a pair of embeddings
with the same labels) to be smaller than the minimum negative distance
among which are at least greater than the positive distance plus the
margin constant (called semi-hard negative) in the mini-batch.
If no such negative exists, uses the largest negative distance instead.
See: https://arxiv.org/abs/1503.03832.
We expect labels `y_true` to be provided as 1-D integer `Tensor` with shape
[batch_size] of multi-class integer labels. And embeddings `y_pred` must be
2-D float `Tensor` of l2 normalized embedding vectors.
Args:
margin: Float, margin term in the loss_functions definition. Default value is 1.0.
name: Optional name for the op.
"""
labels, embeddings = y_true, y_pred
# Reshape label tensor to [batch_size, 1].
lshape = labels.shape
labels = torch.reshape(labels, [lshape[0], 1])
pdist_matrix = pairwise_distance_torch(embeddings, device)
# Build pairwise binary adjacency matrix.
adjacency = torch.eq(labels, labels.transpose(0, 1))
# Invert so we can select negatives only.
adjacency_not = adjacency.logical_not()
batch_size = labels.shape[0]
# Compute the mask.
pdist_matrix_tile = pdist_matrix.repeat(batch_size, 1)
adjacency_not_tile = adjacency_not.repeat(batch_size, 1)
transpose_reshape = pdist_matrix.transpose(0, 1).reshape(-1, 1)
greater = pdist_matrix_tile > transpose_reshape
mask = adjacency_not_tile & greater
# final mask
mask_step = mask.to(dtype=torch.float32)
mask_step = mask_step.sum(axis=1)
mask_step = mask_step > 0.0
mask_final = mask_step.reshape(batch_size, batch_size)
mask_final = mask_final.transpose(0, 1)
adjacency_not = adjacency_not.to(dtype=torch.float32)
mask = mask.to(dtype=torch.float32)
# negatives_outside: smallest D_an where D_an > D_ap.
axis_maximums = torch.max(pdist_matrix_tile, dim=1, keepdim=True)
masked_minimums = torch.min(torch.mul(pdist_matrix_tile - axis_maximums[0], mask), dim=1, keepdim=True)[0] + axis_maximums[0]
negatives_outside = masked_minimums.reshape([batch_size, batch_size])
negatives_outside = negatives_outside.transpose(0, 1)
# negatives_inside: largest D_an.
axis_minimums = torch.min(pdist_matrix, dim=1, keepdim=True)
masked_maximums = torch.max(torch.mul(pdist_matrix - axis_minimums[0], adjacency_not), dim=1, keepdim=True)[0] + axis_minimums[0]
negatives_inside = masked_maximums.repeat(1, batch_size)
semi_hard_negatives = torch.where(mask_final, negatives_outside, negatives_inside)
loss_mat = margin + pdist_matrix - semi_hard_negatives
mask_positives = adjacency.to(dtype=torch.float32) - torch.diag(torch.ones(batch_size)).to(device)
num_positives = mask_positives.sum()
triplet_loss = (torch.max(torch.mul(loss_mat, mask_positives), torch.tensor([0.]).to(device))).sum() / num_positives
triplet_loss = triplet_loss.to(dtype=embeddings.dtype)
return triplet_loss
class TripletLoss(nn.Module):
def __init__(self, device):
super().__init__()
self.device = device
def forward(self, input, target, **kwargs):
return TripletSemiHardLoss(target, input, self.device)