Skip to content

Commit

Permalink
impl triplet
Browse files Browse the repository at this point in the history
  • Loading branch information
nnaakkaaii committed Jul 14, 2024
1 parent 472a4e6 commit 697d609
Show file tree
Hide file tree
Showing 9 changed files with 91 additions and 12 deletions.
2 changes: 2 additions & 0 deletions hrdae/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
PJC2dLossOption,
PJC3dLossOption,
TemporalSimilarityLossOption,
TripletLossOption,
WeightedMSELossOption,
)
from .models.networks import (
Expand Down Expand Up @@ -154,6 +155,7 @@
group="config/experiment/model/loss", name="tsim", node=TemporalSimilarityLossOption
)
cs.store(group="config/experiment/model/loss", name="mstd", node=MStdLossOption)
cs.store(group="config/experiment/model/loss", name="triplet", node=TripletLossOption)
cs.store(
group="config/experiment/model/loss_g",
name="tsim",
Expand Down
7 changes: 1 addition & 6 deletions hrdae/dataloaders/transforms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,7 @@

from .normalization import MinMaxNormalization, MinMaxNormalizationOption
from .option import TransformOption
from .pool import (
Pool2dOption,
Pool3dOption,
create_pool2d,
create_pool3d,
)
from .pool import Pool2dOption, Pool3dOption, create_pool2d, create_pool3d
from .random_shift import (
RandomShift2dOption,
RandomShift3dOption,
Expand Down
4 changes: 2 additions & 2 deletions hrdae/models/basic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ def train(
y = y.reshape(b, n, *y.size()[1:])
z = z.reshape(b, n, *z.size()[1:])

loss = self.criterion(t, y, latent=z)
loss = self.criterion(y, t, latent=z)
loss.backward()
self.optimizer.step()

Expand Down Expand Up @@ -124,7 +124,7 @@ def train(
y = y.reshape(b, n, *y.size()[1:])
z = z.reshape(b, n, *z.size()[1:])

loss = self.criterion(t, y, latent=z)
loss = self.criterion(y, t, latent=z)
total_val_loss += loss.item()

avg_val_loss = total_val_loss / len(val_loader)
Expand Down
12 changes: 10 additions & 2 deletions hrdae/models/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import matplotlib.pyplot as plt
import numpy as np
import torch
from torch import nn
from torch import Tensor, nn


def _save_images(
Expand Down Expand Up @@ -103,5 +103,13 @@ def save_reconstructed_images(

def save_model(model: nn.Module, filepath: Path):
filepath.parent.mkdir(parents=True, exist_ok=True)
model_to_save = model.module if isinstance(model, torch.nn.DataParallel) else model
model_to_save = model.module if isinstance(model, nn.DataParallel) else model
torch.save(model_to_save.state_dict(), filepath)


def shuffled_indices(length: int) -> Tensor:
indices = torch.empty(length, dtype=torch.long)
for i in range(length):
choices = torch.cat([torch.arange(0, i), torch.arange(i + 1, length)])
indices[i] = choices[torch.randint(len(choices), (1,))]
return indices
3 changes: 3 additions & 0 deletions hrdae/models/losses/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from .perceptual import Perceptual2dLossOption, create_perceptual2d_loss
from .pjc import PJC2dLossOption, PJC3dLossOption, create_pjc2d_loss, create_pjc3d_loss
from .t_sim import TemporalSimilarityLossOption, create_tsim_loss
from .triplet import TripletLossOption, create_triplet_loss
from .weighted_mse import WeightedMSELossOption, create_weighted_mse_loss


Expand Down Expand Up @@ -49,6 +50,8 @@ def create_loss(opt: LossOption) -> nn.Module:
return create_mstd_loss()
if isinstance(opt, Perceptual2dLossOption) and type(opt) is Perceptual2dLossOption:
return create_perceptual2d_loss(opt)
if isinstance(opt, TripletLossOption) and type(opt) is TripletLossOption:
return create_triplet_loss(opt)
raise NotImplementedError(f"{opt.__class__.__name__} is not implemented")


Expand Down
46 changes: 46 additions & 0 deletions hrdae/models/losses/triplet.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
from dataclasses import dataclass

import torch
import torch.nn.functional as F
from torch import Tensor, nn

from .option import LossOption


@dataclass
class TripletLossOption(LossOption):
margin: float = 0.1


def create_triplet_loss(opt: TripletLossOption) -> nn.Module:
return TripletLoss(opt.margin)


class TripletLoss(nn.Module):
def __init__(self, margin: float = 0.1) -> None:
super().__init__()
self.margin = margin

@property
def required_kwargs(self) -> list[str]:
return ["latent", "positive", "negative"]

def forward(
self,
input: Tensor,
target: Tensor,
latent: list[Tensor],
positive: list[Tensor],
negative: list[Tensor],
) -> Tensor:
num_frames = positive[0].size(1)

anchor = latent[0].unsqueeze(1).expand(-1, num_frames, -1, -1, -1)
pos_dist = torch.norm(anchor - positive[0], p=2, dim=[2, 3, 4])
neg_dist = torch.norm(anchor - negative[0], p=2, dim=[2, 3, 4])

hard_positive_dist = pos_dist.max(dim=1)[0]
hard_negative_dist = neg_dist.min(dim=1)[0]

losses = F.relu(hard_positive_dist - hard_negative_dist + self.margin)
return losses.mean()
3 changes: 3 additions & 0 deletions hrdae/models/pvr_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ def __init__(
optimizer: Optimizer,
scheduler: LRScheduler,
criterion: nn.Module,
use_triplet: bool,
) -> None:
for k, v in network_weight.items():
if not hasattr(network, k):
Expand All @@ -48,6 +49,7 @@ def __init__(
optimizer,
scheduler,
criterion,
use_triplet,
)


Expand Down Expand Up @@ -81,4 +83,5 @@ def create_pvr_model(
optimizer,
scheduler,
criterion,
opt.use_triplet,
)
24 changes: 23 additions & 1 deletion hrdae/models/vr_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from torch.optim.optimizer import Optimizer
from torch.utils.data import DataLoader

from .functions import save_model, save_reconstructed_images
from .functions import save_model, save_reconstructed_images, shuffled_indices
from .losses import LossMixer, LossOption, create_loss
from .networks import NetworkOption, create_network
from .optimizers import OptimizerOption, create_optimizer
Expand All @@ -24,6 +24,7 @@ class VRModelOption(ModelOption):
scheduler: SchedulerOption = MISSING
loss: dict[str, LossOption] = MISSING
loss_coef: dict[str, float] = MISSING
use_triplet: bool = False


class VRModel(Model):
Expand All @@ -33,11 +34,13 @@ def __init__(
optimizer: Optimizer,
scheduler: LRScheduler,
criterion: nn.Module,
use_triplet: bool,
) -> None:
self.network = network
self.optimizer = optimizer
self.scheduler = scheduler
self.criterion = criterion
self.use_triplet = use_triplet

if torch.cuda.is_available():
print("GPU is enabled")
Expand Down Expand Up @@ -77,12 +80,21 @@ def train(

self.optimizer.zero_grad()
y, latent, cycled_latent = self.network(xm, xp_0, xm_0)
# triplet
indices = shuffled_indices(len(xp))
positive = tensor(0.0)
negative = tensor(0.0)
if self.use_triplet:
_, _, positive = self.network(xm[indices], xp_0, xm_0[indices])
_, _, negative = self.network(xm, xp_0[indices], xm_0)

loss = self.criterion(
y,
xp,
latent=latent,
cycled_latent=cycled_latent,
positive=positive,
negative=negative,
)
loss.backward()
self.optimizer.step()
Expand Down Expand Up @@ -112,12 +124,21 @@ def train(
xp = data["xp"].to(self.device)
xp_0 = data["xp_0"].to(self.device)
y, latent, cycled_latent = self.network(xm, xp_0, xm_0)
# triplet
indices = shuffled_indices(len(xp))
positive = tensor(0.0)
negative = tensor(0.0)
if self.use_triplet:
_, positive, _ = self.network(xm[indices], xp_0, xm_0[indices])
_, negative, _ = self.network(xm, xp_0[indices], xm_0)

loss = self.criterion(
y,
xp,
latent=latent,
cycled_latent=cycled_latent,
positive=positive,
negative=negative,
)
total_val_loss += loss.item()

Expand Down Expand Up @@ -195,4 +216,5 @@ def create_vr_model(
optimizer,
scheduler,
criterion,
opt.use_triplet,
)

0 comments on commit 697d609

Please sign in to comment.