Skip to content

Commit

Permalink
impl test for triplet loss
Browse files Browse the repository at this point in the history
  • Loading branch information
nnaakkaaii committed Jul 14, 2024
1 parent 37abb19 commit 6553a71
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 1 deletion.
23 changes: 23 additions & 0 deletions test/models/losses/test_triplet.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
from torch import randn

from hrdae.models.losses import create_loss, TripletLossOption


def test_triplet_loss():
triplet_loss = create_loss(TripletLossOption())
b, n, c, w, h = 8, 10, 8, 4, 4

input = randn(b, n, w, h)
target = randn(b, n, w, h)
latent = randn(b, 1, c, w, h)
positive = randn(b, n, c, w, h)
negative = randn(b, n, c, w, h)

loss = triplet_loss(
input,
target,
latent=[latent],
positive=[positive],
negative=[negative],
)
assert loss.size() == ()

0 comments on commit 6553a71

Please sign in to comment.