-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathrwightman_sigmoid_loss.py
124 lines (107 loc) · 4.8 KB
/
rwightman_sigmoid_loss.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
import torch
import torch.nn as nn
import torch.nn.functional as F
from distributed_utils import neighbour_exchange_bidir_with_grad, neighbour_exchange_with_grad
# The following code is borrowed from Ross Wightman implementation at
# https://github.com/mlfoundations/open_clip/blob/a5ba05f7cab5ddab7c9967bfb8bbef303be6f3aa/src/open_clip/loss.py
# The code is borrowed for the purpose of testing the correctness of the Sigmoid Loss
class SigLipLoss(nn.Module):
"""Sigmoid Loss for Language Image Pre-Training (SigLIP) - https://arxiv.org/abs/2303.15343
@article{zhai2023sigmoid,
title={Sigmoid loss for language image pre-training},
author={Zhai, Xiaohua and Mustafa, Basil and Kolesnikov, Alexander and Beyer, Lucas},
journal={arXiv preprint arXiv:2303.15343},
year={2023}
}
"""
def __init__(
self,
cache_labels=False,
rank=0,
world_size=1,
bidir=True,
use_horovod=False,
):
super().__init__()
self.cache_labels = cache_labels
self.rank = rank
self.world_size = world_size
assert not use_horovod # FIXME need to look at hvd ops for ring transfers
self.use_horovod = use_horovod
self.bidir = bidir
# cache state FIXME cache not currently used, worthwhile?
self.prev_num_logits = 0
self.labels = {}
def get_ground_truth(self, device, dtype, num_logits, negative_only=False) -> torch.Tensor:
labels = -torch.ones((num_logits, num_logits), device=device, dtype=dtype)
if not negative_only:
labels = 2 * torch.eye(num_logits, device=device, dtype=dtype) + labels
return labels
def get_logits(self, image_features, text_features, logit_scale, logit_bias=None):
logits = logit_scale.exp() * image_features @ text_features.T
if logit_bias is not None:
logits += logit_bias
return logits
def _loss(
self, image_features, text_features, logit_scale, logit_bias=None, negative_only=False
):
logits = self.get_logits(image_features, text_features, logit_scale, logit_bias)
labels = self.get_ground_truth(
image_features.device,
image_features.dtype,
image_features.shape[0],
negative_only=negative_only,
)
loss = -F.logsigmoid(labels * logits).sum() / image_features.shape[0]
return loss
def forward(self, image_features, text_features, logit_scale, logit_bias, output_dict=False):
loss = self._loss(image_features, text_features, logit_scale, logit_bias)
if self.world_size > 1:
# exchange text features w/ neighbour world_size - 1 times
right_rank = (self.rank + 1) % self.world_size
left_rank = (self.rank - 1 + self.world_size) % self.world_size
if self.bidir:
text_features_to_right = text_features_to_left = text_features
num_bidir, remainder = divmod(self.world_size - 1, 2)
for i in range(num_bidir):
text_features_recv = neighbour_exchange_bidir_with_grad(
left_rank,
right_rank,
text_features_to_left,
text_features_to_right,
)
for f in text_features_recv:
loss += self._loss(
image_features,
f,
logit_scale,
logit_bias,
negative_only=True,
)
text_features_to_left, text_features_to_right = text_features_recv
if remainder:
text_features_recv = neighbour_exchange_with_grad(
left_rank, right_rank, text_features_to_right
)
loss += self._loss(
image_features,
text_features_recv,
logit_scale,
logit_bias,
negative_only=True,
)
else:
text_features_to_right = text_features
for i in range(self.world_size - 1):
text_features_from_left = neighbour_exchange_with_grad(
left_rank, right_rank, text_features_to_right
)
loss += self._loss(
image_features,
text_features_from_left,
logit_scale,
logit_bias,
negative_only=True,
)
text_features_to_right = text_features_from_left
return {"contrastive_loss": loss} if output_dict else loss