From 34c99b853f20cbd0513f02247566bd97786e5a28 Mon Sep 17 00:00:00 2001 From: nnaakkaaii Date: Mon, 1 Jul 2024 08:51:42 +0900 Subject: [PATCH] fix perceptual --- hrdae/__init__.py | 6 +-- hrdae/conf | 2 +- hrdae/models/losses/__init__.py | 6 +-- hrdae/models/losses/perceptual.py | 33 +++++++++++----- hrdae/models/networks/__init__.py | 16 ++++++++ hrdae/models/networks/autoencoder.py | 57 +++++++++++++++++++++++----- hrdae/models/networks/r_ae.py | 6 +-- hrdae/models/networks/r_dae.py | 18 ++++----- test/models/networks/test_r_ae.py | 8 ++-- 9 files changed, 110 insertions(+), 42 deletions(-) diff --git a/hrdae/__init__.py b/hrdae/__init__.py index 6614886..aafc55e 100644 --- a/hrdae/__init__.py +++ b/hrdae/__init__.py @@ -23,11 +23,11 @@ ContrastiveLossOption, MSELossOption, MStdLossOption, + Perceptual2dLossOption, PJC2dLossOption, PJC3dLossOption, TemporalSimilarityLossOption, WeightedMSELossOption, - PerceptualLossOption, ) from .models.networks import ( AutoEncoder2dNetworkOption, @@ -154,8 +154,8 @@ ) cs.store( group="config/experiment/model/loss", - name="perceptual", - node=PerceptualLossOption, + name="perceptual2d", + node=Perceptual2dLossOption, ) cs.store( group="config/experiment/model/network", diff --git a/hrdae/conf b/hrdae/conf index aa2f02f..4647b2b 160000 --- a/hrdae/conf +++ b/hrdae/conf @@ -1 +1 @@ -Subproject commit aa2f02f7b6cb7b8664e2452fa7233cd6f4e677d2 +Subproject commit 4647b2bf1dfae807571504993d686139fde081e3 diff --git a/hrdae/models/losses/__init__.py b/hrdae/models/losses/__init__.py index afa091d..0c0607b 100644 --- a/hrdae/models/losses/__init__.py +++ b/hrdae/models/losses/__init__.py @@ -6,10 +6,10 @@ from .contrastive import ContrastiveLossOption, create_contrastive_loss from .mstd import MStdLossOption, create_mstd_loss from .option import LossOption +from .perceptual import Perceptual2dLossOption, create_perceptual2d_loss from .pjc import PJC2dLossOption, PJC3dLossOption, create_pjc2d_loss, create_pjc3d_loss from .t_sim import TemporalSimilarityLossOption, create_tsim_loss from .weighted_mse import WeightedMSELossOption, create_weighted_mse_loss -from .perceptual import PerceptualLossOption, create_perceptual_loss @dataclass @@ -47,8 +47,8 @@ def create_loss(opt: LossOption) -> nn.Module: return create_contrastive_loss(opt) if isinstance(opt, MStdLossOption) and type(opt) is MStdLossOption: return create_mstd_loss() - if isinstance(opt, PerceptualLossOption) and type(opt) is PerceptualLossOption: - return create_perceptual_loss(opt) + if isinstance(opt, Perceptual2dLossOption) and type(opt) is Perceptual2dLossOption: + return create_perceptual2d_loss(opt) raise NotImplementedError(f"{opt.__class__.__name__} is not implemented") diff --git a/hrdae/models/losses/perceptual.py b/hrdae/models/losses/perceptual.py index 757544d..92e3bea 100644 --- a/hrdae/models/losses/perceptual.py +++ b/hrdae/models/losses/perceptual.py @@ -2,24 +2,37 @@ from pathlib import Path import torch -from torch import Tensor, nn from omegaconf import MISSING +from torch import Tensor, nn +from ..networks import AEEncoder2dNetworkOption, NetworkOption, create_network from .option import LossOption -from ..networks import create_network, NetworkOption @dataclass -class PerceptualLossOption(LossOption): +class Perceptual2dLossOption(LossOption): + activation: str = MISSING + in_channels: int = MISSING + hidden_channels: int = MISSING + latent_dim: int = MISSING + conv_params: list[dict[str, list[int]]] = MISSING weight: Path = MISSING - network: NetworkOption = MISSING -def create_perceptual_loss(opt: PerceptualLossOption) -> nn.Module: - return PerceptualLoss(opt.weight, opt.network) +def create_perceptual2d_loss(opt: Perceptual2dLossOption) -> nn.Module: + return Perceptual2dLoss( + opt.weight, + AEEncoder2dNetworkOption( + activation=opt.activation, + in_channels=opt.in_channels, + hidden_channels=opt.hidden_channels, + latent_dim=opt.latent_dim, + conv_params=opt.conv_params, + ), + ) -class PerceptualLoss(nn.Module): +class Perceptual2dLoss(nn.Module): def __init__( self, weight: Path, @@ -33,8 +46,10 @@ def __init__( self.network = nn.DataParallel(self.network) def forward(self, input: Tensor, target: Tensor) -> Tensor: - _, z_input = self.network(input) - _, z_target = self.network(target) + b, t = input.size()[:2] + size = input.size()[2:] + _, z_input = self.network(input.reshape(b * t, *size)) + _, z_target = self.network(target.reshape(b * t, *size)) return sum( [ # type: ignore torch.sqrt(torch.mean((z_input[i] - z_target[i]) ** 2)) diff --git a/hrdae/models/networks/__init__.py b/hrdae/models/networks/__init__.py index 03465c1..21c2c49 100644 --- a/hrdae/models/networks/__init__.py +++ b/hrdae/models/networks/__init__.py @@ -1,8 +1,12 @@ from torch import nn from .autoencoder import ( + AEEncoder2dNetworkOption, + AEEncoder3dNetworkOption, AutoEncoder2dNetworkOption, AutoEncoder3dNetworkOption, + create_ae_encoder2d, + create_ae_encoder3d, create_autoencoder2d, create_autoencoder3d, ) @@ -23,6 +27,16 @@ def create_network(out_channels: int, opt: NetworkOption) -> nn.Module: return create_discriminator2d(out_channels, opt) if isinstance(opt, Discriminator3dOption) and type(opt) is Discriminator3dOption: return create_discriminator3d(out_channels, opt) + if ( + isinstance(opt, AEEncoder2dNetworkOption) + and type(opt) is AEEncoder2dNetworkOption + ): + return create_ae_encoder2d(opt) + if ( + isinstance(opt, AEEncoder3dNetworkOption) + and type(opt) is AEEncoder3dNetworkOption + ): + return create_ae_encoder3d(opt) if ( isinstance(opt, AutoEncoder2dNetworkOption) and type(opt) is AutoEncoder2dNetworkOption @@ -51,6 +65,8 @@ def create_network(out_channels: int, opt: NetworkOption) -> nn.Module: __all__ = [ "Discriminator2dOption", "Discriminator3dOption", + "AEEncoder2dNetworkOption", + "AEEncoder3dNetworkOption", "AutoEncoder2dNetworkOption", "AutoEncoder3dNetworkOption", "HRDAE2dOption", diff --git a/hrdae/models/networks/autoencoder.py b/hrdae/models/networks/autoencoder.py index 6b6cc39..b73b680 100644 --- a/hrdae/models/networks/autoencoder.py +++ b/hrdae/models/networks/autoencoder.py @@ -1,6 +1,7 @@ import sys from dataclasses import dataclass, field +from omegaconf import MISSING from torch import Tensor, nn from .modules import ( @@ -41,7 +42,25 @@ def create_autoencoder2d( ) -class Encoder2d(nn.Module): +@dataclass +class AEEncoder2dNetworkOption(NetworkOption): + in_channels: int = MISSING + hidden_channels: int = MISSING + latent_dim: int = MISSING + conv_params: list[dict[str, list[int]]] = MISSING + + +def create_ae_encoder2d(opt: AEEncoder2dNetworkOption) -> nn.Module: + return AEEncoder2d( + opt.in_channels, + opt.hidden_channels, + opt.latent_dim, + opt.conv_params, + debug_show_dim=False, + ) + + +class AEEncoder2d(nn.Module): def __init__( self, in_channels: int, @@ -79,7 +98,7 @@ def forward(self, x: Tensor) -> tuple[Tensor, list[Tensor]]: return z, latent -class Decoder2d(nn.Module): +class AEDecoder2d(nn.Module): def __init__( self, out_channels: int, @@ -130,14 +149,14 @@ def __init__( ) -> None: super().__init__() - self.encoder = Encoder2d( + self.encoder = AEEncoder2d( in_channels, hidden_channels, latent_dim, conv_params + [IdenticalConvBlockConvParams], debug_show_dim=debug_show_dim, ) - self.decoder = Decoder2d( + self.decoder = AEDecoder2d( in_channels, hidden_channels, latent_dim, @@ -149,7 +168,7 @@ def __init__( self.debug_show_dim = debug_show_dim def forward(self, x: Tensor) -> tuple[Tensor, Tensor]: - z = self.encoder(x) + z, _ = self.encoder(x) y = self.decoder(z) if self.activation is not None: y = self.activation(y) @@ -188,7 +207,25 @@ def create_autoencoder3d( ) -class Encoder3d(nn.Module): +@dataclass +class AEEncoder3dNetworkOption(NetworkOption): + in_channels: int = MISSING + hidden_channels: int = MISSING + latent_dim: int = MISSING + conv_params: list[dict[str, list[int]]] = MISSING + + +def create_ae_encoder3d(opt: AEEncoder3dNetworkOption) -> nn.Module: + return AEEncoder3d( + opt.in_channels, + opt.hidden_channels, + opt.latent_dim, + opt.conv_params, + debug_show_dim=False, + ) + + +class AEEncoder3d(nn.Module): def __init__( self, in_channels: int, @@ -226,7 +263,7 @@ def forward(self, x: Tensor) -> tuple[Tensor, list[Tensor]]: return z, latent -class Decoder3d(nn.Module): +class AEDecoder3d(nn.Module): def __init__( self, out_channels: int, @@ -277,14 +314,14 @@ def __init__( ) -> None: super().__init__() - self.encoder = Encoder3d( + self.encoder = AEEncoder3d( in_channels, hidden_channels, latent_dim, conv_params + [IdenticalConvBlockConvParams], debug_show_dim=debug_show_dim, ) - self.decoder = Decoder3d( + self.decoder = AEDecoder3d( in_channels, hidden_channels, latent_dim, @@ -296,7 +333,7 @@ def __init__( self.debug_show_dim = debug_show_dim def forward(self, x: Tensor) -> tuple[Tensor, Tensor]: - z = self.encoder(x) + z, _ = self.encoder(x) y = self.decoder(z) if self.activation is not None: y = self.activation(y) diff --git a/hrdae/models/networks/r_ae.py b/hrdae/models/networks/r_ae.py index 9b16a09..189621c 100644 --- a/hrdae/models/networks/r_ae.py +++ b/hrdae/models/networks/r_ae.py @@ -7,7 +7,7 @@ from torch import Tensor, nn from torch.nn.functional import interpolate -from .autoencoder import Decoder2d, Decoder3d +from .autoencoder import AEDecoder2d, AEDecoder3d from .modules import create_activation from .motion_encoder import ( MotionEncoder1d, @@ -92,7 +92,7 @@ def __init__( ) -> None: super().__init__() self.motion_encoder = motion_encoder - self.decoder = Decoder2d( + self.decoder = AEDecoder2d( out_channels, hidden_channels, latent_dim, @@ -134,7 +134,7 @@ def __init__( ) -> None: super().__init__() self.motion_encoder = motion_encoder - self.decoder = Decoder3d( + self.decoder = AEDecoder3d( out_channels, hidden_channels, latent_dim, diff --git a/hrdae/models/networks/r_dae.py b/hrdae/models/networks/r_dae.py index 7bfae65..7890f69 100644 --- a/hrdae/models/networks/r_dae.py +++ b/hrdae/models/networks/r_dae.py @@ -5,7 +5,7 @@ from torch import Tensor, nn -from .autoencoder import Decoder2d, Decoder3d, Encoder2d, Encoder3d +from .autoencoder import AEDecoder2d, AEDecoder3d, AEEncoder2d, AEEncoder3d from .functions import upsample_motion_tensor from .modules import ( IdenticalConvBlockConvParams, @@ -109,7 +109,7 @@ def __init__( ) -> None: super().__init__() - self.content_encoder = Encoder2d( + self.content_encoder = AEEncoder2d( in_channels, hidden_channels, latent_dim, @@ -117,7 +117,7 @@ def __init__( debug_show_dim, ) self.motion_encoder = motion_encoder - self.decoder = Decoder2d( + self.decoder = AEDecoder2d( out_channels, hidden_channels, 2 * latent_dim if aggregator == "concatenation" else latent_dim, @@ -133,7 +133,7 @@ def forward( x_2d_0: Tensor, x_1d_0: Tensor | None = None, ) -> tuple[Tensor, list[Tensor]]: - c = self.content_encoder(x_2d_0) + c, _ = self.content_encoder(x_2d_0) m = self.motion_encoder(x_1d, x_1d_0) b, t, c_, h_, w = m.size() m = m.reshape(b * t, c_, h_, w) @@ -158,7 +158,7 @@ def forward( y, cs = super().forward(x_1d, x_2d_0, x_1d_0) b, t, c, h, w = y.size() y_seq = y.reshape(b * t, c, h, w) - d = self.content_encoder(y_seq) + d, _ = self.content_encoder(y_seq) assert len(cs) == 1 assert d.size(0) == b * t d = d.reshape(b, t, *d.size()[1:]) - cs[0].unsqueeze(1) @@ -179,7 +179,7 @@ def __init__( debug_show_dim: bool = False, ) -> None: super().__init__() - self.content_encoder = Encoder3d( + self.content_encoder = AEEncoder3d( in_channels, hidden_channels, latent_dim, @@ -187,7 +187,7 @@ def __init__( debug_show_dim, ) self.motion_encoder = motion_encoder - self.decoder = Decoder3d( + self.decoder = AEDecoder3d( out_channels, hidden_channels, 2 * latent_dim if aggregator == "concatenation" else latent_dim, @@ -203,7 +203,7 @@ def forward( x_3d_0: Tensor, x_2d_0: Tensor | None = None, ) -> tuple[Tensor, list[Tensor]]: - c = self.content_encoder(x_3d_0) + c, _ = self.content_encoder(x_3d_0) m = self.motion_encoder(x_2d, x_2d_0) b, t, c_, d, h_, w = m.size() m = m.reshape(b * t, c_, d, h_, w) @@ -228,7 +228,7 @@ def forward( y, cs = super().forward(x_2d, x_3d_0, x_2d_0) b, t, c, d_, h, w = y.size() y_seq = y.reshape(b * t, c, d_, h, w) - d = self.content_encoder(y_seq) + d, _ = self.content_encoder(y_seq) assert len(cs) == 1 assert d.size(0) == b * t d = d.reshape(b, t, *d.size()[1:]) - cs[0].unsqueeze(1) diff --git a/test/models/networks/test_r_ae.py b/test/models/networks/test_r_ae.py index f620f25..5829f07 100644 --- a/test/models/networks/test_r_ae.py +++ b/test/models/networks/test_r_ae.py @@ -1,8 +1,8 @@ from torch import randn from hrdae.models.networks.r_ae import ( - Decoder2d, - Decoder3d, + AEDecoder2d, + AEDecoder3d, RAE2d, RAE3d, ) @@ -18,7 +18,7 @@ def test_decoder2d(): latent = 4 m = randn((b * n, latent, h // 4, w // 4)) - net = Decoder2d( + net = AEDecoder2d( c_, hidden, latent, @@ -43,7 +43,7 @@ def test_decoder3d(): latent = 4 m = randn((b * n, latent, d // 4, h // 4, w // 4)) - net = Decoder3d( + net = AEDecoder3d( c_, hidden, latent,