Skip to content

Commit

Permalink
update notebooks
Browse files Browse the repository at this point in the history
  • Loading branch information
nnaakkaaii committed Jul 17, 2024
1 parent 8a5df30 commit f46f219
Show file tree
Hide file tree
Showing 12 changed files with 88,392 additions and 300 deletions.
48 changes: 28 additions & 20 deletions hrdae/models/gan_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ def train(
if debug:
max_iter = 5
adv_ratio = 0.1
train_discriminator = 5

least_val_loss_g = float("inf")
training_history: dict[str, list[dict[str, int | float]]] = {"history": []}
Expand Down Expand Up @@ -110,19 +111,16 @@ def train(
indices = torch.randint(0, num_frames, (batch_size, 2))
state1 = latent_m[torch.arange(batch_size), indices[:, 0]]
state2 = latent_m[torch.arange(batch_size), indices[:, 1]]
mixed_state1 = state1[shuffled_indices(batch_size)]

# same video, different frame
same = self.discriminator(torch.cat([state1, state2], dim=1))
# diff = self.discriminator(torch.cat([state1, mixed_state1], dim=1))

loss_g_basic = self.criterion(
y,
xp,
latent=latent_c,
cycled_latent=cycled_latent,
)
# same == onesなら、同じビデオと見破られたことになるため、state encoderのロスは最大となる
# diff == zerosなら、異なるビデオと見破られたことになるため、state encoderのロスは最大となる
loss_g_adv = self.criterion_g(
same, torch.zeros_like(same)
) # + self.criterion_g(diff, torch.ones_like(diff))
Expand All @@ -135,22 +133,32 @@ def train(
running_loss_g_adv += loss_g_adv.item()
running_loss_g += loss_g.item()

self.optimizer_d.zero_grad()
# same == onesなら、同じビデオと見破ったことになるため、discriminatorのロスは最小となる
# diff == zerosなら、異なるビデオと見破ったことになるため、discriminatorのロスは最小となる
same = self.discriminator(
torch.cat([state1.detach(), state2.detach()], dim=1)
)
diff = self.discriminator(
torch.cat([state1.detach(), mixed_state1.detach()], dim=1)
)
loss_d_adv_same = self.criterion_d(same, torch.ones_like(same))
loss_d_adv_diff = self.criterion_d(diff, torch.zeros_like(diff))
loss_d_adv = (loss_d_adv_same + loss_d_adv_diff) / 2
loss_d_adv.backward()
self.optimizer_d.step()
for _ in range(train_discriminator):
self.optimizer_d.zero_grad()
with torch.no_grad():
y, latent_c, latent_m, cycled_latent = self.generator(xm, xp_0, xm_0)

indices = torch.randint(0, num_frames, (batch_size, 2))
state1 = latent_m[torch.arange(batch_size), indices[:, 0]]
state2 = latent_m[torch.arange(batch_size), indices[:, 1]]
mixed_state1 = state1[shuffled_indices(batch_size)]
# same video, different frame
same = self.discriminator(
torch.cat([state1.detach(), state2.detach()], dim=1)
)
# different video
diff = self.discriminator(
torch.cat([state1.detach(), mixed_state1.detach()], dim=1)
)
# same == onesなら、同じビデオと見破ったことになるため、discriminatorのロスは最小となる
loss_d_adv_same = self.criterion_d(same, torch.ones_like(same))
# diff == zerosなら、異なるビデオと見破ったことになるため、discriminatorのロスは最小となる
loss_d_adv_diff = self.criterion_d(diff, torch.zeros_like(diff))
loss_d_adv = (loss_d_adv_same + loss_d_adv_diff) / 2
loss_d_adv.backward()
self.optimizer_d.step()

running_loss_d_adv += loss_d_adv.item()
running_loss_d_adv += loss_d_adv.item()

if idx % 100 == 0:
print(
Expand All @@ -167,7 +175,7 @@ def train(
running_loss_g /= len(train_loader)
running_loss_g_basic /= len(train_loader)
running_loss_g_adv /= len(train_loader)
running_loss_d_adv /= len(train_loader)
running_loss_d_adv /= len(train_loader) * train_discriminator

self.scheduler_g.step()
self.scheduler_d.step()
Expand Down
50 changes: 24 additions & 26 deletions hrdae/models/networks/hr_dae.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@
create_aggregator3d,
)
from .motion_encoder import (
MotionEncoder1d,
MotionEncoder2d,
MotionEncoder1dOption,
MotionEncoder2dOption,
create_motion_encoder1d,
create_motion_encoder2d,
)
Expand All @@ -41,17 +41,14 @@ class HRDAE3dOption(RDAE3dOption):


def create_hrdae2d(out_channels: int, opt: HRDAE2dOption) -> nn.Module:
motion_encoder = create_motion_encoder1d(
opt.latent_dim, opt.debug_show_dim, opt.motion_encoder
)
if opt.cycle:
return CycleHRDAE2d(
opt.in_channels,
out_channels,
opt.hidden_channels,
opt.latent_dim,
opt.conv_params,
motion_encoder,
opt.motion_encoder,
opt.activation,
opt.aggregator,
opt.connection_aggregation,
Expand All @@ -63,7 +60,7 @@ def create_hrdae2d(out_channels: int, opt: HRDAE2dOption) -> nn.Module:
opt.hidden_channels,
opt.latent_dim,
opt.conv_params,
motion_encoder,
opt.motion_encoder,
opt.activation,
opt.aggregator,
opt.connection_aggregation,
Expand All @@ -72,17 +69,14 @@ def create_hrdae2d(out_channels: int, opt: HRDAE2dOption) -> nn.Module:


def create_hrdae3d(out_channels: int, opt: HRDAE3dOption) -> nn.Module:
motion_encoder = create_motion_encoder2d(
opt.latent_dim, opt.debug_show_dim, opt.motion_encoder
)
if opt.cycle:
return CycleHRDAE3d(
opt.in_channels,
out_channels,
opt.hidden_channels,
opt.latent_dim,
opt.conv_params,
motion_encoder,
opt.motion_encoder,
opt.activation,
opt.aggregator,
opt.connection_aggregation,
Expand All @@ -94,7 +88,7 @@ def create_hrdae3d(out_channels: int, opt: HRDAE3dOption) -> nn.Module:
opt.hidden_channels,
opt.latent_dim,
opt.conv_params,
motion_encoder,
opt.motion_encoder,
opt.activation,
opt.aggregator,
opt.connection_aggregation,
Expand Down Expand Up @@ -248,34 +242,36 @@ def __init__(
hidden_channels: int,
latent_dim: int,
conv_params: list[dict[str, list[int]]],
motion_encoder: MotionEncoder1d,
motion_encoder: MotionEncoder1dOption,
activation: str,
aggregator: str,
connection_aggregation: str,
debug_show_dim: bool = False,
) -> None:
super().__init__()

dec_hidden_channels = hidden_channels
if aggregator == "concatenation":
dec_hidden_channels += latent_dim
self.content_encoder = HierarchicalEncoder2d(
in_channels,
hidden_channels,
latent_dim,
conv_params + [IdenticalConvBlockConvParams],
debug_show_dim,
)
self.motion_encoder = motion_encoder
self.motion_encoder = create_motion_encoder1d(motion_encoder)
dec_latent_dim = latent_dim
dec_hidden_channels = hidden_channels
if aggregator == "concatenation":
dec_latent_dim += motion_encoder.latent_dim
dec_hidden_channels += motion_encoder.latent_dim
self.decoder = HierarchicalDecoder2d(
out_channels,
dec_hidden_channels,
2 * latent_dim if aggregator == "concatenation" else latent_dim,
dec_latent_dim,
conv_params[::-1],
connection_aggregation,
debug_show_dim,
)
self.aggregator = create_aggregator2d(aggregator, latent_dim, latent_dim)
self.aggregator = create_aggregator2d(aggregator, latent_dim, motion_encoder.latent_dim)
# motion guided connection
# (Mutual Suppression Network for Video Prediction using Disentangled Features)
self.mgc = nn.ModuleList()
Expand Down Expand Up @@ -349,34 +345,36 @@ def __init__(
hidden_channels: int,
latent_dim: int,
conv_params: list[dict[str, list[int]]],
motion_encoder: MotionEncoder2d,
motion_encoder: MotionEncoder2dOption,
activation: str,
aggregator: str,
connection_aggregation: str,
debug_show_dim: bool = False,
) -> None:
super().__init__()

dec_hidden_channels = hidden_channels
if aggregator == "concatenation":
dec_hidden_channels += latent_dim
self.content_encoder = HierarchicalEncoder3d(
in_channels,
hidden_channels,
latent_dim,
conv_params + [IdenticalConvBlockConvParams],
debug_show_dim,
)
self.motion_encoder = motion_encoder
self.motion_encoder = create_motion_encoder2d(motion_encoder)
dec_latent_dim = latent_dim
dec_hidden_channels = hidden_channels
if aggregator == "concatenation":
dec_latent_dim += motion_encoder.latent_dim
dec_hidden_channels += motion_encoder.latent_dim
self.decoder = HierarchicalDecoder3d(
out_channels,
dec_hidden_channels,
2 * latent_dim if aggregator == "concatenation" else latent_dim,
dec_latent_dim,
conv_params[::-1],
connection_aggregation,
debug_show_dim,
)
self.aggregator = create_aggregator2d(aggregator, latent_dim, latent_dim)
self.aggregator = create_aggregator2d(aggregator, latent_dim, motion_encoder.latent_dim)
# motion guided connection
# (Mutual Suppression Network for Video Prediction using Disentangled Features)
self.mgc = nn.ModuleList()
Expand Down
37 changes: 18 additions & 19 deletions hrdae/models/networks/motion_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ class MotionEncoder1dOption:
# num slices * (1 (agg=diff|none) | 2 (phase=0, t) | 3 (phase=all))
in_channels: int
hidden_channels: int
latent_dim: int
conv_params: list[dict[str, list[int]]] = field(
default_factory=lambda: [{"kernel_size": [3], "stride": [2], "padding": [1]}]
* 3,
Expand All @@ -27,12 +28,14 @@ class MotionEncoder1dOption:
default_factory=lambda: [{"kernel_size": [3], "stride": [2], "padding": [1]}]
* 3,
)
debug_show_dim: bool = False


@dataclass
class MotionEncoder2dOption:
in_channels: int
hidden_channels: int
latent_dim: int
conv_params: list[dict[str, list[int]]] = field(
default_factory=lambda: [{"kernel_size": [3], "stride": [2], "padding": [1]}]
* 3,
Expand All @@ -41,23 +44,22 @@ class MotionEncoder2dOption:
default_factory=lambda: [{"kernel_size": [3], "stride": [2], "padding": [1]}]
* 3,
)
debug_show_dim: bool = False


def create_motion_encoder1d(
latent_dim: int, debug_show_dim: bool, opt: MotionEncoder1dOption
) -> "MotionEncoder1d":
def create_motion_encoder1d(opt: MotionEncoder1dOption) -> "MotionEncoder1d":
if (
isinstance(opt, MotionRNNEncoder1dOption)
and type(opt) is MotionRNNEncoder1dOption
):
return MotionRNNEncoder1d(
opt.in_channels,
opt.hidden_channels,
latent_dim,
opt.latent_dim,
opt.conv_params,
opt.deconv_params,
create_rnn1d(opt.hidden_channels, opt.rnn),
debug_show_dim,
opt.debug_show_dim,
)
if (
isinstance(opt, MotionNormalEncoder1dOption)
Expand All @@ -66,10 +68,10 @@ def create_motion_encoder1d(
return MotionNormalEncoder1d(
opt.in_channels,
opt.hidden_channels,
latent_dim,
opt.latent_dim,
opt.conv_params,
opt.deconv_params,
debug_show_dim,
opt.debug_show_dim,
)
if (
isinstance(opt, MotionConv2dEncoder1dOption)
Expand All @@ -78,29 +80,27 @@ def create_motion_encoder1d(
return MotionConv2dEncoder1d(
opt.in_channels,
opt.hidden_channels,
latent_dim,
opt.latent_dim,
opt.conv_params,
opt.deconv_params,
debug_show_dim,
opt.debug_show_dim,
)
raise NotImplementedError(f"{opt.__class__.__name__} not implemented")


def create_motion_encoder2d(
latent_dim: int, debug_show_dim: bool, opt: MotionEncoder2dOption
) -> "MotionEncoder2d":
def create_motion_encoder2d(opt: MotionEncoder2dOption) -> "MotionEncoder2d":
if (
isinstance(opt, MotionRNNEncoder2dOption)
and type(opt) is MotionRNNEncoder2dOption
):
return MotionRNNEncoder2d(
opt.in_channels,
opt.hidden_channels,
latent_dim,
opt.latent_dim,
opt.conv_params,
opt.deconv_params,
create_rnn2d(opt.hidden_channels, opt.rnn),
debug_show_dim,
opt.debug_show_dim,
)
if (
isinstance(opt, MotionNormalEncoder2dOption)
Expand All @@ -109,10 +109,10 @@ def create_motion_encoder2d(
return MotionNormalEncoder2d(
opt.in_channels,
opt.hidden_channels,
latent_dim,
opt.latent_dim,
opt.conv_params,
opt.deconv_params,
debug_show_dim,
opt.debug_show_dim,
)
if (
isinstance(opt, MotionConv3dEncoder2dOption)
Expand All @@ -121,10 +121,10 @@ def create_motion_encoder2d(
return MotionConv3dEncoder2d(
opt.in_channels,
opt.hidden_channels,
latent_dim,
opt.latent_dim,
opt.conv_params,
opt.deconv_params,
debug_show_dim,
opt.debug_show_dim,
)
raise NotImplementedError(f"{opt.__class__.__name__} not implemented")

Expand Down Expand Up @@ -382,7 +382,6 @@ def forward(
x = x.reshape(b * t, *x.size()[2:])
y = self.cnn(x)
y = y.reshape(b, t, *y.size()[1:])
print(y.shape)
y, _ = self.rnn(y)
y = y.reshape(b * t, *y.size()[2:])
y = y.unsqueeze(-1)
Expand Down
8 changes: 2 additions & 6 deletions hrdae/models/networks/r_ae.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,9 +47,7 @@ class RAE3dOption(NetworkOption):


def create_rae2d(out_channels: int, opt: RAE2dOption) -> nn.Module:
motion_encoder = create_motion_encoder1d(
opt.latent_dim, opt.debug_show_dim, opt.motion_encoder
)
motion_encoder = create_motion_encoder1d(opt.motion_encoder)
return RAE2d(
out_channels,
opt.hidden_channels,
Expand All @@ -63,9 +61,7 @@ def create_rae2d(out_channels: int, opt: RAE2dOption) -> nn.Module:


def create_rae3d(out_channels: int, opt: RAE3dOption) -> nn.Module:
motion_encoder = create_motion_encoder2d(
opt.latent_dim, opt.debug_show_dim, opt.motion_encoder
)
motion_encoder = create_motion_encoder2d(opt.motion_encoder)
return RAE3d(
out_channels,
opt.hidden_channels,
Expand Down
Loading

0 comments on commit f46f219

Please sign in to comment.