From 63f3846d9dae6e05f6db19d9dc9b74ac4bbc982c Mon Sep 17 00:00:00 2001 From: nnaakkaaii Date: Fri, 28 Jun 2024 12:08:11 +0900 Subject: [PATCH] impl contrastive loss --- hrdae/__init__.py | 6 + hrdae/conf | 2 +- hrdae/models/losses/__init__.py | 8 + hrdae/models/losses/contrastive.py | 42 +++ notebook/mmnist.ipynb | 586 +++++++++++++++++++++++++++++ 5 files changed, 643 insertions(+), 1 deletion(-) create mode 100644 hrdae/models/losses/contrastive.py create mode 100644 notebook/mmnist.ipynb diff --git a/hrdae/__init__.py b/hrdae/__init__.py index 27a7368..457613e 100644 --- a/hrdae/__init__.py +++ b/hrdae/__init__.py @@ -25,6 +25,7 @@ PJC3dLossOption, TemporalSimilarityLossOption, WeightedMSELossOption, + ContrastiveLossOption, ) from .models.networks import ( AutoEncoder2dNetworkOption, @@ -139,6 +140,11 @@ name="tsim", node=TemporalSimilarityLossOption, ) +cs.store( + group="config/experiment/model/loss", + name="contrastive", + node=ContrastiveLossOption, +) cs.store( group="config/experiment/model/network", name="autoencoder2d", diff --git a/hrdae/conf b/hrdae/conf index 824d92a..176503d 160000 --- a/hrdae/conf +++ b/hrdae/conf @@ -1 +1 @@ -Subproject commit 824d92a1bc5bc76b92673c28c273c8efcb10403e +Subproject commit 176503d22324aecb00cffd7c3b24cd0b3f3beb6b diff --git a/hrdae/models/losses/__init__.py b/hrdae/models/losses/__init__.py index 472d4cd..9d942c7 100644 --- a/hrdae/models/losses/__init__.py +++ b/hrdae/models/losses/__init__.py @@ -3,6 +3,7 @@ from torch import Tensor, float32, nn, tensor +from .contrastive import ContrastiveLossOption, create_contrastive_loss from .mstd import MStdLossOption, create_mstd_loss from .option import LossOption from .pjc import PJC2dLossOption, PJC3dLossOption, create_pjc2d_loss, create_pjc3d_loss @@ -41,6 +42,13 @@ def create_loss(opt: LossOption) -> nn.Module: return create_tsim_loss() if isinstance(opt, MStdLossOption) and type(opt) is MStdLossOption: return create_mstd_loss() + if ( + isinstance(opt, ContrastiveLossOption) + and type(opt) is ContrastiveLossOption + ): + return create_contrastive_loss(opt) + if isinstance(opt, MStdLossOption) and type(opt) is MStdLossOption: + return create_mstd_loss() raise NotImplementedError(f"{opt.__class__.__name__} is not implemented") diff --git a/hrdae/models/losses/contrastive.py b/hrdae/models/losses/contrastive.py new file mode 100644 index 0000000..3a7afba --- /dev/null +++ b/hrdae/models/losses/contrastive.py @@ -0,0 +1,42 @@ +from dataclasses import dataclass + +import torch +from torch import Tensor, nn + +from .option import LossOption + + +@dataclass +class ContrastiveLossOption(LossOption): + margin: float = 0.1 + + +def create_contrastive_loss(opt: ContrastiveLossOption) -> nn.Module: + return MStdLoss(opt.margin) + + +class MStdLoss(nn.Module): + def __init__(self, margin: float = 0.1) -> None: + super().__init__() + self.margin = margin + + @property + def required_kwargs(self) -> list[str]: + return ["latent"] + + def forward(self, input: Tensor, target: Tensor, latent: list[Tensor]) -> Tensor: + feature = latent[0] + b, t = feature.size()[:2] + feature = feature.view(b * t, -1) + square_distances = torch.cdist(feature, feature, p=2) + + labels = 1 - torch.eye(b*t).to(input.device) + for i in range(b): + labels[i*t:(i+1)*t, i*t:(i+1)*t] = 0 + + positive_loss = (1 - labels) * 0.5 * torch.pow(square_distances, 2) + negative_loss = labels * 0.5 * torch.pow(torch.clamp(self.margin - square_distances, min=0.0), 2) + + loss = torch.sum(positive_loss + negative_loss) / (b * t * (b * t - 1)) + + return loss diff --git a/notebook/mmnist.ipynb b/notebook/mmnist.ipynb new file mode 100644 index 0000000..1abcff3 --- /dev/null +++ b/notebook/mmnist.ipynb @@ -0,0 +1,586 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "import sys\n", + "sys.path.append(\"..\")" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [], + "source": [ + "import torch\n", + "from torch.utils.data import DataLoader\n", + "from torchvision import transforms\n", + "from hrdae.models.networks import create_network, RDAE2dOption\n", + "from hrdae.models.networks.motion_encoder import MotionRNNEncoder1dOption\n", + "from hrdae.models.networks.rnn import TCN1dOption\n", + "from hrdae.dataloaders.datasets import create_dataset, MovingMNISTDatasetOption\n", + "from hrdae.dataloaders.transforms import create_transform, MinMaxNormalizationOption\n" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "metadata": {}, + "outputs": [], + "source": [ + "net = create_network(\n", + " 1,\n", + " opt=RDAE2dOption(\n", + " activation=\"sigmoid\",\n", + " aggregator=\"addition\",\n", + " cycle=False,\n", + " in_channels=1,\n", + " hidden_channels=64,\n", + " latent_dim=8,\n", + " conv_params=[{\"kernel_size\": [3], \"stride\": [2], \"padding\": [1]}] * 3,\n", + " motion_encoder=MotionRNNEncoder1dOption(\n", + " in_channels=5,\n", + " hidden_channels=64,\n", + " conv_params=[{\"kernel_size\": [3], \"stride\": [2], \"padding\": [1]}] * 3,\n", + " deconv_params=[{\"kernel_size\": [3], \"stride\": [1, 2], \"padding\": [1]}] * 3,\n", + " rnn=TCN1dOption(\n", + " num_layers=3,\n", + " image_size=8,\n", + " kernel_size=4,\n", + " dropout=0.1,\n", + " )\n", + " )\n", + " )\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "" + ] + }, + "execution_count": 14, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "net.load_state_dict(torch.load(\"../results/BasicDataLoaderOption/PVRModelOption/rdae2d/2024-06-27_21-27-00/weights/best_model.pth\"))" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "RDAE2d(\n", + " (content_encoder): Encoder2d(\n", + " (cnn): ConvModule2d(\n", + " (layers): ModuleList(\n", + " (0): Sequential(\n", + " (0): ConvBlock2d(\n", + " (conv): Conv2d(1, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))\n", + " )\n", + " (1): IdenticalConvBlock2d(\n", + " (conv): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", + " )\n", + " )\n", + " (1-2): 2 x Sequential(\n", + " (0): ConvBlock2d(\n", + " (conv): Conv2d(64, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))\n", + " )\n", + " (1): IdenticalConvBlock2d(\n", + " (conv): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", + " )\n", + " )\n", + " (3): ConvBlock2d(\n", + " (conv): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", + " )\n", + " )\n", + " )\n", + " (bottleneck): PixelWiseConv2d(\n", + " (conv): Conv2d(64, 8, kernel_size=(1, 1), stride=(1, 1))\n", + " )\n", + " )\n", + " (motion_encoder): MotionRNNEncoder1d(\n", + " (cnn): ConvModule1d(\n", + " (layers): ModuleList(\n", + " (0): Sequential(\n", + " (0): ConvBlock1d(\n", + " (conv): Conv1d(5, 64, kernel_size=(3,), stride=(2,), padding=(1,))\n", + " )\n", + " (1): IdenticalConvBlock1d(\n", + " (conv): Conv1d(64, 64, kernel_size=(3,), stride=(1,), padding=(1,))\n", + " )\n", + " )\n", + " (1): Sequential(\n", + " (0): ConvBlock1d(\n", + " (conv): Conv1d(64, 64, kernel_size=(3,), stride=(2,), padding=(1,))\n", + " )\n", + " (1): IdenticalConvBlock1d(\n", + " (conv): Conv1d(64, 64, kernel_size=(3,), stride=(1,), padding=(1,))\n", + " )\n", + " )\n", + " (2): ConvBlock1d(\n", + " (conv): Conv1d(64, 64, kernel_size=(3,), stride=(2,), padding=(1,))\n", + " )\n", + " )\n", + " )\n", + " (rnn): TCN1d(\n", + " (rnn): TCN1d(\n", + " (tcn): TCN(\n", + " (network): ModuleList(\n", + " (0): TemporalBlock(\n", + " (conv1): ParametrizedCausalConv1d(\n", + " 512, 512, kernel_size=(4,), stride=(1,)\n", + " (parametrizations): ModuleDict(\n", + " (weight): ParametrizationList(\n", + " (0): _WeightNorm()\n", + " )\n", + " )\n", + " )\n", + " (conv2): ParametrizedCausalConv1d(\n", + " 512, 512, kernel_size=(4,), stride=(1,)\n", + " (parametrizations): ModuleDict(\n", + " (weight): ParametrizationList(\n", + " (0): _WeightNorm()\n", + " )\n", + " )\n", + " )\n", + " (activation1): ReLU()\n", + " (activation2): ReLU()\n", + " (activation_final): ReLU()\n", + " (dropout1): Dropout(p=0.1, inplace=False)\n", + " (dropout2): Dropout(p=0.1, inplace=False)\n", + " )\n", + " (1): TemporalBlock(\n", + " (conv1): ParametrizedCausalConv1d(\n", + " 512, 512, kernel_size=(4,), stride=(1,), dilation=(2,)\n", + " (parametrizations): ModuleDict(\n", + " (weight): ParametrizationList(\n", + " (0): _WeightNorm()\n", + " )\n", + " )\n", + " )\n", + " (conv2): ParametrizedCausalConv1d(\n", + " 512, 512, kernel_size=(4,), stride=(1,), dilation=(2,)\n", + " (parametrizations): ModuleDict(\n", + " (weight): ParametrizationList(\n", + " (0): _WeightNorm()\n", + " )\n", + " )\n", + " )\n", + " (activation1): ReLU()\n", + " (activation2): ReLU()\n", + " (activation_final): ReLU()\n", + " (dropout1): Dropout(p=0.1, inplace=False)\n", + " (dropout2): Dropout(p=0.1, inplace=False)\n", + " )\n", + " (2): TemporalBlock(\n", + " (conv1): ParametrizedCausalConv1d(\n", + " 512, 512, kernel_size=(4,), stride=(1,), dilation=(4,)\n", + " (parametrizations): ModuleDict(\n", + " (weight): ParametrizationList(\n", + " (0): _WeightNorm()\n", + " )\n", + " )\n", + " )\n", + " (conv2): ParametrizedCausalConv1d(\n", + " 512, 512, kernel_size=(4,), stride=(1,), dilation=(4,)\n", + " (parametrizations): ModuleDict(\n", + " (weight): ParametrizationList(\n", + " (0): _WeightNorm()\n", + " )\n", + " )\n", + " )\n", + " (activation1): ReLU()\n", + " (activation2): ReLU()\n", + " (activation_final): ReLU()\n", + " (dropout1): Dropout(p=0.1, inplace=False)\n", + " (dropout2): Dropout(p=0.1, inplace=False)\n", + " )\n", + " )\n", + " )\n", + " )\n", + " )\n", + " (tcnn): ConvModule2d(\n", + " (layers): ModuleList(\n", + " (0-1): 2 x Sequential(\n", + " (0): ConvBlock2d(\n", + " (conv): ConvTranspose2d(64, 64, kernel_size=(3, 3), stride=(1, 2), padding=(1, 1), output_padding=(0, 1))\n", + " )\n", + " (1): IdenticalConvBlock2d(\n", + " (conv): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", + " )\n", + " )\n", + " (2): ConvBlock2d(\n", + " (conv): ConvTranspose2d(64, 64, kernel_size=(3, 3), stride=(1, 2), padding=(1, 1), output_padding=(0, 1))\n", + " )\n", + " )\n", + " )\n", + " (bottleneck): PixelWiseConv2d(\n", + " (conv): Conv2d(64, 8, kernel_size=(1, 1), stride=(1, 1))\n", + " )\n", + " )\n", + " (decoder): Decoder2d(\n", + " (bottleneck): PixelWiseConv2d(\n", + " (conv): Conv2d(8, 64, kernel_size=(1, 1), stride=(1, 1))\n", + " )\n", + " (cnn): ConvModule2d(\n", + " (layers): ModuleList(\n", + " (0-1): 2 x Sequential(\n", + " (0): ConvBlock2d(\n", + " (conv): ConvTranspose2d(64, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), output_padding=(1, 1))\n", + " )\n", + " (1): IdenticalConvBlock2d(\n", + " (conv): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", + " )\n", + " )\n", + " (2): ConvBlock2d(\n", + " (conv): ConvTranspose2d(64, 1, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), output_padding=(1, 1))\n", + " )\n", + " )\n", + " )\n", + " )\n", + " (activation): Sigmoid()\n", + " (aggregator): AdditionAggregator2d()\n", + ")" + ] + }, + "execution_count": 15, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "net.eval()" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "metadata": {}, + "outputs": [], + "source": [ + "dataset = create_dataset(\n", + " opt=MovingMNISTDatasetOption(\n", + " sequential=True,\n", + " root=\"../data\",\n", + " slice_index=[16, 24, 32, 40, 48],\n", + " content_phase=\"0\",\n", + " motion_phase=\"0\",\n", + " motion_aggregation=\"none\",\n", + " ),\n", + " transform=transforms.Compose([\n", + " create_transform(MinMaxNormalizationOption()),\n", + " ]),\n", + " is_train=False,\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "metadata": {}, + "outputs": [], + "source": [ + "loader = DataLoader(dataset=dataset, batch_size=10, shuffle=False)" + ] + }, + { + "cell_type": "code", + "execution_count": 52, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "0 0 tensor(1.4039e-10, grad_fn=)\n", + "0 1 tensor(2.2091e-10, grad_fn=)\n", + "0 2 tensor(1.8385e-10, grad_fn=)\n", + "0 3 tensor(9.5200e-11, grad_fn=)\n", + "0 4 tensor(3.6294e-10, grad_fn=)\n", + "0 5 tensor(3.1136e-10, grad_fn=)\n", + "0 6 tensor(1.1639e-10, grad_fn=)\n", + "0 7 tensor(7.3829e-11, grad_fn=)\n", + "0 8 tensor(4.8012e-11, grad_fn=)\n", + "0 9 tensor(5.2832e-11, grad_fn=)\n", + "1 0 tensor(2.2091e-10, grad_fn=)\n", + "1 1 tensor(3.0423e-10, grad_fn=)\n", + "1 2 tensor(7.5801e-11, grad_fn=)\n", + "1 3 tensor(3.8858e-10, grad_fn=)\n", + "1 4 tensor(6.2034e-10, grad_fn=)\n", + "1 5 tensor(2.4635e-10, grad_fn=)\n", + "1 6 tensor(1.1119e-10, grad_fn=)\n", + "1 7 tensor(2.3004e-10, grad_fn=)\n", + "1 8 tensor(2.6573e-10, grad_fn=)\n", + "1 9 tensor(3.0289e-10, grad_fn=)\n", + "2 0 tensor(1.8385e-10, grad_fn=)\n", + "2 1 tensor(7.5801e-11, grad_fn=)\n", + "2 2 tensor(3.6883e-10, grad_fn=)\n", + "2 3 tensor(3.2594e-10, grad_fn=)\n", + "2 4 tensor(5.0680e-10, grad_fn=)\n", + "2 5 tensor(1.4440e-10, grad_fn=)\n", + "2 6 tensor(1.3460e-10, grad_fn=)\n", + "2 7 tensor(2.1659e-10, grad_fn=)\n", + "2 8 tensor(1.9823e-10, grad_fn=)\n", + "2 9 tensor(2.4722e-10, grad_fn=)\n", + "3 0 tensor(9.5200e-11, grad_fn=)\n", + "3 1 tensor(3.8858e-10, grad_fn=)\n", + "3 2 tensor(3.2594e-10, grad_fn=)\n", + "3 3 tensor(2.7341e-10, grad_fn=)\n", + "3 4 tensor(2.1864e-10, grad_fn=)\n", + "3 5 tensor(3.4957e-10, grad_fn=)\n", + "3 6 tensor(1.7128e-10, grad_fn=)\n", + "3 7 tensor(8.5589e-11, grad_fn=)\n", + "3 8 tensor(1.2116e-10, grad_fn=)\n", + "3 9 tensor(7.4654e-11, grad_fn=)\n", + "4 0 tensor(3.6294e-10, grad_fn=)\n", + "4 1 tensor(6.2034e-10, grad_fn=)\n", + "4 2 tensor(5.0680e-10, grad_fn=)\n", + "4 3 tensor(2.1864e-10, grad_fn=)\n", + "4 4 tensor(5.7858e-10, grad_fn=)\n", + "4 5 tensor(3.6352e-10, grad_fn=)\n", + "4 6 tensor(3.7786e-10, grad_fn=)\n", + "4 7 tensor(2.8377e-10, grad_fn=)\n", + "4 8 tensor(3.4256e-10, grad_fn=)\n", + "4 9 tensor(2.5163e-10, grad_fn=)\n", + "5 0 tensor(3.1136e-10, grad_fn=)\n", + "5 1 tensor(2.4635e-10, grad_fn=)\n", + "5 2 tensor(1.4440e-10, grad_fn=)\n", + "5 3 tensor(3.4957e-10, grad_fn=)\n", + "5 4 tensor(3.6352e-10, grad_fn=)\n", + "5 5 tensor(4.9348e-10, grad_fn=)\n", + "5 6 tensor(2.1875e-10, grad_fn=)\n", + "5 7 tensor(2.6515e-10, grad_fn=)\n", + "5 8 tensor(2.7397e-10, grad_fn=)\n", + "5 9 tensor(3.0599e-10, grad_fn=)\n", + "6 0 tensor(1.1639e-10, grad_fn=)\n", + "6 1 tensor(1.1119e-10, grad_fn=)\n", + "6 2 tensor(1.3460e-10, grad_fn=)\n", + "6 3 tensor(1.7128e-10, grad_fn=)\n", + "6 4 tensor(3.7786e-10, grad_fn=)\n", + "6 5 tensor(2.1875e-10, grad_fn=)\n", + "6 6 tensor(2.5243e-10, grad_fn=)\n", + "6 7 tensor(8.0260e-11, grad_fn=)\n", + "6 8 tensor(1.2736e-10, grad_fn=)\n", + "6 9 tensor(1.3085e-10, grad_fn=)\n", + "7 0 tensor(7.3829e-11, grad_fn=)\n", + "7 1 tensor(2.3004e-10, grad_fn=)\n", + "7 2 tensor(2.1659e-10, grad_fn=)\n", + "7 3 tensor(8.5589e-11, grad_fn=)\n", + "7 4 tensor(2.8377e-10, grad_fn=)\n", + "7 5 tensor(2.6515e-10, grad_fn=)\n", + "7 6 tensor(8.0260e-11, grad_fn=)\n", + "7 7 tensor(2.5253e-10, grad_fn=)\n", + "7 8 tensor(8.5414e-11, grad_fn=)\n", + "7 9 tensor(5.3537e-11, grad_fn=)\n", + "8 0 tensor(4.8012e-11, grad_fn=)\n", + "8 1 tensor(2.6573e-10, grad_fn=)\n", + "8 2 tensor(1.9823e-10, grad_fn=)\n", + "8 3 tensor(1.2116e-10, grad_fn=)\n", + "8 4 tensor(3.4256e-10, grad_fn=)\n", + "8 5 tensor(2.7397e-10, grad_fn=)\n", + "8 6 tensor(1.2736e-10, grad_fn=)\n", + "8 7 tensor(8.5414e-11, grad_fn=)\n", + "8 8 tensor(2.5835e-10, grad_fn=)\n", + "8 9 tensor(4.4578e-11, grad_fn=)\n", + "9 0 tensor(5.2832e-11, grad_fn=)\n", + "9 1 tensor(3.0289e-10, grad_fn=)\n", + "9 2 tensor(2.4722e-10, grad_fn=)\n", + "9 3 tensor(7.4654e-11, grad_fn=)\n", + "9 4 tensor(2.5163e-10, grad_fn=)\n", + "9 5 tensor(3.0599e-10, grad_fn=)\n", + "9 6 tensor(1.3085e-10, grad_fn=)\n", + "9 7 tensor(5.3537e-11, grad_fn=)\n", + "9 8 tensor(4.4578e-11, grad_fn=)\n", + "9 9 tensor(2.1350e-10, grad_fn=)\n" + ] + } + ], + "source": [ + "for data in loader:\n", + " xm = data[\"xm\"]\n", + " xp = data[\"xp\"]\n", + " ys, latents = [], []\n", + " for i in range(10):\n", + " y, latent = net(xm, xp[:, i], xm[:, i])\n", + " ys.append(y)\n", + " latents.append(latent[0])\n", + " for i in range(10):\n", + " for j in range(10):\n", + " if i != j:\n", + " print(i, j, ((latents[i][0]-latents[j][0])**2).mean())\n", + " else:\n", + " mse = 0.\n", + " for k in range(10):\n", + " mse += ((latents[i][0] - latents[i][k])**2).mean()\n", + " print(i, i, mse / 10)\n", + " # for i in range(len(y)):\n", + " # for j in range(len(y)):\n", + " # if i == j:\n", + " # l1 = 0.\n", + " # for _xp_0 in xp[i]:\n", + " # l2 = net.content_encoder(_xp_0.unsqueeze(0)).squeeze()\n", + " # l1 += float(((latent[0][i]-l2)**2).mean())\n", + " # print(float(((latent[0][i]-l2)**2).mean()))\n", + " # print(f\"i={i} / j={j}\", l1 / len(xp[i]))\n", + " # continue\n", + " # print(f\"i={i} / j={j}\", float(((latent[0][i]-latent[0][j])**2).mean()))\n", + " # print(\"mse:\", float(((y[i]-xp[i])**2).mean()))\n", + "\n", + " break" + ] + }, + { + "cell_type": "code", + "execution_count": 21, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "tensor(0.)" + ] + }, + "execution_count": 21, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "((xp_0 - xp[:, 0])**2).mean()" + ] + }, + { + "cell_type": "code", + "execution_count": 53, + "metadata": {}, + "outputs": [], + "source": [ + "import torch\n", + "import torch.nn.functional as F" + ] + }, + { + "cell_type": "code", + "execution_count": 77, + "metadata": {}, + "outputs": [], + "source": [ + "def contrastive_loss(features, tau=0.1):\n", + " \"\"\"\n", + " Compute contrastive loss for features in shape (t, n, c)\n", + " :param features: Tensor of shape (t=num_classes, n=samples_per_class, c=feature_dim)\n", + " :param tau: Temperature scaling parameter\n", + " :return: Contrastive loss\n", + " \"\"\"\n", + " t, n, _ = features.size()\n", + " features = features.reshape(t*n, -1)\n", + "\n", + " dw = (features @ features.T).flatten()\n", + "\n", + " label = torch.eye(t * n).flatten() # 1 if same class else 0\n", + "\n", + " square_distance = torch.pow(dw, 2)\n", + " margin_distance = torch.pow(torch.clamp(tau - dw, min=0.0), 2)\n", + "\n", + " negative = label * square_distance\n", + " positive = (1 - label) * margin_distance\n", + "\n", + " loss = (negative + positive) / 2\n", + "\n", + " return loss.mean()\n" + ] + }, + { + "cell_type": "code", + "execution_count": 78, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Contrastive Loss: 6.538628101348877\n" + ] + } + ], + "source": [ + "# Example tensor with random data\n", + "num_classes = 4\n", + "samples_per_class = 3\n", + "feature_dim = 10\n", + "\n", + "# Randomly generate features\n", + "features = torch.randn(num_classes, samples_per_class, feature_dim)\n", + "\n", + "# Compute loss\n", + "loss = contrastive_loss(features)\n", + "print(\"Contrastive Loss:\", loss.item())" + ] + }, + { + "cell_type": "code", + "execution_count": 79, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "tensor(127.5766)" + ] + }, + "execution_count": 79, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "contrastive_loss(torch.randn(4, 1, 1).repeat(1, 3, 10))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": ".venv", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.12.3" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +}