Skip to content

Commit

Permalink
fix perceptual
Browse files Browse the repository at this point in the history
  • Loading branch information
nnaakkaaii committed Jun 30, 2024
1 parent f9274aa commit 34c99b8
Show file tree
Hide file tree
Showing 9 changed files with 110 additions and 42 deletions.
6 changes: 3 additions & 3 deletions hrdae/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,11 @@
ContrastiveLossOption,
MSELossOption,
MStdLossOption,
Perceptual2dLossOption,
PJC2dLossOption,
PJC3dLossOption,
TemporalSimilarityLossOption,
WeightedMSELossOption,
PerceptualLossOption,
)
from .models.networks import (
AutoEncoder2dNetworkOption,
Expand Down Expand Up @@ -154,8 +154,8 @@
)
cs.store(
group="config/experiment/model/loss",
name="perceptual",
node=PerceptualLossOption,
name="perceptual2d",
node=Perceptual2dLossOption,
)
cs.store(
group="config/experiment/model/network",
Expand Down
6 changes: 3 additions & 3 deletions hrdae/models/losses/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,10 @@
from .contrastive import ContrastiveLossOption, create_contrastive_loss
from .mstd import MStdLossOption, create_mstd_loss
from .option import LossOption
from .perceptual import Perceptual2dLossOption, create_perceptual2d_loss
from .pjc import PJC2dLossOption, PJC3dLossOption, create_pjc2d_loss, create_pjc3d_loss
from .t_sim import TemporalSimilarityLossOption, create_tsim_loss
from .weighted_mse import WeightedMSELossOption, create_weighted_mse_loss
from .perceptual import PerceptualLossOption, create_perceptual_loss


@dataclass
Expand Down Expand Up @@ -47,8 +47,8 @@ def create_loss(opt: LossOption) -> nn.Module:
return create_contrastive_loss(opt)
if isinstance(opt, MStdLossOption) and type(opt) is MStdLossOption:
return create_mstd_loss()
if isinstance(opt, PerceptualLossOption) and type(opt) is PerceptualLossOption:
return create_perceptual_loss(opt)
if isinstance(opt, Perceptual2dLossOption) and type(opt) is Perceptual2dLossOption:
return create_perceptual2d_loss(opt)
raise NotImplementedError(f"{opt.__class__.__name__} is not implemented")


Expand Down
33 changes: 24 additions & 9 deletions hrdae/models/losses/perceptual.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,24 +2,37 @@
from pathlib import Path

import torch
from torch import Tensor, nn
from omegaconf import MISSING
from torch import Tensor, nn

from ..networks import AEEncoder2dNetworkOption, NetworkOption, create_network
from .option import LossOption
from ..networks import create_network, NetworkOption


@dataclass
class PerceptualLossOption(LossOption):
class Perceptual2dLossOption(LossOption):
activation: str = MISSING
in_channels: int = MISSING
hidden_channels: int = MISSING
latent_dim: int = MISSING
conv_params: list[dict[str, list[int]]] = MISSING
weight: Path = MISSING
network: NetworkOption = MISSING


def create_perceptual_loss(opt: PerceptualLossOption) -> nn.Module:
return PerceptualLoss(opt.weight, opt.network)
def create_perceptual2d_loss(opt: Perceptual2dLossOption) -> nn.Module:
return Perceptual2dLoss(
opt.weight,
AEEncoder2dNetworkOption(
activation=opt.activation,
in_channels=opt.in_channels,
hidden_channels=opt.hidden_channels,
latent_dim=opt.latent_dim,
conv_params=opt.conv_params,
),
)


class PerceptualLoss(nn.Module):
class Perceptual2dLoss(nn.Module):
def __init__(
self,
weight: Path,
Expand All @@ -33,8 +46,10 @@ def __init__(
self.network = nn.DataParallel(self.network)

def forward(self, input: Tensor, target: Tensor) -> Tensor:
_, z_input = self.network(input)
_, z_target = self.network(target)
b, t = input.size()[:2]
size = input.size()[2:]
_, z_input = self.network(input.reshape(b * t, *size))
_, z_target = self.network(target.reshape(b * t, *size))
return sum(
[ # type: ignore
torch.sqrt(torch.mean((z_input[i] - z_target[i]) ** 2))
Expand Down
16 changes: 16 additions & 0 deletions hrdae/models/networks/__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,12 @@
from torch import nn

from .autoencoder import (
AEEncoder2dNetworkOption,
AEEncoder3dNetworkOption,
AutoEncoder2dNetworkOption,
AutoEncoder3dNetworkOption,
create_ae_encoder2d,
create_ae_encoder3d,
create_autoencoder2d,
create_autoencoder3d,
)
Expand All @@ -23,6 +27,16 @@ def create_network(out_channels: int, opt: NetworkOption) -> nn.Module:
return create_discriminator2d(out_channels, opt)
if isinstance(opt, Discriminator3dOption) and type(opt) is Discriminator3dOption:
return create_discriminator3d(out_channels, opt)
if (
isinstance(opt, AEEncoder2dNetworkOption)
and type(opt) is AEEncoder2dNetworkOption
):
return create_ae_encoder2d(opt)
if (
isinstance(opt, AEEncoder3dNetworkOption)
and type(opt) is AEEncoder3dNetworkOption
):
return create_ae_encoder3d(opt)
if (
isinstance(opt, AutoEncoder2dNetworkOption)
and type(opt) is AutoEncoder2dNetworkOption
Expand Down Expand Up @@ -51,6 +65,8 @@ def create_network(out_channels: int, opt: NetworkOption) -> nn.Module:
__all__ = [
"Discriminator2dOption",
"Discriminator3dOption",
"AEEncoder2dNetworkOption",
"AEEncoder3dNetworkOption",
"AutoEncoder2dNetworkOption",
"AutoEncoder3dNetworkOption",
"HRDAE2dOption",
Expand Down
57 changes: 47 additions & 10 deletions hrdae/models/networks/autoencoder.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import sys
from dataclasses import dataclass, field

from omegaconf import MISSING
from torch import Tensor, nn

from .modules import (
Expand Down Expand Up @@ -41,7 +42,25 @@ def create_autoencoder2d(
)


class Encoder2d(nn.Module):
@dataclass
class AEEncoder2dNetworkOption(NetworkOption):
in_channels: int = MISSING
hidden_channels: int = MISSING
latent_dim: int = MISSING
conv_params: list[dict[str, list[int]]] = MISSING


def create_ae_encoder2d(opt: AEEncoder2dNetworkOption) -> nn.Module:
return AEEncoder2d(
opt.in_channels,
opt.hidden_channels,
opt.latent_dim,
opt.conv_params,
debug_show_dim=False,
)


class AEEncoder2d(nn.Module):
def __init__(
self,
in_channels: int,
Expand Down Expand Up @@ -79,7 +98,7 @@ def forward(self, x: Tensor) -> tuple[Tensor, list[Tensor]]:
return z, latent


class Decoder2d(nn.Module):
class AEDecoder2d(nn.Module):
def __init__(
self,
out_channels: int,
Expand Down Expand Up @@ -130,14 +149,14 @@ def __init__(
) -> None:
super().__init__()

self.encoder = Encoder2d(
self.encoder = AEEncoder2d(
in_channels,
hidden_channels,
latent_dim,
conv_params + [IdenticalConvBlockConvParams],
debug_show_dim=debug_show_dim,
)
self.decoder = Decoder2d(
self.decoder = AEDecoder2d(
in_channels,
hidden_channels,
latent_dim,
Expand All @@ -149,7 +168,7 @@ def __init__(
self.debug_show_dim = debug_show_dim

def forward(self, x: Tensor) -> tuple[Tensor, Tensor]:
z = self.encoder(x)
z, _ = self.encoder(x)
y = self.decoder(z)
if self.activation is not None:
y = self.activation(y)
Expand Down Expand Up @@ -188,7 +207,25 @@ def create_autoencoder3d(
)


class Encoder3d(nn.Module):
@dataclass
class AEEncoder3dNetworkOption(NetworkOption):
in_channels: int = MISSING
hidden_channels: int = MISSING
latent_dim: int = MISSING
conv_params: list[dict[str, list[int]]] = MISSING


def create_ae_encoder3d(opt: AEEncoder3dNetworkOption) -> nn.Module:
return AEEncoder3d(
opt.in_channels,
opt.hidden_channels,
opt.latent_dim,
opt.conv_params,
debug_show_dim=False,
)


class AEEncoder3d(nn.Module):
def __init__(
self,
in_channels: int,
Expand Down Expand Up @@ -226,7 +263,7 @@ def forward(self, x: Tensor) -> tuple[Tensor, list[Tensor]]:
return z, latent


class Decoder3d(nn.Module):
class AEDecoder3d(nn.Module):
def __init__(
self,
out_channels: int,
Expand Down Expand Up @@ -277,14 +314,14 @@ def __init__(
) -> None:
super().__init__()

self.encoder = Encoder3d(
self.encoder = AEEncoder3d(
in_channels,
hidden_channels,
latent_dim,
conv_params + [IdenticalConvBlockConvParams],
debug_show_dim=debug_show_dim,
)
self.decoder = Decoder3d(
self.decoder = AEDecoder3d(
in_channels,
hidden_channels,
latent_dim,
Expand All @@ -296,7 +333,7 @@ def __init__(
self.debug_show_dim = debug_show_dim

def forward(self, x: Tensor) -> tuple[Tensor, Tensor]:
z = self.encoder(x)
z, _ = self.encoder(x)
y = self.decoder(z)
if self.activation is not None:
y = self.activation(y)
Expand Down
6 changes: 3 additions & 3 deletions hrdae/models/networks/r_ae.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from torch import Tensor, nn
from torch.nn.functional import interpolate

from .autoencoder import Decoder2d, Decoder3d
from .autoencoder import AEDecoder2d, AEDecoder3d
from .modules import create_activation
from .motion_encoder import (
MotionEncoder1d,
Expand Down Expand Up @@ -92,7 +92,7 @@ def __init__(
) -> None:
super().__init__()
self.motion_encoder = motion_encoder
self.decoder = Decoder2d(
self.decoder = AEDecoder2d(
out_channels,
hidden_channels,
latent_dim,
Expand Down Expand Up @@ -134,7 +134,7 @@ def __init__(
) -> None:
super().__init__()
self.motion_encoder = motion_encoder
self.decoder = Decoder3d(
self.decoder = AEDecoder3d(
out_channels,
hidden_channels,
latent_dim,
Expand Down
18 changes: 9 additions & 9 deletions hrdae/models/networks/r_dae.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

from torch import Tensor, nn

from .autoencoder import Decoder2d, Decoder3d, Encoder2d, Encoder3d
from .autoencoder import AEDecoder2d, AEDecoder3d, AEEncoder2d, AEEncoder3d
from .functions import upsample_motion_tensor
from .modules import (
IdenticalConvBlockConvParams,
Expand Down Expand Up @@ -109,15 +109,15 @@ def __init__(
) -> None:
super().__init__()

self.content_encoder = Encoder2d(
self.content_encoder = AEEncoder2d(
in_channels,
hidden_channels,
latent_dim,
conv_params + [IdenticalConvBlockConvParams],
debug_show_dim,
)
self.motion_encoder = motion_encoder
self.decoder = Decoder2d(
self.decoder = AEDecoder2d(
out_channels,
hidden_channels,
2 * latent_dim if aggregator == "concatenation" else latent_dim,
Expand All @@ -133,7 +133,7 @@ def forward(
x_2d_0: Tensor,
x_1d_0: Tensor | None = None,
) -> tuple[Tensor, list[Tensor]]:
c = self.content_encoder(x_2d_0)
c, _ = self.content_encoder(x_2d_0)
m = self.motion_encoder(x_1d, x_1d_0)
b, t, c_, h_, w = m.size()
m = m.reshape(b * t, c_, h_, w)
Expand All @@ -158,7 +158,7 @@ def forward(
y, cs = super().forward(x_1d, x_2d_0, x_1d_0)
b, t, c, h, w = y.size()
y_seq = y.reshape(b * t, c, h, w)
d = self.content_encoder(y_seq)
d, _ = self.content_encoder(y_seq)
assert len(cs) == 1
assert d.size(0) == b * t
d = d.reshape(b, t, *d.size()[1:]) - cs[0].unsqueeze(1)
Expand All @@ -179,15 +179,15 @@ def __init__(
debug_show_dim: bool = False,
) -> None:
super().__init__()
self.content_encoder = Encoder3d(
self.content_encoder = AEEncoder3d(
in_channels,
hidden_channels,
latent_dim,
conv_params + [IdenticalConvBlockConvParams],
debug_show_dim,
)
self.motion_encoder = motion_encoder
self.decoder = Decoder3d(
self.decoder = AEDecoder3d(
out_channels,
hidden_channels,
2 * latent_dim if aggregator == "concatenation" else latent_dim,
Expand All @@ -203,7 +203,7 @@ def forward(
x_3d_0: Tensor,
x_2d_0: Tensor | None = None,
) -> tuple[Tensor, list[Tensor]]:
c = self.content_encoder(x_3d_0)
c, _ = self.content_encoder(x_3d_0)
m = self.motion_encoder(x_2d, x_2d_0)
b, t, c_, d, h_, w = m.size()
m = m.reshape(b * t, c_, d, h_, w)
Expand All @@ -228,7 +228,7 @@ def forward(
y, cs = super().forward(x_2d, x_3d_0, x_2d_0)
b, t, c, d_, h, w = y.size()
y_seq = y.reshape(b * t, c, d_, h, w)
d = self.content_encoder(y_seq)
d, _ = self.content_encoder(y_seq)
assert len(cs) == 1
assert d.size(0) == b * t
d = d.reshape(b, t, *d.size()[1:]) - cs[0].unsqueeze(1)
Expand Down
Loading

0 comments on commit 34c99b8

Please sign in to comment.