Skip to content

Commit f27a8bf

Browse files
Merge pull request #577 from KevinMusgrave/dev
v1.7.3
2 parents 7536b2f + 9bd2276 commit f27a8bf

File tree

5 files changed

+103
-16
lines changed

5 files changed

+103
-16
lines changed
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
__version__ = "1.7.2"
1+
__version__ = "1.7.3"

src/pytorch_metric_learning/losses/cross_batch_memory.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,8 @@ def __init__(self, loss, embedding_size, memory_size=1024, miner=None, **kwargs)
1818
)
1919

2020
def forward(self, embeddings, labels, indices_tuple=None, enqueue_idx=None):
21+
if indices_tuple is not None and enqueue_idx is not None:
22+
raise ValueError("indices_tuple and enqueue_idx are mutually exclusive")
2123
if enqueue_idx is not None:
2224
assert len(enqueue_idx) <= len(self.embedding_memory)
2325
assert len(enqueue_idx) < len(embeddings)
@@ -46,7 +48,6 @@ def forward(self, embeddings, labels, indices_tuple=None, enqueue_idx=None):
4648
labels_for_queue = labels
4749
do_remove_self_comparisons = True
4850

49-
batch_size = len(embeddings)
5051
queue_batch_size = len(emb_for_queue)
5152
self.add_to_memory(emb_for_queue, labels_for_queue, queue_batch_size)
5253

@@ -58,7 +59,6 @@ def forward(self, embeddings, labels, indices_tuple=None, enqueue_idx=None):
5859
L_mem = self.label_memory
5960

6061
indices_tuple = self.create_indices_tuple(
61-
batch_size,
6262
embeddings,
6363
labels,
6464
E_mem,
@@ -85,7 +85,6 @@ def add_to_memory(self, embeddings, labels, batch_size):
8585

8686
def create_indices_tuple(
8787
self,
88-
batch_size,
8988
embeddings,
9089
labels,
9190
E_mem,
@@ -117,7 +116,9 @@ def create_indices_tuple(
117116
return indices_tuple
118117

119118
def reset_queue(self):
120-
self.embedding_memory = torch.zeros(self.memory_size, self.embedding_size)
121-
self.label_memory = torch.zeros(self.memory_size).long()
119+
self.register_buffer(
120+
"embedding_memory", torch.zeros(self.memory_size, self.embedding_size)
121+
)
122+
self.register_buffer("label_memory", torch.zeros(self.memory_size).long())
122123
self.has_been_filled = False
123124
self.queue_idx = 0

src/pytorch_metric_learning/utils/distributed.py

Lines changed: 32 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,19 @@ def get_indices_tuple(labels, ref_labels, embeddings=None, ref_emb=None, miner=N
6666
return lmu.remove_self_comparisons(indices_tuple, curr_batch_idx, len(ref_labels))
6767

6868

69+
def get_corrected_enqueue_idx(enqueue_idx, emb):
70+
if enqueue_idx is None:
71+
return enqueue_idx
72+
enqueue_idx = c_f.to_device(enqueue_idx, device=emb.device)
73+
bs = len(emb)
74+
e_len = len(enqueue_idx)
75+
world_size = torch.distributed.get_world_size()
76+
enqueue_idx = torch.cat([enqueue_idx, all_gather(enqueue_idx)], dim=0)
77+
for i in range(e_len, e_len * world_size, e_len):
78+
enqueue_idx[i:] += bs
79+
return enqueue_idx
80+
81+
6982
def select_ref_or_regular(regular, ref):
7083
return regular if ref is None else ref
7184

@@ -85,12 +98,18 @@ def __init__(self, loss, efficient=False):
8598
self.efficient = efficient
8699

87100
def forward(
88-
self, emb, labels=None, indices_tuple=None, ref_emb=None, ref_labels=None
101+
self,
102+
emb,
103+
labels=None,
104+
indices_tuple=None,
105+
ref_emb=None,
106+
ref_labels=None,
107+
enqueue_idx=None,
89108
):
90109
world_size = torch.distributed.get_world_size()
91110
common_args = [emb, labels, indices_tuple, ref_emb, ref_labels, world_size]
92111
if isinstance(self.loss, CrossBatchMemory):
93-
return self.forward_cross_batch(*common_args)
112+
return self.forward_cross_batch(*common_args, enqueue_idx)
94113
return self.forward_regular_loss(*common_args)
95114

96115
def forward_regular_loss(
@@ -118,20 +137,28 @@ def forward_regular_loss(
118137
return loss * world_size
119138

120139
def forward_cross_batch(
121-
self, emb, labels, indices_tuple, ref_emb, ref_labels, world_size
140+
self,
141+
emb,
142+
labels,
143+
indices_tuple,
144+
ref_emb,
145+
ref_labels,
146+
world_size,
147+
enqueue_idx,
122148
):
123149
if ref_emb is not None or ref_labels is not None:
124150
raise ValueError(
125151
"CrossBatchMemory is not compatible with ref_emb and ref_labels"
126152
)
127153

128154
if world_size <= 1:
129-
return self.loss(emb, labels, indices_tuple)
155+
return self.loss(emb, labels, indices_tuple, enqueue_idx)
130156

131157
all_emb, all_labels, _, _, _ = gather_emb_and_ref(
132158
emb, labels, ref_emb, ref_labels
133159
)
134-
loss = self.loss(all_emb, all_labels, indices_tuple)
160+
enqueue_idx = get_corrected_enqueue_idx(enqueue_idx, emb)
161+
loss = self.loss(all_emb, all_labels, indices_tuple, enqueue_idx)
135162
return loss * world_size
136163

137164

tests/losses/test_cross_batch_memory.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -501,7 +501,6 @@ def test_input_indices_tuple(self):
501501
indices_tuple, self.loss.curr_batch_idx, self.loss.memory_size
502502
)
503503
a1, p, a2, n = self.loss.create_indices_tuple(
504-
batch_size,
505504
embeddings,
506505
labels,
507506
self.loss.embedding_memory,
@@ -520,6 +519,26 @@ def test_input_indices_tuple(self):
520519
self.assertTrue(torch.all(a2 == torch.cat([a2i, a2ii])))
521520
self.assertTrue(torch.all(n == torch.cat([ni, nii])))
522521

522+
def test_reset_queue(self):
523+
self.loss = CrossBatchMemory(
524+
loss=ContrastiveLoss(),
525+
embedding_size=self.embedding_size,
526+
memory_size=self.memory_size,
527+
)
528+
529+
init_emb = torch.zeros(self.memory_size, self.embedding_size)
530+
init_label = torch.zeros(self.memory_size).long()
531+
self.assertTrue(torch.equal(self.loss.embedding_memory, init_emb))
532+
self.assertTrue(torch.equal(self.loss.label_memory, init_label))
533+
534+
self.loss(torch.randn(32, 128), torch.randint(0, 2, size=(32,)))
535+
self.assertTrue(not torch.equal(self.loss.embedding_memory, init_emb))
536+
self.assertTrue(not torch.equal(self.loss.label_memory, init_label))
537+
538+
self.loss.reset_queue()
539+
self.assertTrue(torch.equal(self.loss.embedding_memory, init_emb))
540+
self.assertTrue(torch.equal(self.loss.label_memory, init_label))
541+
523542

524543
if __name__ == "__main__":
525544
unittest.main()

tests/utils/test_distributed.py

Lines changed: 44 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,8 @@ def single_process_function(
6565
original_model,
6666
efficient,
6767
pass_labels_to_loss_fn,
68+
use_xbm_enqueue_idx,
69+
enqueue_idx,
6870
):
6971
setup(rank, world_size)
7072
if TEST_DEVICE == torch.device("cpu"):
@@ -105,6 +107,10 @@ def single_process_function(
105107
indices_tuple = miner_fn(outputs, curr_labels, ref_outputs, curr_ref_labels)
106108
if miner_fn and not pass_labels_to_loss_fn:
107109
loss = loss_fn(outputs, indices_tuple=indices_tuple, ref_emb=ref_outputs)
110+
elif use_xbm_enqueue_idx and isinstance(loss_fn.loss, CrossBatchMemory):
111+
loss = loss_fn(
112+
outputs, curr_labels, indices_tuple, enqueue_idx=enqueue_idx[rank]
113+
)
108114
else:
109115
loss = loss_fn(
110116
outputs, curr_labels, indices_tuple, ref_outputs, curr_ref_labels
@@ -149,6 +155,22 @@ def create_labels(batch_size, world_size, iterations):
149155
]
150156

151157

158+
def create_enqueue_idx(batch_size, world_size):
159+
# enqueue every other embedding
160+
local_enqueue_idx = [
161+
(torch.randint(0, batch_size, size=(batch_size // 4,))).long()
162+
for _ in range(world_size)
163+
]
164+
global_enqueue_idx = []
165+
for i, x in enumerate(local_enqueue_idx):
166+
if i == 0:
167+
global_enqueue_idx.append(x)
168+
else:
169+
global_enqueue_idx.append(x + batch_size)
170+
global_enqueue_idx = torch.cat(global_enqueue_idx, dim=0)
171+
return local_enqueue_idx, global_enqueue_idx
172+
173+
152174
def get_all_outputs_and_labels(inputs, labels, model, iteration):
153175
all_inputs = torch.cat(inputs[iteration], dim=0).to(TEST_DEVICE)
154176
all_labels = torch.cat(labels[iteration], dim=0).to(TEST_DEVICE)
@@ -167,6 +189,7 @@ def loss_and_miner_tester(
167189
loss_kwargs=None,
168190
miner_kwargs=None,
169191
pass_labels_to_loss_fn=True,
192+
use_xbm_enqueue_idx=False,
170193
):
171194
torch.manual_seed(75210)
172195
loss_kwargs = {} if loss_kwargs is None else loss_kwargs
@@ -222,6 +245,10 @@ def loss_and_miner_tester(
222245
)
223246
ref_labels = create_labels(batch_size, world_size, iterations)
224247

248+
local_enqueue_idx, global_enqueue_idx = create_enqueue_idx(
249+
batch_size, world_size
250+
)
251+
225252
for aaa in range(iterations):
226253
optimizer.zero_grad()
227254
all_outputs, all_labels = get_all_outputs_and_labels(
@@ -269,8 +296,11 @@ def loss_and_miner_tester(
269296
all_outputs, all_labels, all_ref_outputs, all_ref_labels
270297
)
271298
if xbm:
299+
enqueue_idx = (
300+
global_enqueue_idx if use_xbm_enqueue_idx else None
301+
)
272302
loss = original_loss_fn(
273-
all_outputs, all_labels, indices_tuple
303+
all_outputs, all_labels, indices_tuple, enqueue_idx
274304
)
275305
else:
276306
loss = original_loss_fn(
@@ -300,6 +330,8 @@ def loss_and_miner_tester(
300330
original_model,
301331
efficient,
302332
pass_labels_to_loss_fn,
333+
use_xbm_enqueue_idx,
334+
local_enqueue_idx,
303335
),
304336
nprocs=world_size,
305337
join=True,
@@ -308,9 +340,17 @@ def loss_and_miner_tester(
308340
def test_distributed_tuple_loss(self):
309341
for xbm in [False, True]:
310342
for use_ref in [False, True]:
311-
if xbm and use_ref:
312-
continue
313-
self.loss_and_miner_tester(ContrastiveLoss, None, False, xbm, use_ref)
343+
for use_xbm_enqueue_idx in [False, True]:
344+
if xbm and use_ref:
345+
continue
346+
self.loss_and_miner_tester(
347+
ContrastiveLoss,
348+
None,
349+
False,
350+
xbm,
351+
use_ref,
352+
use_xbm_enqueue_idx=use_xbm_enqueue_idx,
353+
)
314354

315355
def test_distributed_tuple_loss_and_miner(self):
316356
for xbm in [False, True]:

0 commit comments

Comments
 (0)