Skip to content

Commit a17055f

Browse files
committed
Don't allow both indices_tuple and enqueue_idx
1 parent c03947a commit a17055f

File tree

2 files changed

+11
-11
lines changed

2 files changed

+11
-11
lines changed

src/pytorch_metric_learning/losses/cross_batch_memory.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,8 @@ def __init__(self, loss, embedding_size, memory_size=1024, miner=None, **kwargs)
1818
)
1919

2020
def forward(self, embeddings, labels, indices_tuple=None, enqueue_idx=None):
21+
if indices_tuple is not None and enqueue_idx is not None:
22+
raise ValueError("indices_tuple and enqueue_idx are mutually exclusive")
2123
if enqueue_idx is not None:
2224
assert len(enqueue_idx) <= len(self.embedding_memory)
2325
assert len(enqueue_idx) < len(embeddings)

tests/utils/test_distributed.py

Lines changed: 9 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -358,17 +358,15 @@ def test_distributed_tuple_loss_and_miner(self):
358358
for pass_labels_to_loss_fn in [False, True]:
359359
if xbm and use_ref or xbm and not pass_labels_to_loss_fn:
360360
continue
361-
for use_xbm_enqueue_idx in [False, True]:
362-
self.loss_and_miner_tester(
363-
ContrastiveLoss,
364-
PairMarginMiner,
365-
False,
366-
xbm,
367-
use_ref,
368-
miner_kwargs={"pos_margin": 0.5, "neg_margin": 0.5},
369-
pass_labels_to_loss_fn=pass_labels_to_loss_fn,
370-
use_xbm_enqueue_idx=use_xbm_enqueue_idx,
371-
)
361+
self.loss_and_miner_tester(
362+
ContrastiveLoss,
363+
PairMarginMiner,
364+
False,
365+
xbm,
366+
use_ref,
367+
miner_kwargs={"pos_margin": 0.5, "neg_margin": 0.5},
368+
pass_labels_to_loss_fn=pass_labels_to_loss_fn,
369+
)
372370

373371
def test_distributed_tuple_loss_efficient(self):
374372
for use_ref in [False, True]:

0 commit comments

Comments
 (0)