Skip to content

Commit

Permalink
fix ci error
Browse files Browse the repository at this point in the history
  • Loading branch information
nnaakkaaii committed Jun 25, 2024
1 parent 1286a5b commit 1c5af6e
Show file tree
Hide file tree
Showing 10 changed files with 202 additions and 95 deletions.
2 changes: 1 addition & 1 deletion hrdae/models/losses/mstd.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,4 +21,4 @@ def required_kwargs(self) -> list[str]:
return ["latent"]

def forward(self, input: Tensor, target: Tensor, latent: list[Tensor]) -> Tensor:
return sum([torch.sqrt(torch.mean(v**2)) for v in latent])
return sum([torch.sqrt(torch.mean(v**2)) for v in latent]) # type: ignore
2 changes: 0 additions & 2 deletions hrdae/models/networks/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,11 @@ def upsample_motion_tensor(m: Tensor, c: Tensor) -> Tensor:

def _upsample_motion_tensor2d(m: Tensor, c: Tensor) -> Tensor:
b, c_, h, w = c.size()
m = m.unsqueeze(-1)
m = interpolate(m, size=(h, w), mode="bilinear", align_corners=True)
return m


def _upsample_motion_tensor3d(m: Tensor, c: Tensor) -> Tensor:
b, c_, d, h, w = c.size()
m = m.unsqueeze(-1)
m = interpolate(m, size=(d, h, w), mode="trilinear", align_corners=True)
return m
8 changes: 4 additions & 4 deletions hrdae/models/networks/hr_dae.py
Original file line number Diff line number Diff line change
Expand Up @@ -314,8 +314,8 @@ def forward(
) -> tuple[Tensor, list[Tensor]]:
c, cs = self.content_encoder(x_2d_0)
m = self.motion_encoder(x_1d, x_1d_0)
b, t, c_, h = m.size()
m = m.reshape(b * t, c_, h)
b, t, c_, h, w = m.size()
m = m.reshape(b * t, c_, h, w)
c_exp = c.repeat(t, 1, 1, 1)
cs_exp = [c_.repeat(t, 1, 1, 1) for c_ in cs]
y = self.decoder(m, c_exp, cs_exp[::-1])
Expand Down Expand Up @@ -386,8 +386,8 @@ def forward(
) -> tuple[Tensor, list[Tensor]]:
c, cs = self.content_encoder(x_3d_0)
m = self.motion_encoder(x_2d, x_2d_0)
b, t, c_, h, w = m.size()
m = m.reshape(b * t, c_, h, w)
b, t, c_, d, h, w = m.size()
m = m.reshape(b * t, c_, d, h, w)
c_exp = c.repeat(t, 1, 1, 1, 1)
cs_exp = [c_.repeat(t, 1, 1, 1, 1) for c_ in cs]
y = self.decoder(m, c_exp, cs_exp[::-1])
Expand Down
45 changes: 26 additions & 19 deletions hrdae/models/networks/motion_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def create_motion_encoder1d(
latent_dim,
opt.conv_params,
opt.deconv_params,
create_rnn1d(latent_dim, opt.rnn),
create_rnn1d(opt.hidden_channels, opt.rnn),
debug_show_dim,
)
if (
Expand All @@ -78,7 +78,7 @@ def create_motion_encoder1d(
return MotionConv2dEncoder1d(
opt.in_channels,
opt.hidden_channels,
latent_dim,
opt.hidden_channels,
opt.conv_params,
opt.deconv_params,
debug_show_dim,
Expand All @@ -99,7 +99,7 @@ def create_motion_encoder2d(
latent_dim,
opt.conv_params,
opt.deconv_params,
create_rnn2d(latent_dim, opt.rnn),
create_rnn2d(opt.hidden_channels, opt.rnn),
debug_show_dim,
)
if (
Expand Down Expand Up @@ -207,6 +207,7 @@ def forward(
y = y.unsqueeze(-1)
y = self.tcnn(y)
z = self.bottleneck(y)
z = z.reshape(b, t, *z.size()[1:])
if self.debug_show_dim:
print(f"{self.__class__.__name__}", z.size())
return z
Expand Down Expand Up @@ -260,6 +261,7 @@ def forward(
y = y.unsqueeze(-1)
y = self.tcnn(y)
z = self.bottleneck(y)
z = z.reshape(b, t, *z.size()[1:])
if self.debug_show_dim:
print(f"{self.__class__.__name__}", z.size())
return z
Expand Down Expand Up @@ -318,13 +320,16 @@ def forward(
x: Tensor,
x_0: Tensor | None = None,
) -> Tensor:
b, t, c, h = x.size()
x = x.reshape(b * t, c, h)
b, t = x.size()[:2]
x = x.reshape(b * t, *x.size()[2:])
y = self.cnn(x)
y = y.reshape(b, t, *y.size()[1:])
y, _ = self.rnn(y)
y = y.reshape(b * t, *y.size()[2:])
y = y.unsqueeze(-1)
y = self.tcnn(y)
z = self.bottleneck(y)
z = z.reshape(b, t, *z.size()[1:])
if self.debug_show_dim:
print(f"{self.__class__.__name__}", z.size())
return z
Expand Down Expand Up @@ -373,13 +378,17 @@ def forward(
x: Tensor,
x_0: Tensor | None = None,
) -> Tensor:
b, t, c, d, h = x.size()
x = x.reshape(b * t, c, d, h)
b, t = x.size()[:2]
x = x.reshape(b * t, *x.size()[2:])
y = self.cnn(x)
y = y.reshape(b, t, *y.size()[1:])
print(y.shape)
y, _ = self.rnn(y)
y = y.reshape(b * t, *y.size()[2:])
y = y.unsqueeze(-1)
y = self.tcnn(y)
z = self.bottleneck(y)
z = z.reshape(b, t, *z.size()[1:])
if self.debug_show_dim:
print(f"{self.__class__.__name__}", z.size())
return z
Expand Down Expand Up @@ -439,13 +448,12 @@ def forward(
) -> Tensor:
x = x.permute(0, 2, 1, 3) # (b, c, t, h)
y = self.cnn(x)
y = y.unsqueeze(-1) # (b, c, t, h, 1)
y = self.tcnn(y)
y = y.permute(0, 2, 1, 3, 4) # (b, t, c, h, w)
b, t, _, h, w = y.size()
y = y.reshape(b * t, -1, h, w)
y = y.permute(0, 2, 1, 3) # (b, t, c, h)
b, t, _, h = y.size()
y = y.reshape(b * t, -1, h, 1)
y = self.tcnn(y) # (b * t, c, h, w)
z = self.bottleneck(y)
z = z.reshape(b, t, -1, h, w)
z = z.reshape(b, t, *z.size()[1:])
if self.debug_show_dim:
print(f"{self.__class__.__name__}", z.size())
return z
Expand Down Expand Up @@ -495,13 +503,12 @@ def forward(
) -> Tensor:
x = x.permute(0, 2, 1, 3, 4) # (b, c, t, d, h)
y = self.cnn(x)
y = y.unsqueeze(-1) # (b, c, t, d, h, 1)
y = self.tcnn(y)
y = y.permute(0, 2, 1, 3, 4, 5) # (b, t, c, d, h, w)
b, t, _, d, h, w = y.size()
y = y.reshape(b * t, -1, d, h, w)
y = y.permute(0, 2, 1, 3, 4) # (b, t, c, d, h)
b, t, _, d, h = y.size()
y = y.reshape(b * t, -1, d, h, 1)
y = self.tcnn(y) # (b * t, c, d, h, w)
z = self.bottleneck(y)
z = z.reshape(b, t, -1, d, h, w)
z = z.reshape(b, t, *z.size()[1:])
if self.debug_show_dim:
print(f"{self.__class__.__name__}", z.size())
return z
8 changes: 4 additions & 4 deletions hrdae/models/networks/r_ae.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,8 +109,8 @@ def forward(
x_1d_0: Tensor | None = None,
) -> tuple[Tensor, list[Tensor]]:
m = self.motion_encoder(x_1d, x_1d_0)
b, t, c_, h = m.size()
m = m.reshape(b * t, c_, h, 1)
b, t, c_, h, w = m.size()
m = m.reshape(b * t, c_, h, w)
m = interpolate(m, size=self.upsample_size, mode="bilinear", align_corners=True)
y = self.decoder(m)
_, c_, h, w = y.size()
Expand Down Expand Up @@ -151,8 +151,8 @@ def forward(
x_2d_0: Tensor | None = None,
) -> tuple[Tensor, list[Tensor]]:
m = self.motion_encoder(x_2d, x_2d_0)
b, t, c_, d, h = m.size()
m = m.reshape(b * t, c_, d, h, 1)
b, t, c_, d, h, w = m.size()
m = m.reshape(b * t, c_, d, h, w)
m = interpolate(
m, size=self.upsample_size, mode="trilinear", align_corners=True
)
Expand Down
8 changes: 4 additions & 4 deletions hrdae/models/networks/r_dae.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,8 +111,8 @@ def forward(
) -> tuple[Tensor, list[Tensor]]:
c = self.content_encoder(x_2d_0)
m = self.motion_encoder(x_1d, x_1d_0)
b, t, c_, h_ = m.size()
m = m.reshape(b * t, c_, h_)
b, t, c_, h_, w = m.size()
m = m.reshape(b * t, c_, h_, w)
c_exp = c.repeat(t, 1, 1, 1)
m = upsample_motion_tensor(m, c_exp)
h = self.aggregator((m, c_exp))
Expand Down Expand Up @@ -181,8 +181,8 @@ def forward(
) -> tuple[Tensor, list[Tensor]]:
c = self.content_encoder(x_3d_0)
m = self.motion_encoder(x_2d, x_2d_0)
b, t, c_, d, h_ = m.size()
m = m.reshape(b * t, c_, d, h_)
b, t, c_, d, h_, w = m.size()
m = m.reshape(b * t, c_, d, h_, w)
c_exp = c.repeat(t, 1, 1, 1, 1)
m = upsample_motion_tensor(m, c_exp)
h = self.aggregator((m, c_exp))
Expand Down
58 changes: 56 additions & 2 deletions test/models/networks/test_hr_dae.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ def test_hierarchical_decoder2d():
hidden = 16
latent = 4

m = randn((b * n, latent, h // 4))
m = randn((b * n, latent, h // 4, w // 4))
c = randn((b * n, latent, h // 4, w // 4))
cs = [
randn((b * n, hidden, h // 4, w // 4)),
Expand Down Expand Up @@ -111,7 +111,7 @@ def test_hierarchical_decoder3d():
hidden = 16
latent = 4

m = randn((b * n, latent, d, h))
m = randn((b * n, latent, d, h, w))
c = randn((b * n, latent, h // 4, h // 4, w // 4))
cs = [
randn((b * n, hidden, h // 4, h // 4, w // 4)),
Expand Down Expand Up @@ -165,6 +165,15 @@ def test_hrdae2d():
}
]
* 2,
deconv_params=[
{
"kernel_size": [3],
"stride": [1, 2],
"padding": [1],
"output_padding": [0, 1],
}
]
* 2,
),
aggregator="addition",
activation="sigmoid",
Expand Down Expand Up @@ -205,6 +214,15 @@ def test_hrdae2d__concatenation():
}
]
* 2,
deconv_params=[
{
"kernel_size": [3],
"stride": [1, 2],
"padding": [1],
"output_padding": [0, 1],
}
]
* 2,
),
aggregator="concatenation",
activation="sigmoid",
Expand Down Expand Up @@ -245,6 +263,15 @@ def test_hrdae3d():
}
]
* 2,
deconv_params=[
{
"kernel_size": [3],
"stride": [1, 1, 2],
"padding": [1],
"output_padding": [0, 0, 1],
}
]
* 2,
),
aggregator="addition",
activation="sigmoid",
Expand Down Expand Up @@ -285,6 +312,15 @@ def test_hrdae3d__concatenation():
}
]
* 2,
deconv_params=[
{
"kernel_size": [3],
"stride": [1, 1, 2],
"padding": [1],
"output_padding": [0, 0, 1],
}
]
* 2,
),
aggregator="concatenation",
activation="sigmoid",
Expand Down Expand Up @@ -325,6 +361,15 @@ def test_cycle_hrdae2d():
}
]
* 2,
deconv_params=[
{
"kernel_size": [3],
"stride": [1, 2],
"padding": [1],
"output_padding": [0, 1],
}
]
* 2,
),
aggregator="addition",
activation="sigmoid",
Expand Down Expand Up @@ -369,6 +414,15 @@ def test_cycle_hrdae3d():
}
]
* 2,
deconv_params=[
{
"kernel_size": [3],
"stride": [1, 1, 2],
"padding": [1],
"output_padding": [0, 0, 1],
}
]
* 2,
),
aggregator="addition",
activation="sigmoid",
Expand Down
Loading

0 comments on commit 1c5af6e

Please sign in to comment.