Skip to content

Commit

Permalink
fix view to reshape
Browse files Browse the repository at this point in the history
  • Loading branch information
nnaakkaaii committed Jun 15, 2024
1 parent 04988f7 commit 9c2ad2e
Show file tree
Hide file tree
Showing 6 changed files with 24 additions and 24 deletions.
14 changes: 7 additions & 7 deletions hrdae/models/networks/fb_autoencoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,21 +116,21 @@ def forward(
b, n, s, h = x_1d.size()
_, _, c, _, w = x_2d_0.size()

x_1d = x_1d.view(b * n, s, h)
x_2d_0 = x_2d_0.view(b * 2, c, h, w)
x_1d_0 = x_1d_0.view(b * 2, s, h)
x_1d = x_1d.reshape(b * n, s, h)
x_2d_0 = x_2d_0.reshape(b * 2, c, h, w)
x_1d_0 = x_1d_0.reshape(b * 2, s, h)

# encode
x_1d = self.encoder_1d(x_1d)
_, c_, h_ = x_1d.size()
x_1d = x_1d.view(b * n, 1 * c_, h_, 1)
x_1d = x_1d.reshape(b * n, 1 * c_, h_, 1)

x_2d_0 = self.encoder_2d(x_2d_0)
_, _, _, w_ = x_2d_0.size()
x_2d_0 = x_2d_0.view(b, 2 * c_, h_, w_)
x_2d_0 = x_2d_0.reshape(b, 2 * c_, h_, w_)

x_1d_0 = self.encoder_1d(x_1d_0)
x_1d_0 = x_1d_0.view(b, 2 * c_, h_, 1)
x_1d_0 = x_1d_0.reshape(b, 2 * c_, h_, 1)

# expand
# (b * n, 1 * c_, h_, 1) -> (b * n, 1 * c_, h_, w_)
Expand All @@ -145,7 +145,7 @@ def forward(

# decode
out = self.decoder_2d(latent)
out = out.view(b, n, c, h, w)
out = out.reshape(b, n, c, h, w)

if self.activation is not None:
out = self.activation(out)
Expand Down
8 changes: 4 additions & 4 deletions hrdae/models/networks/hr_dae.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,12 +244,12 @@ def forward(
c, cs = self.content_encoder(x_2d_0)
m = self.motion_encoder(x_1d, x_1d_0)
b, t, c_, h = m.size()
m = m.view(b * t, c_, h)
m = m.reshape(b * t, c_, h)
c = c.repeat(t, 1, 1, 1)
cs = [c_.repeat(t, 1, 1, 1) for c_ in cs]
y = self.decoder(m, c, cs[::-1])
_, c_, h, w = y.size()
y = y.view(b, t, c_, h, w)
y = y.reshape(b, t, c_, h, w)
if self.activation is not None:
y = self.activation(y)
return y
Expand Down Expand Up @@ -293,12 +293,12 @@ def forward(
c, cs = self.content_encoder(x_3d_0)
m = self.motion_encoder(x_2d, x_2d_0)
b, t, c_, h, w = m.size()
m = m.view(b * t, c_, h, w)
m = m.reshape(b * t, c_, h, w)
c = c.repeat(t, 1, 1, 1, 1)
cs = [c_.repeat(t, 1, 1, 1, 1) for c_ in cs]
y = self.decoder(m, c, cs[::-1])
_, c_, d, h, w = y.size()
y = y.view(b, t, c_, d, h, w)
y = y.reshape(b, t, c_, d, h, w)
if self.activation is not None:
y = self.activation(y)
return y
8 changes: 4 additions & 4 deletions hrdae/models/networks/modules/gru.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,9 @@ def forward(
self, x: Tensor, last_states: Tensor | None = None
) -> tuple[Tensor, Tensor]:
b, t, c, h = x.size()
x = x.view(b, t, c * h)
x = x.reshape(b, t, c * h)
y, _last_states = self.rnn(x, last_states)
y = y.view(b, t, self.c, h)
y = y.reshape(b, t, self.c, h)
return y, _last_states


Expand Down Expand Up @@ -57,7 +57,7 @@ def forward(
self, x: Tensor, last_states: Tensor | None = None
) -> tuple[Tensor, Tensor]:
b, t, c, h, w = x.size()
x = x.view(b, t, c * h * w)
x = x.reshape(b, t, c * h * w)
y, _last_states = self.rnn(x, last_states)
y = y.view(b, t, self.c, h, w)
y = y.reshape(b, t, self.c, h, w)
return y, _last_states
8 changes: 4 additions & 4 deletions hrdae/models/networks/motion_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,10 +166,10 @@ def forward(
x_0: Tensor | None = None,
) -> Tensor:
b, t, c, h = x.size()
x = x.view(b * t, c, h)
x = x.reshape(b * t, c, h)
x = self.enc(x)
_, c, h = x.size()
x = x.view(b, t, c, h)
x = x.reshape(b, t, c, h)
if self.debug_show_dim:
print(f"{self.__class__.__name__}", x.size())
return x
Expand Down Expand Up @@ -201,10 +201,10 @@ def forward(
x_0: Tensor | None = None,
) -> Tensor:
b, t, c, d, h = x.size()
x = x.view(b * t, c, d, h)
x = x.reshape(b * t, c, d, h)
x = self.enc(x)
_, c, d, h = x.size()
x = x.view(b, t, c, d, h)
x = x.reshape(b, t, c, d, h)
if self.debug_show_dim:
print(f"{self.__class__.__name__}", x.size())
return x
Expand Down
4 changes: 2 additions & 2 deletions hrdae/models/networks/r_ae.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,11 +149,11 @@ def forward(
) -> Tensor:
m = self.motion_encoder(x_1d, x_1d_0)
b, t, c_, h = m.size()
m = m.view(b * t, c_, h, 1)
m = m.reshape(b * t, c_, h, 1)
m = interpolate(m, size=self.upsample_size, mode="bilinear", align_corners=True)
y = self.decoder(m)
_, c_, h, w = y.size()
y = y.view(b, t, c_, h, w)
y = y.reshape(b, t, c_, h, w)
if self.activation is not None:
y = self.activation(y)
return y
Expand Down
6 changes: 3 additions & 3 deletions hrdae/models/networks/r_dae.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,11 +216,11 @@ def forward(
c = self.content_encoder(x_2d_0)
m = self.motion_encoder(x_1d, x_1d_0)
b, t, c_, h = m.size()
m = m.view(b * t, c_, h)
m = m.reshape(b * t, c_, h)
c = c.repeat(t, 1, 1, 1)
y = self.decoder(m, c)
_, c_, h, w = y.size()
y = y.view(b, t, c_, h, w)
y = y.reshape(b, t, c_, h, w)
if self.activation is not None:
y = self.activation(y)
return y
Expand Down Expand Up @@ -264,7 +264,7 @@ def forward(
c = self.content_encoder(x_3d_0)
m = self.motion_encoder(x_2d, x_2d_0)
b, t, c_, d, h = m.size()
m = m.view(b * t, c_, d, h)
m = m.reshape(b * t, c_, d, h)
c = c.repeat(t, 1, 1, 1, 1)
y = self.decoder(m, c)
_, c_, d, h, w = y.size()
Expand Down

0 comments on commit 9c2ad2e

Please sign in to comment.