Skip to content

Commit 557a7ca

Browse files
Merge pull request #572 from KevinMusgrave/dev
v1.7.1
2 parents bad1a31 + f1c1d9f commit 557a7ca

File tree

8 files changed

+66
-30
lines changed

8 files changed

+66
-30
lines changed

docs/reducers.md

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,13 @@ reducers.PerAnchorReducer(reducer=None,
122122
* ```num_per_row``` is a size N array which indicates how many non-zero losses there are per-row of ```x```.
123123

124124

125+
## SumReducer
126+
This will return the sum of the losses.
127+
```python
128+
reducers.SumReducer(**kwargs)
129+
```
130+
131+
125132
## ThresholdReducer
126133
This computes the average loss, using only the losses that fall within a specified range.
127134

setup.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,6 @@
4242
"scikit-learn",
4343
"tqdm",
4444
"torch >= 1.6.0",
45-
"torchvision",
4645
],
4746
extras_require={
4847
"with-hooks": extras_require_with_hooks,
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
__version__ = "1.7.0"
1+
__version__ = "1.7.1"

src/pytorch_metric_learning/reducers/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,4 +6,5 @@
66
from .mean_reducer import MeanReducer
77
from .multiple_reducers import MultipleReducers
88
from .per_anchor_reducer import PerAnchorReducer
9+
from .sum_reducer import SumReducer
910
from .threshold_reducer import ThresholdReducer
Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
import torch
2+
3+
from pytorch_metric_learning.reducers import MeanReducer
4+
5+
6+
class SumReducer(MeanReducer):
7+
def element_reduction(self, losses, *_):
8+
return torch.sum(losses)

src/pytorch_metric_learning/utils/distributed.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -30,24 +30,27 @@ def all_gather_embeddings_and_labels(emb, labels):
3030
if not is_distributed():
3131
return None, None
3232
ref_emb = all_gather(emb)
33-
ref_labels = all_gather(labels)
33+
ref_labels = all_gather(labels) if labels is not None else None
3434
return ref_emb, ref_labels
3535

3636

3737
def gather(emb, labels):
3838
device = emb.device
39-
labels = c_f.to_device(labels, device=device)
39+
if labels is not None:
40+
labels = c_f.to_device(labels, device=device)
4041
dist_emb, dist_labels = all_gather_embeddings_and_labels(emb, labels)
4142
all_emb = torch.cat([emb, dist_emb], dim=0)
42-
all_labels = torch.cat([labels, dist_labels], dim=0)
43+
all_labels = (
44+
torch.cat([labels, dist_labels], dim=0) if dist_labels is not None else None
45+
)
4346
return all_emb, all_labels, labels
4447

4548

4649
def gather_emb_and_ref(emb, labels, ref_emb=None, ref_labels=None):
4750
all_emb, all_labels, labels = gather(emb, labels)
4851
all_ref_emb, all_ref_labels = None, None
4952

50-
if ref_emb is not None and ref_labels is not None:
53+
if ref_emb is not None:
5154
all_ref_emb, all_ref_labels, _ = gather(ref_emb, ref_labels)
5255

5356
return all_emb, all_labels, all_ref_emb, all_ref_labels, labels
@@ -81,7 +84,9 @@ def __init__(self, loss, efficient=False):
8184
self.loss = loss
8285
self.efficient = efficient
8386

84-
def forward(self, emb, labels, indices_tuple=None, ref_emb=None, ref_labels=None):
87+
def forward(
88+
self, emb, labels=None, indices_tuple=None, ref_emb=None, ref_labels=None
89+
):
8590
world_size = torch.distributed.get_world_size()
8691
common_args = [emb, labels, indices_tuple, ref_emb, ref_labels, world_size]
8792
if isinstance(self.loss, CrossBatchMemory):
@@ -99,7 +104,8 @@ def forward_regular_loss(
99104
)
100105

101106
if self.efficient:
102-
all_labels = select_ref_or_regular(all_labels, all_ref_labels)
107+
if all_labels is not None:
108+
all_labels = select_ref_or_regular(all_labels, all_ref_labels)
103109
all_emb = select_ref_or_regular(all_emb, all_ref_emb)
104110
if indices_tuple is None:
105111
indices_tuple = get_indices_tuple(labels, all_labels)

tests/reducers/test_mean_reducer.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,14 +2,15 @@
22

33
import torch
44

5-
from pytorch_metric_learning.reducers import MeanReducer
5+
from pytorch_metric_learning.reducers import MeanReducer, SumReducer
66

77
from .. import TEST_DEVICE, TEST_DTYPES
88

99

1010
class TestMeanReducer(unittest.TestCase):
1111
def test_mean_reducer(self):
1212
reducer = MeanReducer()
13+
sum_reducer = SumReducer()
1314
batch_size = 100
1415
embedding_size = 64
1516
for dtype in TEST_DTYPES:
@@ -42,3 +43,7 @@ def test_mean_reducer(self):
4243
output = reducer(loss_dict, embeddings, labels)
4344
correct_output = torch.mean(losses)
4445
self.assertTrue(output == correct_output)
46+
47+
output = sum_reducer(loss_dict, embeddings, labels)
48+
correct_output = torch.sum(losses)
49+
self.assertTrue(output == correct_output)

tests/utils/test_distributed.py

Lines changed: 31 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,7 @@ def single_process_function(
6464
miner_fn,
6565
original_model,
6666
efficient,
67+
pass_labels_to_loss_fn,
6768
):
6869
setup(rank, world_size)
6970
if TEST_DEVICE == torch.device("cpu"):
@@ -102,9 +103,12 @@ def single_process_function(
102103
indices_tuple = None
103104
if miner_fn:
104105
indices_tuple = miner_fn(outputs, curr_labels, ref_outputs, curr_ref_labels)
105-
loss = loss_fn(
106-
outputs, curr_labels, indices_tuple, ref_outputs, curr_ref_labels
107-
)
106+
if miner_fn and not pass_labels_to_loss_fn:
107+
loss = loss_fn(outputs, indices_tuple=indices_tuple, ref_emb=ref_outputs)
108+
else:
109+
loss = loss_fn(
110+
outputs, curr_labels, indices_tuple, ref_outputs, curr_ref_labels
111+
)
108112

109113
dist.barrier()
110114
loss.backward()
@@ -162,6 +166,7 @@ def loss_and_miner_tester(
162166
use_ref,
163167
loss_kwargs=None,
164168
miner_kwargs=None,
169+
pass_labels_to_loss_fn=True,
165170
):
166171
torch.manual_seed(75210)
167172
loss_kwargs = {} if loss_kwargs is None else loss_kwargs
@@ -294,6 +299,7 @@ def loss_and_miner_tester(
294299
miner_fn,
295300
original_model,
296301
efficient,
302+
pass_labels_to_loss_fn,
297303
),
298304
nprocs=world_size,
299305
join=True,
@@ -309,31 +315,35 @@ def test_distributed_tuple_loss(self):
309315
def test_distributed_tuple_loss_and_miner(self):
310316
for xbm in [False, True]:
311317
for use_ref in [False, True]:
312-
if xbm and use_ref:
313-
continue
314-
self.loss_and_miner_tester(
315-
ContrastiveLoss,
316-
PairMarginMiner,
317-
False,
318-
xbm,
319-
use_ref,
320-
miner_kwargs={"pos_margin": 0.5, "neg_margin": 0.5},
321-
)
318+
for pass_labels_to_loss_fn in [False, True]:
319+
if xbm and use_ref or xbm and not pass_labels_to_loss_fn:
320+
continue
321+
self.loss_and_miner_tester(
322+
ContrastiveLoss,
323+
PairMarginMiner,
324+
False,
325+
xbm,
326+
use_ref,
327+
miner_kwargs={"pos_margin": 0.5, "neg_margin": 0.5},
328+
pass_labels_to_loss_fn=pass_labels_to_loss_fn,
329+
)
322330

323331
def test_distributed_tuple_loss_efficient(self):
324332
for use_ref in [False, True]:
325333
self.loss_and_miner_tester(ContrastiveLoss, None, True, False, use_ref)
326334

327335
def test_distributed_tuple_loss_and_miner_efficient(self):
328336
for use_ref in [False, True]:
329-
self.loss_and_miner_tester(
330-
ContrastiveLoss,
331-
PairMarginMiner,
332-
True,
333-
False,
334-
use_ref,
335-
miner_kwargs={"pos_margin": 0.5, "neg_margin": 0.5},
336-
)
337+
for pass_labels_to_loss_fn in [False, True]:
338+
self.loss_and_miner_tester(
339+
ContrastiveLoss,
340+
PairMarginMiner,
341+
True,
342+
False,
343+
use_ref,
344+
miner_kwargs={"pos_margin": 0.5, "neg_margin": 0.5},
345+
pass_labels_to_loss_fn=pass_labels_to_loss_fn,
346+
)
337347

338348
def test_single_proc(self):
339349
setup(0, 1)

0 commit comments

Comments
 (0)