Skip to content

Commit 6448387

Browse files
committed
Moved ndim check into BaseDistance
1 parent f1c1d9f commit 6448387

File tree

4 files changed

+31
-5
lines changed

4 files changed

+31
-5
lines changed
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
__version__ = "1.7.1"
1+
__version__ = "1.7.2"

src/pytorch_metric_learning/distances/base_distance.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ def __init__(
1616

1717
def forward(self, query_emb, ref_emb=None):
1818
self.reset_stats()
19+
self.check_embeddings_ndim(query_emb, ref_emb)
1920
query_emb_normalized = self.maybe_normalize(query_emb)
2021
if ref_emb is None:
2122
ref_emb = query_emb
@@ -87,3 +88,9 @@ def set_stats(self, stats_dict):
8788
for k, v in stats_dict.items():
8889
self.add_to_recordable_attributes(name=k, is_stat=True)
8990
setattr(self, k, v)
91+
92+
def check_embeddings_ndim(self, query_emb, ref_emb):
93+
if query_emb.ndim != 2 or (ref_emb is not None and ref_emb.ndim != 2):
94+
raise ValueError(
95+
"embeddings must be a 2D tensor of shape (batch_size, embedding_size)"
96+
)

src/pytorch_metric_learning/utils/common_functions.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -412,10 +412,6 @@ def return_input(x):
412412
def check_shapes(embeddings, labels):
413413
if labels is not None and embeddings.shape[0] != labels.shape[0]:
414414
raise ValueError("Number of embeddings must equal number of labels")
415-
if embeddings.ndim != 2:
416-
raise ValueError(
417-
"embeddings must be a 2D tensor of shape (batch_size, embedding_size)"
418-
)
419415
if labels is not None and labels.ndim != 1:
420416
raise ValueError("labels must be a 1D tensor of shape (batch_size,)")
421417

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
import unittest
2+
3+
import torch
4+
5+
from pytorch_metric_learning.distances import BaseDistance, LpDistance
6+
7+
8+
class CustomDistance(BaseDistance):
9+
def compute_mat(self, query_emb, ref_emb):
10+
pass
11+
12+
def check_embeddings_ndim(self, query_emb, ref_emb):
13+
pass
14+
15+
16+
class TestCustomEmbeddingNdim(unittest.TestCase):
17+
def test_custom_embedding_ndim(self):
18+
embeddings = torch.randn(32, 3, 128)
19+
20+
dist_fn = LpDistance()
21+
22+
with self.assertRaises(ValueError):
23+
dist_fn(embeddings)

0 commit comments

Comments
 (0)