Skip to content

Commit

Permalink
update hrdae
Browse files Browse the repository at this point in the history
  • Loading branch information
nnaakkaaii committed Jun 27, 2024
1 parent b16d154 commit c34571c
Show file tree
Hide file tree
Showing 2 changed files with 96 additions and 88 deletions.
168 changes: 90 additions & 78 deletions hrdae/models/networks/hr_dae.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,54 +167,33 @@ def __init__(
hidden_channels: int,
latent_dim: int,
conv_params: list[dict[str, list[int]]],
aggregator: str,
debug_show_dim: bool = False,
) -> None:
super().__init__()

dec_hidden_channels = hidden_channels
if aggregator == "concatenation":
dec_hidden_channels += latent_dim
self.aggregator = create_aggregator2d(aggregator, latent_dim, latent_dim)
self.bottleneck = PixelWiseConv2d(
2 * latent_dim if aggregator == "concatenation" else latent_dim,
dec_hidden_channels,
latent_dim,
hidden_channels,
act_norm=True,
)
self.dec = HierarchicalConvDecoder2d(
dec_hidden_channels,
self.cnn = HierarchicalConvDecoder2d(
hidden_channels,
out_channels,
dec_hidden_channels,
hidden_channels,
conv_params,
debug_show_dim,
)
# motion guided connection
# (Mutual Suppression Network for Video Prediction using Disentangled Features)
self.mgc = nn.ModuleList()
for _ in conv_params:
self.mgc.append(
nn.Sequential(
create_aggregator2d(aggregator, hidden_channels, latent_dim),
ResNetBranch(
IdenticalConvBlock2d(dec_hidden_channels, dec_hidden_channels),
IdenticalConvBlock2d(
dec_hidden_channels, dec_hidden_channels, act_norm=False
),
),
nn.GroupNorm(2, dec_hidden_channels),
nn.LeakyReLU(0.2, inplace=True),
)
)
self.debug_show_dim = debug_show_dim

def forward(self, m: Tensor, c: Tensor, cs: list[Tensor]) -> Tensor:
assert len(self.mgc) == len(cs)
def forward(self, z: Tensor, cs: list[Tensor]) -> Tensor:
y = self.bottleneck(z)
x = self.cnn(y, cs)

x = self.aggregator((c, upsample_motion_tensor(m, c)))
for i, mgc in enumerate(self.mgc):
cs[i] = mgc((cs[i], upsample_motion_tensor(m, cs[i])))
if self.debug_show_dim:
print("Latent", z.size())
print("Output", y.size())
print("Input", x.size())

x = self.bottleneck(x)
return self.dec(x, cs)
return x


class HierarchicalDecoder3d(nn.Module):
Expand All @@ -224,54 +203,33 @@ def __init__(
hidden_channels: int,
latent_dim: int,
conv_params: list[dict[str, list[int]]],
aggregator: str,
debug_show_dim: bool = False,
) -> None:
super().__init__()

dec_hidden_channels = hidden_channels
if aggregator == "concatenation":
dec_hidden_channels += latent_dim
self.aggregator = create_aggregator3d(aggregator, latent_dim, latent_dim)
self.bottleneck = PixelWiseConv3d(
2 * latent_dim if aggregator == "concatenation" else latent_dim,
dec_hidden_channels,
latent_dim,
hidden_channels,
act_norm=True,
)
self.dec = HierarchicalConvDecoder3d(
dec_hidden_channels,
self.cnn = HierarchicalConvDecoder3d(
hidden_channels,
out_channels,
dec_hidden_channels,
hidden_channels,
conv_params,
debug_show_dim,
)
# motion guided connection
# (Mutual Suppression Network for Video Prediction using Disentangled Features)
self.mgc = nn.ModuleList()
for _ in conv_params:
self.mgc.append(
nn.Sequential(
create_aggregator3d(aggregator, hidden_channels, latent_dim),
ResNetBranch(
IdenticalConvBlock3d(dec_hidden_channels, dec_hidden_channels),
IdenticalConvBlock3d(
dec_hidden_channels, dec_hidden_channels, act_norm=False
),
),
nn.GroupNorm(2, dec_hidden_channels),
nn.LeakyReLU(0.2, inplace=True),
)
)
self.debug_show_dim = debug_show_dim

def forward(self, m: Tensor, c: Tensor, cs: list[Tensor]) -> Tensor:
assert len(self.mgc) == len(cs)
def forward(self, z: Tensor, cs: list[Tensor]) -> Tensor:
y = self.bottleneck(z)
x = self.cnn(y, cs)

x = self.aggregator((c, upsample_motion_tensor(m, c)))
for i, mgc in enumerate(self.mgc):
cs[i] = mgc((cs[i], upsample_motion_tensor(m, cs[i])))
if self.debug_show_dim:
print("Latent", z.size())
print("Output", y.size())
print("Input", x.size())

x = self.bottleneck(x)
return self.dec(x, cs)
return x


class HRDAE2d(nn.Module):
Expand All @@ -288,6 +246,10 @@ def __init__(
debug_show_dim: bool = False,
) -> None:
super().__init__()

dec_hidden_channels = hidden_channels
if aggregator == "concatenation":
dec_hidden_channels += latent_dim
self.content_encoder = HierarchicalEncoder2d(
in_channels,
hidden_channels,
Expand All @@ -298,12 +260,29 @@ def __init__(
self.motion_encoder = motion_encoder
self.decoder = HierarchicalDecoder2d(
out_channels,
hidden_channels,
latent_dim,
dec_hidden_channels,
2 * latent_dim if aggregator == "concatenation" else latent_dim,
conv_params[::-1],
aggregator,
debug_show_dim,
)
self.aggregator = create_aggregator2d(aggregator, latent_dim, latent_dim)
# motion guided connection
# (Mutual Suppression Network for Video Prediction using Disentangled Features)
self.mgc = nn.ModuleList()
for _ in conv_params:
self.mgc.append(
nn.Sequential(
create_aggregator2d(aggregator, hidden_channels, latent_dim),
ResNetBranch(
IdenticalConvBlock2d(dec_hidden_channels, dec_hidden_channels),
IdenticalConvBlock2d(
dec_hidden_channels, dec_hidden_channels, act_norm=False
),
),
nn.GroupNorm(2, dec_hidden_channels),
nn.LeakyReLU(0.2, inplace=True),
)
)
self.activation = create_activation(activation)

def forward(
Expand All @@ -318,7 +297,13 @@ def forward(
m = m.reshape(b * t, c_, h, w)
c_exp = c.repeat(t, 1, 1, 1)
cs_exp = [c_.repeat(t, 1, 1, 1) for c_ in cs]
y = self.decoder(m, c_exp, cs_exp[::-1])

assert len(self.mgc) == len(cs_exp)
z = self.aggregator((c_exp, upsample_motion_tensor(m, c_exp)))
for i, mgc in enumerate(self.mgc):
cs_exp[i] = mgc((cs_exp[i], upsample_motion_tensor(m, cs_exp[i])))
y = self.decoder(z, cs_exp[::-1])

_, c_, h, w = y.size()
y = y.reshape(b, t, c_, h, w)
if self.activation is not None:
Expand Down Expand Up @@ -360,6 +345,10 @@ def __init__(
debug_show_dim: bool = False,
) -> None:
super().__init__()

dec_hidden_channels = hidden_channels
if aggregator == "concatenation":
dec_hidden_channels += latent_dim
self.content_encoder = HierarchicalEncoder3d(
in_channels,
hidden_channels,
Expand All @@ -370,12 +359,29 @@ def __init__(
self.motion_encoder = motion_encoder
self.decoder = HierarchicalDecoder3d(
out_channels,
hidden_channels,
latent_dim,
dec_hidden_channels,
2 * latent_dim if aggregator == "concatenation" else latent_dim,
conv_params[::-1],
aggregator,
debug_show_dim,
)
self.aggregator = create_aggregator2d(aggregator, latent_dim, latent_dim)
# motion guided connection
# (Mutual Suppression Network for Video Prediction using Disentangled Features)
self.mgc = nn.ModuleList()
for _ in conv_params:
self.mgc.append(
nn.Sequential(
create_aggregator3d(aggregator, hidden_channels, latent_dim),
ResNetBranch(
IdenticalConvBlock3d(dec_hidden_channels, dec_hidden_channels),
IdenticalConvBlock3d(
dec_hidden_channels, dec_hidden_channels, act_norm=False
),
),
nn.GroupNorm(2, dec_hidden_channels),
nn.LeakyReLU(0.2, inplace=True),
)
)
self.activation = create_activation(activation)

def forward(
Expand All @@ -390,7 +396,13 @@ def forward(
m = m.reshape(b * t, c_, d, h, w)
c_exp = c.repeat(t, 1, 1, 1, 1)
cs_exp = [c_.repeat(t, 1, 1, 1, 1) for c_ in cs]
y = self.decoder(m, c_exp, cs_exp[::-1])

assert len(self.mgc) == len(cs_exp)
z = self.aggregator((c_exp, upsample_motion_tensor(m, c_exp)))
for i, mgc in enumerate(self.mgc):
cs_exp[i] = mgc((cs_exp[i], upsample_motion_tensor(m, cs_exp[i])))
y = self.decoder(z, cs_exp[::-1])

_, c_, d, h, w = y.size()
y = y.reshape(b, t, c_, d, h, w)
if self.activation is not None:
Expand Down
16 changes: 6 additions & 10 deletions test/models/networks/test_hr_dae.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,8 +80,7 @@ def test_hierarchical_decoder2d():
hidden = 16
latent = 4

m = randn((b * n, latent, h // 4, w // 4))
c = randn((b * n, latent, h // 4, w // 4))
x = randn((b * n, latent, h // 4, w // 4))
cs = [
randn((b * n, hidden, h // 4, w // 4)),
randn((b * n, hidden, h // 2, w // 2)),
Expand All @@ -99,10 +98,9 @@ def test_hierarchical_decoder2d():
}
]
* 2,
aggregator="addition",
debug_show_dim=False,
)
x = net(m, c, cs)
x = net(x, cs)
assert x.size() == (b * n, c_, h, w)


Expand All @@ -111,11 +109,10 @@ def test_hierarchical_decoder3d():
hidden = 16
latent = 4

m = randn((b * n, latent, d, h, w))
c = randn((b * n, latent, h // 4, h // 4, w // 4))
x = randn((b * n, latent, d // 4, h // 4, w // 4))
cs = [
randn((b * n, hidden, h // 4, h // 4, w // 4)),
randn((b * n, hidden, h // 2, h // 2, w // 2)),
randn((b * n, hidden, d // 4, h // 4, w // 4)),
randn((b * n, hidden, d // 2, h // 2, w // 2)),
]
net = HierarchicalDecoder3d(
c_,
Expand All @@ -130,10 +127,9 @@ def test_hierarchical_decoder3d():
}
]
* 2,
aggregator="addition",
debug_show_dim=False,
)
x = net(m, c, cs)
x = net(x, cs)
assert x.size() == (b * n, c_, d, h, w)


Expand Down

0 comments on commit c34571c

Please sign in to comment.