From 697d6093e2c40a23a0be3b0251327f6c5fc84505 Mon Sep 17 00:00:00 2001 From: nnaakkaaii Date: Sun, 14 Jul 2024 13:38:23 +0900 Subject: [PATCH] impl triplet --- hrdae/__init__.py | 2 ++ hrdae/conf | 2 +- hrdae/dataloaders/transforms/__init__.py | 7 +--- hrdae/models/basic_model.py | 4 +-- hrdae/models/functions.py | 12 +++++-- hrdae/models/losses/__init__.py | 3 ++ hrdae/models/losses/triplet.py | 46 ++++++++++++++++++++++++ hrdae/models/pvr_model.py | 3 ++ hrdae/models/vr_model.py | 24 ++++++++++++- 9 files changed, 91 insertions(+), 12 deletions(-) create mode 100644 hrdae/models/losses/triplet.py diff --git a/hrdae/__init__.py b/hrdae/__init__.py index d2c2131..13bed96 100644 --- a/hrdae/__init__.py +++ b/hrdae/__init__.py @@ -29,6 +29,7 @@ PJC2dLossOption, PJC3dLossOption, TemporalSimilarityLossOption, + TripletLossOption, WeightedMSELossOption, ) from .models.networks import ( @@ -154,6 +155,7 @@ group="config/experiment/model/loss", name="tsim", node=TemporalSimilarityLossOption ) cs.store(group="config/experiment/model/loss", name="mstd", node=MStdLossOption) +cs.store(group="config/experiment/model/loss", name="triplet", node=TripletLossOption) cs.store( group="config/experiment/model/loss_g", name="tsim", diff --git a/hrdae/conf b/hrdae/conf index 24b52f3..e55d5fa 160000 --- a/hrdae/conf +++ b/hrdae/conf @@ -1 +1 @@ -Subproject commit 24b52f3da7afca544525a9723958d4d84a5f70e6 +Subproject commit e55d5fa2311e821c5650065af142e9d45d2ecf46 diff --git a/hrdae/dataloaders/transforms/__init__.py b/hrdae/dataloaders/transforms/__init__.py index 3b2fc5c..bd238bb 100644 --- a/hrdae/dataloaders/transforms/__init__.py +++ b/hrdae/dataloaders/transforms/__init__.py @@ -5,12 +5,7 @@ from .normalization import MinMaxNormalization, MinMaxNormalizationOption from .option import TransformOption -from .pool import ( - Pool2dOption, - Pool3dOption, - create_pool2d, - create_pool3d, -) +from .pool import Pool2dOption, Pool3dOption, create_pool2d, create_pool3d from .random_shift import ( RandomShift2dOption, RandomShift3dOption, diff --git a/hrdae/models/basic_model.py b/hrdae/models/basic_model.py index 3b09270..7c40d7f 100644 --- a/hrdae/models/basic_model.py +++ b/hrdae/models/basic_model.py @@ -89,7 +89,7 @@ def train( y = y.reshape(b, n, *y.size()[1:]) z = z.reshape(b, n, *z.size()[1:]) - loss = self.criterion(t, y, latent=z) + loss = self.criterion(y, t, latent=z) loss.backward() self.optimizer.step() @@ -124,7 +124,7 @@ def train( y = y.reshape(b, n, *y.size()[1:]) z = z.reshape(b, n, *z.size()[1:]) - loss = self.criterion(t, y, latent=z) + loss = self.criterion(y, t, latent=z) total_val_loss += loss.item() avg_val_loss = total_val_loss / len(val_loader) diff --git a/hrdae/models/functions.py b/hrdae/models/functions.py index b007b81..4ae299c 100644 --- a/hrdae/models/functions.py +++ b/hrdae/models/functions.py @@ -3,7 +3,7 @@ import matplotlib.pyplot as plt import numpy as np import torch -from torch import nn +from torch import Tensor, nn def _save_images( @@ -103,5 +103,13 @@ def save_reconstructed_images( def save_model(model: nn.Module, filepath: Path): filepath.parent.mkdir(parents=True, exist_ok=True) - model_to_save = model.module if isinstance(model, torch.nn.DataParallel) else model + model_to_save = model.module if isinstance(model, nn.DataParallel) else model torch.save(model_to_save.state_dict(), filepath) + + +def shuffled_indices(length: int) -> Tensor: + indices = torch.empty(length, dtype=torch.long) + for i in range(length): + choices = torch.cat([torch.arange(0, i), torch.arange(i + 1, length)]) + indices[i] = choices[torch.randint(len(choices), (1,))] + return indices diff --git a/hrdae/models/losses/__init__.py b/hrdae/models/losses/__init__.py index 0c0607b..7bd00da 100644 --- a/hrdae/models/losses/__init__.py +++ b/hrdae/models/losses/__init__.py @@ -9,6 +9,7 @@ from .perceptual import Perceptual2dLossOption, create_perceptual2d_loss from .pjc import PJC2dLossOption, PJC3dLossOption, create_pjc2d_loss, create_pjc3d_loss from .t_sim import TemporalSimilarityLossOption, create_tsim_loss +from .triplet import TripletLossOption, create_triplet_loss from .weighted_mse import WeightedMSELossOption, create_weighted_mse_loss @@ -49,6 +50,8 @@ def create_loss(opt: LossOption) -> nn.Module: return create_mstd_loss() if isinstance(opt, Perceptual2dLossOption) and type(opt) is Perceptual2dLossOption: return create_perceptual2d_loss(opt) + if isinstance(opt, TripletLossOption) and type(opt) is TripletLossOption: + return create_triplet_loss(opt) raise NotImplementedError(f"{opt.__class__.__name__} is not implemented") diff --git a/hrdae/models/losses/triplet.py b/hrdae/models/losses/triplet.py new file mode 100644 index 0000000..beccb1e --- /dev/null +++ b/hrdae/models/losses/triplet.py @@ -0,0 +1,46 @@ +from dataclasses import dataclass + +import torch +import torch.nn.functional as F +from torch import Tensor, nn + +from .option import LossOption + + +@dataclass +class TripletLossOption(LossOption): + margin: float = 0.1 + + +def create_triplet_loss(opt: TripletLossOption) -> nn.Module: + return TripletLoss(opt.margin) + + +class TripletLoss(nn.Module): + def __init__(self, margin: float = 0.1) -> None: + super().__init__() + self.margin = margin + + @property + def required_kwargs(self) -> list[str]: + return ["latent", "positive", "negative"] + + def forward( + self, + input: Tensor, + target: Tensor, + latent: list[Tensor], + positive: list[Tensor], + negative: list[Tensor], + ) -> Tensor: + num_frames = positive[0].size(1) + + anchor = latent[0].unsqueeze(1).expand(-1, num_frames, -1, -1, -1) + pos_dist = torch.norm(anchor - positive[0], p=2, dim=[2, 3, 4]) + neg_dist = torch.norm(anchor - negative[0], p=2, dim=[2, 3, 4]) + + hard_positive_dist = pos_dist.max(dim=1)[0] + hard_negative_dist = neg_dist.min(dim=1)[0] + + losses = F.relu(hard_positive_dist - hard_negative_dist + self.margin) + return losses.mean() diff --git a/hrdae/models/pvr_model.py b/hrdae/models/pvr_model.py index a893b97..cdbe6b1 100644 --- a/hrdae/models/pvr_model.py +++ b/hrdae/models/pvr_model.py @@ -30,6 +30,7 @@ def __init__( optimizer: Optimizer, scheduler: LRScheduler, criterion: nn.Module, + use_triplet: bool, ) -> None: for k, v in network_weight.items(): if not hasattr(network, k): @@ -48,6 +49,7 @@ def __init__( optimizer, scheduler, criterion, + use_triplet, ) @@ -81,4 +83,5 @@ def create_pvr_model( optimizer, scheduler, criterion, + opt.use_triplet, ) diff --git a/hrdae/models/vr_model.py b/hrdae/models/vr_model.py index b7b3ba5..f3da442 100644 --- a/hrdae/models/vr_model.py +++ b/hrdae/models/vr_model.py @@ -8,7 +8,7 @@ from torch.optim.optimizer import Optimizer from torch.utils.data import DataLoader -from .functions import save_model, save_reconstructed_images +from .functions import save_model, save_reconstructed_images, shuffled_indices from .losses import LossMixer, LossOption, create_loss from .networks import NetworkOption, create_network from .optimizers import OptimizerOption, create_optimizer @@ -24,6 +24,7 @@ class VRModelOption(ModelOption): scheduler: SchedulerOption = MISSING loss: dict[str, LossOption] = MISSING loss_coef: dict[str, float] = MISSING + use_triplet: bool = False class VRModel(Model): @@ -33,11 +34,13 @@ def __init__( optimizer: Optimizer, scheduler: LRScheduler, criterion: nn.Module, + use_triplet: bool, ) -> None: self.network = network self.optimizer = optimizer self.scheduler = scheduler self.criterion = criterion + self.use_triplet = use_triplet if torch.cuda.is_available(): print("GPU is enabled") @@ -77,12 +80,21 @@ def train( self.optimizer.zero_grad() y, latent, cycled_latent = self.network(xm, xp_0, xm_0) + # triplet + indices = shuffled_indices(len(xp)) + positive = tensor(0.0) + negative = tensor(0.0) + if self.use_triplet: + _, _, positive = self.network(xm[indices], xp_0, xm_0[indices]) + _, _, negative = self.network(xm, xp_0[indices], xm_0) loss = self.criterion( y, xp, latent=latent, cycled_latent=cycled_latent, + positive=positive, + negative=negative, ) loss.backward() self.optimizer.step() @@ -112,12 +124,21 @@ def train( xp = data["xp"].to(self.device) xp_0 = data["xp_0"].to(self.device) y, latent, cycled_latent = self.network(xm, xp_0, xm_0) + # triplet + indices = shuffled_indices(len(xp)) + positive = tensor(0.0) + negative = tensor(0.0) + if self.use_triplet: + _, positive, _ = self.network(xm[indices], xp_0, xm_0[indices]) + _, negative, _ = self.network(xm, xp_0[indices], xm_0) loss = self.criterion( y, xp, latent=latent, cycled_latent=cycled_latent, + positive=positive, + negative=negative, ) total_val_loss += loss.item() @@ -195,4 +216,5 @@ def create_vr_model( optimizer, scheduler, criterion, + opt.use_triplet, )