Skip to content

Commit

Permalink
impl contrastive loss
Browse files Browse the repository at this point in the history
  • Loading branch information
nnaakkaaii committed Jun 28, 2024
1 parent 35177d1 commit 63f3846
Show file tree
Hide file tree
Showing 5 changed files with 643 additions and 1 deletion.
6 changes: 6 additions & 0 deletions hrdae/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
PJC3dLossOption,
TemporalSimilarityLossOption,
WeightedMSELossOption,
ContrastiveLossOption,
)
from .models.networks import (
AutoEncoder2dNetworkOption,
Expand Down Expand Up @@ -139,6 +140,11 @@
name="tsim",
node=TemporalSimilarityLossOption,
)
cs.store(
group="config/experiment/model/loss",
name="contrastive",
node=ContrastiveLossOption,
)
cs.store(
group="config/experiment/model/network",
name="autoencoder2d",
Expand Down
8 changes: 8 additions & 0 deletions hrdae/models/losses/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

from torch import Tensor, float32, nn, tensor

from .contrastive import ContrastiveLossOption, create_contrastive_loss
from .mstd import MStdLossOption, create_mstd_loss
from .option import LossOption
from .pjc import PJC2dLossOption, PJC3dLossOption, create_pjc2d_loss, create_pjc3d_loss
Expand Down Expand Up @@ -41,6 +42,13 @@ def create_loss(opt: LossOption) -> nn.Module:
return create_tsim_loss()
if isinstance(opt, MStdLossOption) and type(opt) is MStdLossOption:
return create_mstd_loss()
if (
isinstance(opt, ContrastiveLossOption)
and type(opt) is ContrastiveLossOption
):
return create_contrastive_loss(opt)
if isinstance(opt, MStdLossOption) and type(opt) is MStdLossOption:
return create_mstd_loss()
raise NotImplementedError(f"{opt.__class__.__name__} is not implemented")


Expand Down
42 changes: 42 additions & 0 deletions hrdae/models/losses/contrastive.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
from dataclasses import dataclass

import torch
from torch import Tensor, nn

from .option import LossOption


@dataclass
class ContrastiveLossOption(LossOption):
margin: float = 0.1


def create_contrastive_loss(opt: ContrastiveLossOption) -> nn.Module:
return MStdLoss(opt.margin)


class MStdLoss(nn.Module):
def __init__(self, margin: float = 0.1) -> None:
super().__init__()
self.margin = margin

@property
def required_kwargs(self) -> list[str]:
return ["latent"]

def forward(self, input: Tensor, target: Tensor, latent: list[Tensor]) -> Tensor:
feature = latent[0]
b, t = feature.size()[:2]
feature = feature.view(b * t, -1)
square_distances = torch.cdist(feature, feature, p=2)

labels = 1 - torch.eye(b*t).to(input.device)
for i in range(b):
labels[i*t:(i+1)*t, i*t:(i+1)*t] = 0

positive_loss = (1 - labels) * 0.5 * torch.pow(square_distances, 2)
negative_loss = labels * 0.5 * torch.pow(torch.clamp(self.margin - square_distances, min=0.0), 2)

loss = torch.sum(positive_loss + negative_loss) / (b * t * (b * t - 1))

return loss
Loading

0 comments on commit 63f3846

Please sign in to comment.