Skip to content

Commit bdc5996

Browse files
committed
Fixed check_shape test
1 parent 6448387 commit bdc5996

File tree

4 files changed

+30
-26
lines changed

4 files changed

+30
-26
lines changed

src/pytorch_metric_learning/distances/base_distance.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ def __init__(
1616

1717
def forward(self, query_emb, ref_emb=None):
1818
self.reset_stats()
19-
self.check_embeddings_ndim(query_emb, ref_emb)
19+
self.check_shapes(query_emb, ref_emb)
2020
query_emb_normalized = self.maybe_normalize(query_emb)
2121
if ref_emb is None:
2222
ref_emb = query_emb
@@ -89,7 +89,7 @@ def set_stats(self, stats_dict):
8989
self.add_to_recordable_attributes(name=k, is_stat=True)
9090
setattr(self, k, v)
9191

92-
def check_embeddings_ndim(self, query_emb, ref_emb):
92+
def check_shapes(self, query_emb, ref_emb):
9393
if query_emb.ndim != 2 or (ref_emb is not None and ref_emb.ndim != 2):
9494
raise ValueError(
9595
"embeddings must be a 2D tensor of shape (batch_size, embedding_size)"
Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
import unittest
2+
3+
import torch
4+
5+
from pytorch_metric_learning.distances import BaseDistance, LpDistance
6+
from pytorch_metric_learning.losses import TripletMarginLoss
7+
8+
9+
class CustomDistance(BaseDistance):
10+
def compute_mat(self, query_emb, ref_emb):
11+
return torch.randn(query_emb.shape[0], ref_emb.shape[0])
12+
13+
def check_shapes(self, query_emb, ref_emb):
14+
pass
15+
16+
17+
class TestCustomCheckShape(unittest.TestCase):
18+
def test_custom_embedding_ndim(self):
19+
embeddings = torch.randn(32, 3, 128)
20+
labels = torch.randint(0, 10, size=(32,))
21+
loss_fn = TripletMarginLoss(distance=LpDistance())
22+
23+
with self.assertRaises(ValueError):
24+
loss_fn(embeddings, labels)
25+
26+
loss_fn = TripletMarginLoss(distance=CustomDistance())
27+
loss_fn(embeddings, labels)

tests/distances/test_custom_embedding_ndim.py

Lines changed: 0 additions & 23 deletions
This file was deleted.

tests/utils/test_common_functions.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@ def test_collect_stats_flag(self):
6969

7070
def test_check_shapes(self):
7171
embeddings = torch.randn(32, 512, 3)
72-
labels = torch.randn(32)
72+
labels = torch.randint(0, 10, size=(32,))
7373
loss_fn = TripletMarginLoss()
7474

7575
# embeddings is 3-dimensional

0 commit comments

Comments
 (0)