Skip to content

Commit 368c686

Browse files
Merge pull request #574 from KevinMusgrave/dev
v1.7.2
2 parents 1b6cdef + bdc5996 commit 368c686

File tree

5 files changed

+36
-6
lines changed

5 files changed

+36
-6
lines changed
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
__version__ = "1.7.1"
1+
__version__ = "1.7.2"

src/pytorch_metric_learning/distances/base_distance.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ def __init__(
1616

1717
def forward(self, query_emb, ref_emb=None):
1818
self.reset_stats()
19+
self.check_shapes(query_emb, ref_emb)
1920
query_emb_normalized = self.maybe_normalize(query_emb)
2021
if ref_emb is None:
2122
ref_emb = query_emb
@@ -87,3 +88,9 @@ def set_stats(self, stats_dict):
8788
for k, v in stats_dict.items():
8889
self.add_to_recordable_attributes(name=k, is_stat=True)
8990
setattr(self, k, v)
91+
92+
def check_shapes(self, query_emb, ref_emb):
93+
if query_emb.ndim != 2 or (ref_emb is not None and ref_emb.ndim != 2):
94+
raise ValueError(
95+
"embeddings must be a 2D tensor of shape (batch_size, embedding_size)"
96+
)

src/pytorch_metric_learning/utils/common_functions.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -412,10 +412,6 @@ def return_input(x):
412412
def check_shapes(embeddings, labels):
413413
if labels is not None and embeddings.shape[0] != labels.shape[0]:
414414
raise ValueError("Number of embeddings must equal number of labels")
415-
if embeddings.ndim != 2:
416-
raise ValueError(
417-
"embeddings must be a 2D tensor of shape (batch_size, embedding_size)"
418-
)
419415
if labels is not None and labels.ndim != 1:
420416
raise ValueError("labels must be a 1D tensor of shape (batch_size,)")
421417

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
import unittest
2+
3+
import torch
4+
5+
from pytorch_metric_learning.distances import BaseDistance, LpDistance
6+
from pytorch_metric_learning.losses import TripletMarginLoss
7+
8+
9+
class CustomDistance(BaseDistance):
10+
def compute_mat(self, query_emb, ref_emb):
11+
return torch.randn(query_emb.shape[0], ref_emb.shape[0])
12+
13+
def check_shapes(self, query_emb, ref_emb):
14+
pass
15+
16+
17+
class TestCustomCheckShape(unittest.TestCase):
18+
def test_custom_embedding_ndim(self):
19+
embeddings = torch.randn(32, 3, 128)
20+
labels = torch.randint(0, 10, size=(32,))
21+
loss_fn = TripletMarginLoss(distance=LpDistance())
22+
23+
with self.assertRaises(ValueError):
24+
loss_fn(embeddings, labels)
25+
26+
loss_fn = TripletMarginLoss(distance=CustomDistance())
27+
loss_fn(embeddings, labels)

tests/utils/test_common_functions.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@ def test_collect_stats_flag(self):
6969

7070
def test_check_shapes(self):
7171
embeddings = torch.randn(32, 512, 3)
72-
labels = torch.randn(32)
72+
labels = torch.randint(0, 10, size=(32,))
7373
loss_fn = TripletMarginLoss()
7474

7575
# embeddings is 3-dimensional

0 commit comments

Comments
 (0)