Skip to content

Commit 03d5192

Browse files
Merge pull request #605 from KevinMusgrave/dev
v2.1.0
2 parents 691a635 + d94576c commit 03d5192

File tree

7 files changed

+240
-1
lines changed

7 files changed

+240
-1
lines changed

CONTENTS.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
| [**NormalizedSoftmaxLoss**](https://kevinmusgrave.github.io/pytorch-metric-learning/losses/#normalizedsoftmaxloss) | - [NormFace: L2 Hypersphere Embedding for Face Verification](https://arxiv.org/pdf/1704.06369.pdf) <br/> - [Classification is a Strong Baseline for DeepMetric Learning](https://arxiv.org/pdf/1811.12649.pdf)
2828
| [**NPairsLoss**](https://kevinmusgrave.github.io/pytorch-metric-learning/losses/#npairsloss) | [Improved Deep Metric Learning with Multi-class N-pair Loss Objective](http://www.nec-labs.com/uploads/images/Department-Images/MediaAnalytics/papers/nips16_npairmetriclearning.pdf)
2929
| [**NTXentLoss**](https://kevinmusgrave.github.io/pytorch-metric-learning/losses/#ntxentloss) | - [Representation Learning with Contrastive Predictive Coding](https://arxiv.org/pdf/1807.03748.pdf) <br/> - [Momentum Contrast for Unsupervised Visual Representation Learning](https://arxiv.org/pdf/1911.05722.pdf) <br/> - [A Simple Framework for Contrastive Learning of Visual Representations](https://arxiv.org/abs/2002.05709)
30+
| [**PNPLoss**](https://kevinmusgrave.github.io/pytorch-metric-learning/losses/#pnploss) | [Rethinking the Optimization of Average Precision: Only Penalizing Negative Instances before Positive Ones is Enough](https://arxiv.org/pdf/2102.04640.pdf)
3031
| [**ProxyAnchorLoss**](https://kevinmusgrave.github.io/pytorch-metric-learning/losses/#proxyanchorloss) | [Proxy Anchor Loss for Deep Metric Learning](https://arxiv.org/pdf/2003.13911.pdf)
3132
| [**ProxyNCALoss**](https://kevinmusgrave.github.io/pytorch-metric-learning/losses/#proxyncaloss) | [No Fuss Distance Metric Learning using Proxies](https://arxiv.org/pdf/1703.07464.pdf)
3233
| [**SignalToNoiseRatioContrastiveLoss**](https://kevinmusgrave.github.io/pytorch-metric-learning/losses/#signaltonoiseratiocontrastiveloss) | [Signal-to-Noise Ratio: A Robust Distance Metric for Deep Metric Learning](http://openaccess.thecvf.com/content_CVPR_2019/papers/Yuan_Signal-To-Noise_Ratio_A_Robust_Distance_Metric_for_Deep_Metric_Learning_CVPR_2019_paper.pdf)

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -232,6 +232,7 @@ Thanks to the contributors who made pull requests!
232232
| [elias-ramzi](https://github.com/elias-ramzi) | [HierarchicalSampler](https://kevinmusgrave.github.io/pytorch-metric-learning/samplers/#hierarchicalsampler) |
233233
| [fjsj](https://github.com/fjsj) | [SupConLoss](https://kevinmusgrave.github.io/pytorch-metric-learning/losses/#supconloss) |
234234
| [AlenUbuntu](https://github.com/AlenUbuntu) | [CircleLoss](https://kevinmusgrave.github.io/pytorch-metric-learning/losses/#circleloss) |
235+
| [interestingzhuo](https://github.com/interestingzhuo) | [**PNPLoss**](https://kevinmusgrave.github.io/pytorch-metric-learning/losses/#pnploss) |
235236
| [wconnell](https://github.com/wconnell) | [Learning a scRNAseq Metric Embedding](https://github.com/KevinMusgrave/pytorch-metric-learning/blob/master/examples/notebooks/scRNAseq_MetricEmbedding.ipynb) |
236237
| [AlexSchuy](https://github.com/AlexSchuy) | optimized ```utils.loss_and_miner_utils.get_random_triplet_indices``` |
237238
| [JohnGiorgi](https://github.com/JohnGiorgi) | ```all_gather``` in [utils.distributed](https://kevinmusgrave.github.io/pytorch-metric-learning/distributed) |

docs/losses.md

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -760,6 +760,14 @@ losses.NTXentLoss(temperature=0.07, **kwargs)
760760

761761
* **loss**: The loss per positive pair in the batch. Reduction type is ```"pos_pair"```.
762762

763+
764+
## PNPLoss
765+
[Rethinking the Optimization of Average Precision: Only Penalizing Negative Instances before Positive Ones is Enough](https://arxiv.org/pdf/2102.04640.pdf){target=_blank}
766+
```python
767+
losses.PNPLoss(b=2, alpha=1, anneal=0.01, variant="O", **kwargs)
768+
```
769+
770+
763771
## ProxyAnchorLoss
764772
[Proxy Anchor Loss for Deep Metric Learning](https://arxiv.org/pdf/2003.13911.pdf){target=_blank}
765773
```python
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
__version__ = "2.0.1"
1+
__version__ = "2.1.0"

src/pytorch_metric_learning/losses/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
from .nca_loss import NCALoss
2121
from .normalized_softmax_loss import NormalizedSoftmaxLoss
2222
from .ntxent_loss import NTXentLoss
23+
from .pnp_loss import PNPLoss
2324
from .proxy_anchor_loss import ProxyAnchorLoss
2425
from .proxy_losses import ProxyNCALoss
2526
from .self_supervised_loss import SelfSupervisedLoss
Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,95 @@
1+
import torch
2+
3+
from ..distances import CosineSimilarity
4+
from ..utils import common_functions as c_f
5+
from ..utils import loss_and_miner_utils as lmu
6+
from .base_metric_loss_function import BaseMetricLossFunction
7+
8+
9+
class PNPLoss(BaseMetricLossFunction):
10+
VARIANTS = ["Ds", "Dq", "Iu", "Ib", "O"]
11+
12+
def __init__(self, b=2, alpha=1, anneal=0.01, variant="O", **kwargs):
13+
super().__init__(**kwargs)
14+
c_f.assert_distance_type(self, CosineSimilarity)
15+
self.b = b
16+
self.alpha = alpha
17+
self.anneal = anneal
18+
self.variant = variant
19+
if self.variant not in self.VARIANTS:
20+
raise ValueError(f"variant={variant} but must be one of {self.VARIANTS}")
21+
22+
"""
23+
Adapted from https://github.com/interestingzhuo/PNPloss
24+
"""
25+
26+
def compute_loss(self, embeddings, labels, indices_tuple, ref_emb, ref_labels):
27+
c_f.indices_tuple_not_supported(indices_tuple)
28+
c_f.labels_required(labels)
29+
c_f.ref_not_supported(embeddings, labels, ref_emb, ref_labels)
30+
dtype, device = embeddings.dtype, embeddings.device
31+
32+
N = labels.size(0)
33+
a1_idx, p_idx, a2_idx, n_idx = lmu.get_all_pairs_indices(labels)
34+
I_pos = torch.zeros(N, N, dtype=dtype, device=device)
35+
I_neg = torch.zeros(N, N, dtype=dtype, device=device)
36+
I_pos[a1_idx, p_idx] = 1
37+
I_pos[a1_idx, a1_idx] = 1
38+
I_neg[a2_idx, n_idx] = 1
39+
N_pos = torch.sum(I_pos, dim=1)
40+
safe_N = N_pos > 0
41+
if torch.sum(safe_N) == 0:
42+
return self.zero_losses()
43+
sim_all = self.distance(embeddings)
44+
45+
mask = I_neg.unsqueeze(dim=0).repeat(N, 1, 1)
46+
47+
sim_all_repeat = sim_all.unsqueeze(dim=1).repeat(1, N, 1)
48+
# compute the difference matrix
49+
sim_diff = sim_all_repeat - sim_all_repeat.permute(0, 2, 1)
50+
# pass through the sigmoid and ignores the relevance score of the query to itself
51+
sim_sg = self.sigmoid(sim_diff, temp=self.anneal) * mask
52+
# compute the number of negatives before
53+
sim_all_rk = torch.sum(sim_sg, dim=-1)
54+
55+
if self.variant == "Ds":
56+
sim_all_rk = torch.log(1 + sim_all_rk)
57+
elif self.variant == "Dq":
58+
sim_all_rk = 1 / (1 + sim_all_rk) ** (self.alpha)
59+
60+
elif self.variant == "Iu":
61+
sim_all_rk = (1 + sim_all_rk) * torch.log(1 + sim_all_rk)
62+
63+
elif self.variant == "Ib":
64+
b = self.b
65+
sim_all_rk = 1 / b**2 * (b * sim_all_rk - torch.log(1 + b * sim_all_rk))
66+
elif self.variant == "O":
67+
pass
68+
else:
69+
raise Exception(f"variant <{self.variant}> not available!")
70+
71+
loss = torch.sum(sim_all_rk * I_pos, dim=-1) / N_pos.reshape(-1)
72+
loss = torch.sum(loss) / N
73+
if self.variant == "Dq":
74+
loss = 1 - loss
75+
76+
return {
77+
"loss": {
78+
"losses": loss,
79+
"indices": torch.where(safe_N)[0],
80+
"reduction_type": "already_reduced",
81+
}
82+
}
83+
84+
def sigmoid(self, tensor, temp=1.0):
85+
"""temperature controlled sigmoid
86+
takes as input a torch tensor (tensor) and passes it through a sigmoid, controlled by temperature: temp
87+
"""
88+
exponent = -tensor / temp
89+
# clamp the input tensor for stability
90+
exponent = torch.clamp(exponent, min=-50, max=50)
91+
y = 1.0 / (1.0 + torch.exp(exponent))
92+
return y
93+
94+
def get_default_distance(self):
95+
return CosineSimilarity()

tests/losses/test_pnp_loss.py

Lines changed: 133 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,133 @@
1+
import unittest
2+
3+
import torch
4+
import torch.nn
5+
import torch.nn.functional
6+
import torch.nn.functional as F
7+
8+
from pytorch_metric_learning.losses import PNPLoss
9+
10+
from .. import TEST_DEVICE, TEST_DTYPES
11+
12+
13+
class OriginalImplementationPNP(torch.nn.Module):
14+
def __init__(self, b, alpha, anneal, variant, bs, classes):
15+
super(OriginalImplementationPNP, self).__init__()
16+
self.b = b
17+
self.alpha = alpha
18+
self.anneal = anneal
19+
self.variant = variant
20+
self.batch_size = bs
21+
self.num_id = classes
22+
self.samples_per_class = int(bs / classes)
23+
24+
mask = 1.0 - torch.eye(self.batch_size)
25+
for i in range(self.num_id):
26+
mask[
27+
i * (self.samples_per_class) : (i + 1) * (self.samples_per_class),
28+
i * (self.samples_per_class) : (i + 1) * (self.samples_per_class),
29+
] = 0
30+
31+
self.mask = mask.unsqueeze(dim=0).repeat(self.batch_size, 1, 1)
32+
33+
def forward(self, batch):
34+
35+
dtype, device = batch.dtype, batch.device
36+
self.mask = self.mask.type(dtype).to(device)
37+
# compute the relevance scores via cosine similarity of the CNN-produced embedding vectors
38+
39+
sim_all = self.compute_aff(batch)
40+
41+
sim_all_repeat = sim_all.unsqueeze(dim=1).repeat(1, self.batch_size, 1)
42+
# compute the difference matrix
43+
sim_diff = sim_all_repeat - sim_all_repeat.permute(0, 2, 1)
44+
# pass through the sigmoid and ignores the relevance score of the query to itself
45+
sim_sg = self.sigmoid(sim_diff, temp=self.anneal) * self.mask
46+
# compute the rankings,all batch
47+
sim_all_rk = torch.sum(sim_sg, dim=-1)
48+
if self.variant == "PNP-D_s":
49+
sim_all_rk = torch.log(1 + sim_all_rk)
50+
elif self.variant == "PNP-D_q":
51+
sim_all_rk = 1 / (1 + sim_all_rk) ** (self.alpha)
52+
53+
elif self.variant == "PNP-I_u":
54+
sim_all_rk = (1 + sim_all_rk) * torch.log(1 + sim_all_rk)
55+
56+
elif self.variant == "PNP-I_b":
57+
b = self.b
58+
sim_all_rk = 1 / b**2 * (b * sim_all_rk - torch.log(1 + b * sim_all_rk))
59+
elif self.variant == "PNP-O":
60+
pass
61+
else:
62+
raise Exception("variantation <{}> not available!".format(self.variant))
63+
64+
# sum the values of the Smooth-AP for all instances in the mini-batch
65+
loss = torch.zeros(1).type(dtype).to(device)
66+
group = int(self.batch_size / self.num_id)
67+
68+
for ind in range(self.num_id):
69+
neg_divide = torch.sum(
70+
sim_all_rk[
71+
(ind * group) : ((ind + 1) * group),
72+
(ind * group) : ((ind + 1) * group),
73+
]
74+
/ group
75+
)
76+
loss = loss + (neg_divide / self.batch_size)
77+
if self.variant == "PNP-D_q":
78+
return 1 - loss
79+
else:
80+
return loss
81+
82+
def sigmoid(self, tensor, temp=1.0):
83+
"""temperature controlled sigmoid
84+
takes as input a torch tensor (tensor) and passes it through a sigmoid, controlled by temperature: temp
85+
"""
86+
exponent = -tensor / temp
87+
# clamp the input tensor for stability
88+
exponent = torch.clamp(exponent, min=-50, max=50)
89+
y = 1.0 / (1.0 + torch.exp(exponent))
90+
return y
91+
92+
def compute_aff(self, x):
93+
"""computes the affinity matrix between an input vector and itself"""
94+
return torch.mm(x, x.t())
95+
96+
97+
class TestPNPLoss(unittest.TestCase):
98+
def test_pnp_loss(self):
99+
torch.manual_seed(30293)
100+
bs = 180
101+
classes = 30
102+
for variant in PNPLoss.VARIANTS:
103+
original_variant = {
104+
"Ds": "PNP-D_s",
105+
"Dq": "PNP-D_q",
106+
"Iu": "PNP-I_u",
107+
"Ib": "PNP-I_b",
108+
"O": "PNP-O",
109+
}[variant]
110+
b, alpha, anneal = 2, 4, 0.01
111+
loss_func = PNPLoss(b, alpha, anneal, variant)
112+
original_loss_func = OriginalImplementationPNP(
113+
b, alpha, anneal, original_variant, bs, classes
114+
).to(TEST_DEVICE)
115+
116+
for dtype in TEST_DTYPES:
117+
embeddings = torch.randn(
118+
180, 32, dtype=dtype, device=TEST_DEVICE, requires_grad=True
119+
)
120+
labels = (
121+
torch.tensor([[i] * (int(bs / classes)) for i in range(classes)])
122+
.reshape(-1)
123+
.to(TEST_DEVICE)
124+
)
125+
loss = loss_func(embeddings, labels)
126+
loss.backward()
127+
correct_loss = original_loss_func(F.normalize(embeddings, dim=-1))
128+
129+
rtol = 1e-2 if dtype == torch.float16 else 1e-5
130+
self.assertTrue(torch.isclose(loss, correct_loss[0], rtol=rtol))
131+
132+
with self.assertRaises(ValueError):
133+
PNPLoss(b, alpha, anneal, "PNP")

0 commit comments

Comments
 (0)