Skip to content

Commit

Permalink
impl cycle dae
Browse files Browse the repository at this point in the history
  • Loading branch information
nnaakkaaii committed Jun 24, 2024
1 parent f57e244 commit 5cfc272
Show file tree
Hide file tree
Showing 10 changed files with 321 additions and 41 deletions.
12 changes: 9 additions & 3 deletions hrdae/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,16 +122,22 @@
cs.store(group="config/experiment/model", name="pvr", node=PVRModelOption)
cs.store(group="config/experiment/model", name="gan", node=GANModelOption)
cs.store(group="config/experiment/model/loss", name="mse", node=MSELossOption)
cs.store(group="config/experiment/model/loss_g", name="bce", node=BCEWithLogitsLossOption)
cs.store(group="config/experiment/model/loss_d", name="bce", node=BCEWithLogitsLossOption)
cs.store(
group="config/experiment/model/loss_g", name="bce", node=BCEWithLogitsLossOption
)
cs.store(
group="config/experiment/model/loss_d", name="bce", node=BCEWithLogitsLossOption
)
cs.store(group="config/experiment/model/loss", name="pjc2d", node=PJC2dLossOption)
cs.store(group="config/experiment/model/loss", name="pjc3d", node=PJC3dLossOption)
cs.store(group="config/experiment/model/loss", name="wmse", node=WeightedMSELossOption)
cs.store(
group="config/experiment/model/loss", name="tsim", node=TemporalSimilarityLossOption
)
cs.store(
group="config/experiment/model/loss_g", name="tsim", node=TemporalSimilarityLossOption
group="config/experiment/model/loss_g",
name="tsim",
node=TemporalSimilarityLossOption,
)
cs.store(
group="config/experiment/model/network",
Expand Down
2 changes: 1 addition & 1 deletion hrdae/conf
84 changes: 74 additions & 10 deletions hrdae/models/networks/hr_dae.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,18 @@ def create_hrdae2d(out_channels: int, opt: HRDAE2dOption) -> nn.Module:
motion_encoder = create_motion_encoder1d(
opt.latent_dim, opt.debug_show_dim, opt.motion_encoder
)
if opt.cycle:
return CycleHRDAE2d(
opt.in_channels,
out_channels,
opt.hidden_channels,
opt.latent_dim,
opt.conv_params,
motion_encoder,
opt.activation,
opt.aggregator,
opt.debug_show_dim,
)
return HRDAE2d(
opt.in_channels,
out_channels,
Expand All @@ -61,6 +73,18 @@ def create_hrdae3d(out_channels: int, opt: HRDAE3dOption) -> nn.Module:
motion_encoder = create_motion_encoder2d(
opt.latent_dim, opt.debug_show_dim, opt.motion_encoder
)
if opt.cycle:
return CycleHRDAE3d(
opt.in_channels,
out_channels,
opt.hidden_channels,
opt.latent_dim,
opt.conv_params,
motion_encoder,
opt.activation,
opt.aggregator,
opt.debug_show_dim,
)
return HRDAE3d(
opt.in_channels,
out_channels,
Expand Down Expand Up @@ -287,19 +311,39 @@ def forward(
x_1d: Tensor,
x_2d_0: Tensor,
x_1d_0: Tensor | None = None,
) -> Tensor:
) -> tuple[Tensor, list[Tensor]]:
c, cs = self.content_encoder(x_2d_0)
m = self.motion_encoder(x_1d, x_1d_0)
b, t, c_, h = m.size()
m = m.reshape(b * t, c_, h)
c = c.repeat(t, 1, 1, 1)
cs = [c_.repeat(t, 1, 1, 1) for c_ in cs]
y = self.decoder(m, c, cs[::-1])
c_exp = c.repeat(t, 1, 1, 1)
cs_exp = [c_.repeat(t, 1, 1, 1) for c_ in cs]
y = self.decoder(m, c_exp, cs_exp[::-1])
_, c_, h, w = y.size()
y = y.reshape(b, t, c_, h, w)
if self.activation is not None:
y = self.activation(y)
return y
return y, [c] + cs


class CycleHRDAE2d(HRDAE2d):
def forward(
self,
x_1d: Tensor,
x_2d_0: Tensor,
x_1d_0: Tensor | None = None,
) -> tuple[Tensor, list[Tensor]]:
y, cs = super().forward(x_1d, x_2d_0, x_1d_0)
b, t, c, h, w = y.size()
y_seq = y.reshape(b * t, c, h, w)
d, ds = self.content_encoder(y_seq)
assert d.size(0) == b * t
d = d.reshape(b, t, *d.size()[1:]) - cs[0].unsqueeze(1)
assert len(cs) == 1 + len(ds)
for i, di in enumerate(ds):
assert di.size(0) == b * t
ds[i] = di.reshape(b, t, *di.size()[1:]) - cs[1 + i].unsqueeze(1)
return y, [d] + ds


class HRDAE3d(nn.Module):
Expand Down Expand Up @@ -339,16 +383,36 @@ def forward(
x_2d: Tensor,
x_3d_0: Tensor,
x_2d_0: Tensor | None = None,
) -> Tensor:
) -> tuple[Tensor, list[Tensor]]:
c, cs = self.content_encoder(x_3d_0)
m = self.motion_encoder(x_2d, x_2d_0)
b, t, c_, h, w = m.size()
m = m.reshape(b * t, c_, h, w)
c = c.repeat(t, 1, 1, 1, 1)
cs = [c_.repeat(t, 1, 1, 1, 1) for c_ in cs]
y = self.decoder(m, c, cs[::-1])
c_exp = c.repeat(t, 1, 1, 1, 1)
cs_exp = [c_.repeat(t, 1, 1, 1, 1) for c_ in cs]
y = self.decoder(m, c_exp, cs_exp[::-1])
_, c_, d, h, w = y.size()
y = y.reshape(b, t, c_, d, h, w)
if self.activation is not None:
y = self.activation(y)
return y
return y, [c] + cs


class CycleHRDAE3d(HRDAE3d):
def forward(
self,
x_2d: Tensor,
x_3d_0: Tensor,
x_2d_0: Tensor | None = None,
) -> tuple[Tensor, list[Tensor]]:
y, cs = super().forward(x_2d, x_3d_0, x_2d_0)
b, t, c, d_, h, w = y.size()
y_seq = y.reshape(b * t, c, d_, h, w)
d, ds = self.content_encoder(y_seq)
assert d.size(0) == b * t
d = d.reshape(b, t, *d.size()[1:]) - cs[0].unsqueeze(1)
assert len(cs) == 1 + len(ds)
for i, di in enumerate(ds):
assert di.size(0) == b * t
ds[i] = di.reshape(b, t, *di.size()[1:]) - cs[1 + i].unsqueeze(1)
return y, [d] + ds
8 changes: 4 additions & 4 deletions hrdae/models/networks/r_ae.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ def forward(
x_1d: Tensor,
x_2d_0: Tensor | None = None,
x_1d_0: Tensor | None = None,
) -> Tensor:
) -> tuple[Tensor, list[Tensor]]:
m = self.motion_encoder(x_1d, x_1d_0)
b, t, c_, h = m.size()
m = m.reshape(b * t, c_, h, 1)
Expand All @@ -117,7 +117,7 @@ def forward(
y = y.reshape(b, t, c_, h, w)
if self.activation is not None:
y = self.activation(y)
return y
return y, []


class RAE3d(nn.Module):
Expand Down Expand Up @@ -149,7 +149,7 @@ def forward(
x_2d: Tensor,
x_3d_0: Tensor | None = None,
x_2d_0: Tensor | None = None,
) -> Tensor:
) -> tuple[Tensor, list[Tensor]]:
m = self.motion_encoder(x_2d, x_2d_0)
b, t, c_, d, h = m.size()
m = m.reshape(b * t, c_, d, h, 1)
Expand All @@ -161,4 +161,4 @@ def forward(
y = y.reshape(b, t, c_, d, h, w)
if self.activation is not None:
y = self.activation(y)
return y
return y, []
56 changes: 46 additions & 10 deletions hrdae/models/networks/r_dae.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,12 +26,14 @@
class RDAE2dOption(RAE2dOption):
in_channels: int = 1 # 2 if content_phase = "all"
aggregator: str = "attention"
cycle: bool = False


@dataclass
class RDAE3dOption(RAE3dOption):
in_channels: int = 1 # 2 if content_phase = "all"
aggregator: str = "attention"
cycle: bool = False


def create_rdae2d(out_channels: int, opt: RDAE2dOption) -> nn.Module:
Expand Down Expand Up @@ -106,20 +108,37 @@ def forward(
x_1d: Tensor,
x_2d_0: Tensor,
x_1d_0: Tensor | None = None,
) -> Tensor:
) -> tuple[Tensor, list[Tensor]]:
c = self.content_encoder(x_2d_0)
m = self.motion_encoder(x_1d, x_1d_0)
b, t, c_, h_ = m.size()
m = m.reshape(b * t, c_, h_)
c = c.repeat(t, 1, 1, 1)
m = upsample_motion_tensor(m, c)
h = self.aggregator((m, c))
c_exp = c.repeat(t, 1, 1, 1)
m = upsample_motion_tensor(m, c_exp)
h = self.aggregator((m, c_exp))
y = self.decoder(h)
_, c_, h_, w = y.size()
y = y.reshape(b, t, c_, h_, w)
if self.activation is not None:
y = self.activation(y)
return y
return y, [c]


class CycleRDAE2d(RDAE2d):
def forward(
self,
x_1d: Tensor,
x_2d_0: Tensor,
x_1d_0: Tensor | None = None,
) -> tuple[Tensor, list[Tensor]]:
y, cs = super().forward(x_1d, x_2d_0, x_1d_0)
b, t, c, h, w = y.size()
y_seq = y.reshape(b * t, c, h, w)
d = self.content_encoder(y_seq)
assert len(cs) == 1
assert d.size(0) == b * t
d = d.reshape(b, t, *d.size()[1:]) - cs[0].unsqueeze(1)
return y, [d]


class RDAE3d(nn.Module):
Expand Down Expand Up @@ -159,17 +178,34 @@ def forward(
x_2d: Tensor,
x_3d_0: Tensor,
x_2d_0: Tensor | None = None,
) -> Tensor:
) -> tuple[Tensor, list[Tensor]]:
c = self.content_encoder(x_3d_0)
m = self.motion_encoder(x_2d, x_2d_0)
b, t, c_, d, h_ = m.size()
m = m.reshape(b * t, c_, d, h_)
c = c.repeat(t, 1, 1, 1, 1)
m = upsample_motion_tensor(m, c)
h = self.aggregator((m, c))
c_exp = c.repeat(t, 1, 1, 1, 1)
m = upsample_motion_tensor(m, c_exp)
h = self.aggregator((m, c_exp))
y = self.decoder(h)
_, c_, d, h_, w = y.size()
y = y.reshape(b, t, c_, d, h_, w)
if self.activation is not None:
y = self.activation(y)
return y
return y, [c]


class CycleRDAE3d(RDAE3d):
def forward(
self,
x_2d: Tensor,
x_3d_0: Tensor,
x_2d_0: Tensor | None = None,
) -> tuple[Tensor, list[Tensor]]:
y, cs = super().forward(x_2d, x_3d_0, x_2d_0)
b, t, c, d_, h, w = y.size()
y_seq = y.reshape(b * t, c, d_, h, w)
d = self.content_encoder(y_seq)
assert len(cs) == 1
assert d.size(0) == b * t
d = d.reshape(b, t, *d.size()[1:]) - cs[0].unsqueeze(1)
return y, [d]
6 changes: 3 additions & 3 deletions hrdae/models/vr_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ def train(
idx_expanded = data["idx_expanded"].to(self.device)

self.optimizer.zero_grad()
y = self.network(xm, xp_0, xm_0)
y, _ = self.network(xm, xp_0, xm_0)

loss = self.criterion(y, xp, idx_expanded=idx_expanded)
loss.backward()
Expand Down Expand Up @@ -108,7 +108,7 @@ def train(
xp = data["xp"].to(self.device)
xp_0 = data["xp_0"].to(self.device)
idx_expanded = data["idx_expanded"].to(self.device)
y = self.network(xm, xp_0, xm_0)
y, _ = self.network(xm, xp_0, xm_0)

loss = self.criterion(y, xp, idx_expanded=idx_expanded)
total_val_loss += loss.item()
Expand Down Expand Up @@ -145,7 +145,7 @@ def train(
xp = data["xp"].to(self.device)
xp_0 = data["xp_0"].to(self.device)

y = self.network(xm, xp_0, xm_0)
y, _ = self.network(xm, xp_0, xm_0)

save_reconstructed_images(
xp.data.cpu().clone().detach().numpy(),
Expand Down
Loading

0 comments on commit 5cfc272

Please sign in to comment.