Skip to content

Commit

Permalink
update gan model
Browse files Browse the repository at this point in the history
  • Loading branch information
nnaakkaaii committed Jul 14, 2024
1 parent 6553a71 commit c361269
Show file tree
Hide file tree
Showing 20 changed files with 197 additions and 204 deletions.
4 changes: 3 additions & 1 deletion hrdae/models/basic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,4 +227,6 @@ def create_basic_model(
criterion = LossMixer(
{k: create_loss(v) for k, v in opt.loss.items()}, opt.loss_coef
)
return BasicModel(network, opt.network_weight, optimizer, scheduler, criterion, opt.serialize)
return BasicModel(
network, opt.network_weight, optimizer, scheduler, criterion, opt.serialize
)
99 changes: 53 additions & 46 deletions hrdae/models/gan_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from torch.optim.optimizer import Optimizer
from torch.utils.data import DataLoader

from .functions import save_model, save_reconstructed_images
from .functions import save_model, save_reconstructed_images, shuffled_indices
from .losses import LossMixer, LossOption, create_loss
from .networks import NetworkOption, create_network
from .optimizers import OptimizerOption, create_optimizer
Expand Down Expand Up @@ -77,6 +77,7 @@ def train(
max_iter = None
if debug:
max_iter = 5
adv_ratio = 0.01

least_val_loss_g = float("inf")
training_history: dict[str, list[dict[str, int | float]]] = {"history": []}
Expand All @@ -89,8 +90,6 @@ def train(
running_loss_g_basic = 0.0
running_loss_g_adv = 0.0
running_loss_d_adv = 0.0
running_loss_d_adv_fake = 0.0
running_loss_d_adv_real = 0.0

for idx, data in enumerate(train_loader):
if max_iter is not None and idx >= max_iter:
Expand All @@ -100,21 +99,33 @@ def train(
xm_0 = data["xm_0"].to(self.device)
xp = data["xp"].to(self.device)
xp_0 = data["xp_0"].to(self.device)
batch_size, num_frames = xm.size()[:2]

# train generator
self.optimizer_g.zero_grad()
y, latent, cycled_latent = self.generator(xm, xp_0, xm_0)
y, latent_c, latent_m, cycled_latent = self.generator(xm, xp_0, xm_0)

indices = torch.randint(0, num_frames, (batch_size, 2))
state1 = latent_m[torch.arange(batch_size), indices[:, 0]]
state2 = latent_m[torch.arange(batch_size), indices[:, 1]]
mixed_state1 = state1[shuffled_indices(batch_size)]

same = self.discriminator(torch.cat([state1, state2], dim=1))
diff = self.discriminator(torch.cat([state1, mixed_state1], dim=1))

y_pred = self.discriminator(y, xp)
loss_g_basic = self.criterion(
y,
xp,
latent=latent,
latent=latent_c,
cycled_latent=cycled_latent,
)
loss_g_adv = self.criterion_g(y_pred, torch.ones_like(y_pred))
# same == onesなら、同じビデオと見破られたことになるため、state encoderのロスは最大となる
# diff == zerosなら、異なるビデオと見破られたことになるため、state encoderのロスは最大となる
loss_g_adv = self.criterion_g(
same, torch.zeros_like(same)
) + self.criterion_g(diff, torch.ones_like(diff))

loss_g = loss_g_basic + 0.001 * loss_g_adv
loss_g = loss_g_basic + adv_ratio * loss_g_adv
loss_g.backward()
self.optimizer_g.step()

Expand All @@ -123,25 +134,27 @@ def train(
running_loss_g += loss_g.item()

self.optimizer_d.zero_grad()
xp_pred = self.discriminator(xp, xp)
y_pred = self.discriminator(y.detach(), xp)
loss_d_adv_real = self.criterion_d(xp_pred, torch.ones_like(xp_pred))
loss_d_adv_fake = self.criterion_d(y_pred, torch.zeros_like(y_pred))
loss_d_adv = loss_d_adv_real + loss_d_adv_fake
# same == onesなら、同じビデオと見破ったことになるため、discriminatorのロスは最小となる
# diff == zerosなら、異なるビデオと見破ったことになるため、discriminatorのロスは最小となる
same = self.discriminator(
torch.cat([state1.detach(), state2.detach()], dim=1)
)
diff = self.discriminator(
torch.cat([state1.detach(), mixed_state1.detach()], dim=1)
)
loss_d_adv = self.criterion_d(
same, torch.ones_like(same)
) + self.criterion_d(diff, torch.zeros_like(diff))
loss_d_adv.backward()
self.optimizer_d.step()

running_loss_d_adv_real += loss_d_adv_real.item()
running_loss_d_adv_fake += loss_d_adv_fake.item()
running_loss_d_adv += loss_d_adv.item()

if idx % 100 == 0:
print(
f"Epoch: {epoch+1}, "
f"Batch: {idx}, "
f"Loss D Adv: {loss_d_adv.item():.6f}, "
f"Loss D Adv (Fake): {loss_d_adv_fake.item():.6f}, "
f"Loss D Adv (Real): {loss_d_adv_real.item():.6f}, "
f"Loss G: {loss_g.item():.6f}, "
f"Loss G Adv: {loss_g_adv.item():.6f}, "
f"Loss G Basic: {loss_g_basic.item():.6f}, "
Expand All @@ -151,8 +164,6 @@ def train(
running_loss_g_basic /= len(train_loader)
running_loss_g_adv /= len(train_loader)
running_loss_d_adv /= len(train_loader)
running_loss_d_adv_real /= len(train_loader)
running_loss_d_adv_fake /= len(train_loader)

self.scheduler_g.step()
self.scheduler_d.step()
Expand All @@ -164,8 +175,6 @@ def train(
total_val_loss_g_basic = 0.0
total_val_loss_g_adv = 0.0
total_val_loss_d_adv = 0.0
total_val_loss_d_adv_fake = 0.0
total_val_loss_d_adv_real = 0.0
xp = torch.tensor([0.0], device=self.device)
y = torch.tensor([0.0], device=self.device)

Expand All @@ -177,52 +186,54 @@ def train(
xm_0 = data["xm_0"].to(self.device)
xp = data["xp"].to(self.device)
xp_0 = data["xp_0"].to(self.device)
y, cs, ds = self.generator(xm, xp_0, xm_0)
batch_size, num_frames = xm.size()[:2]
y, latent_c, latent_m, cycled_latent = self.generator(
xm, xp_0, xm_0
)

indices = torch.randint(0, num_frames, (batch_size, 2))
state1 = latent_m[torch.arange(batch_size), indices[:, 0]]
state2 = latent_m[torch.arange(batch_size), indices[:, 1]]
mixed_state1 = state1[shuffled_indices(batch_size)]

same = self.discriminator(torch.cat([state1, state2], dim=1))
diff = self.discriminator(torch.cat([state1, mixed_state1], dim=1))

y_pred = self.discriminator(y, xp)
xp_pred = self.discriminator(xp, xp)
y = y.detach().clone()
loss_g_basic = self.criterion(
y,
xp,
latent=cs,
cycled_latent=ds,
latent=latent_c,
cycled_latent=cycled_latent,
)
loss_g_adv = self.criterion_g(y_pred, torch.ones_like(y_pred))
loss_g = loss_g_basic + loss_g_adv
loss_d_adv_real = self.criterion_d(
xp_pred, torch.ones_like(xp_pred)
)
loss_d_adv_fake = self.criterion_d(y_pred, torch.zeros_like(y_pred))
loss_d_adv = loss_d_adv_fake + loss_d_adv_real
loss_g_adv = self.criterion_g(
same, torch.zeros_like(same)
) + self.criterion_g(diff, torch.ones_like(diff))

loss_g = loss_g_basic + adv_ratio * loss_g_adv
loss_d_adv = self.criterion_d(
same, torch.ones_like(same)
) + self.criterion_d(diff, torch.zeros_like(diff))

total_val_loss_g += loss_g.item()
total_val_loss_g_basic += loss_g_basic.item()
total_val_loss_g_adv += loss_g_adv.item()
total_val_loss_d_adv += loss_d_adv.item()
total_val_loss_d_adv_fake += loss_d_adv_fake.item()
total_val_loss_d_adv_real += loss_d_adv_real.item()

total_val_loss_g /= len(val_loader)
total_val_loss_g_basic /= len(val_loader)
total_val_loss_g_adv /= len(val_loader)
total_val_loss_d_adv /= len(val_loader)
total_val_loss_d_adv_fake /= len(val_loader)
total_val_loss_d_adv_real /= len(val_loader)

print(
f"Epoch: {epoch+1} "
f"[train] "
f"Loss D Adv: {running_loss_d_adv:.6f}, "
f"Loss D Adv (Fake): {running_loss_d_adv_fake:.6f}, "
f"Loss D Adv (Real): {running_loss_d_adv_real:.6f}, "
f"Loss G: {running_loss_g:.6f}, "
f"Loss G Adv: {running_loss_g_adv:.6f}, "
f"Loss G Basic: {running_loss_g_basic:.6f}, "
f"[val] "
f"Loss D Adv: {total_val_loss_d_adv:.6f}, "
f"Loss D Adv (Fake): {total_val_loss_d_adv_fake:.6f}, "
f"Loss D Adv (Real): {total_val_loss_d_adv_real:.6f}, "
f"Loss G: {total_val_loss_g:.6f}, "
f"Loss G Adv: {total_val_loss_g_adv:.6f}, "
f"Loss G Basic: {total_val_loss_g_basic:.6f}, "
Expand Down Expand Up @@ -257,14 +268,10 @@ def train(
"train_loss_g_basic": float(running_loss_g_basic),
"train_loss_g_adv": float(running_loss_g_adv),
"train_loss_d_adv": float(running_loss_d_adv),
"train_loss_d_adv_fake": float(running_loss_d_adv_fake),
"train_loss_d_adv_real": float(running_loss_d_adv_real),
"val_loss_g": float(total_val_loss_g),
"val_loss_g_basic": float(total_val_loss_g_basic),
"val_loss_g_adv": float(total_val_loss_g_adv),
"val_loss_d_adv": float(total_val_loss_d_adv),
"val_loss_d_adv_fake": float(total_val_loss_d_adv_fake),
"val_loss_d_adv_real": float(total_val_loss_d_adv_real),
}
)

Expand All @@ -276,7 +283,7 @@ def train(
xp = data["xp"].to(self.device)
xp_0 = data["xp_0"].to(self.device)

y, _, _ = self.generator(xm, xp_0, xm_0)
y, _, _, _ = self.generator(xm, xp_0, xm_0)

save_reconstructed_images(
xp.data.cpu().clone().detach().numpy()[:10],
Expand Down
4 changes: 2 additions & 2 deletions hrdae/models/networks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,9 @@

def create_network(out_channels: int, opt: NetworkOption) -> nn.Module:
if isinstance(opt, Discriminator2dOption) and type(opt) is Discriminator2dOption:
return create_discriminator2d(out_channels, opt)
return create_discriminator2d(opt)
if isinstance(opt, Discriminator3dOption) and type(opt) is Discriminator3dOption:
return create_discriminator3d(out_channels, opt)
return create_discriminator3d(opt)
if (
isinstance(opt, AEEncoder2dNetworkOption)
and type(opt) is AEEncoder2dNetworkOption
Expand Down
44 changes: 13 additions & 31 deletions hrdae/models/networks/discriminator.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
from dataclasses import dataclass, field

from torch import Tensor, arange, cat, nn, randint
from torch import Tensor, nn

from .modules import ConvModule2d, ConvModule3d
from .option import NetworkOption


@dataclass
class Discriminator2dOption(NetworkOption):
hidden_channels: int = 64
in_channels: int = 8
hidden_channels: int = 256
image_size: list[int] = field(default_factory=lambda: [4, 4])
conv_params: list[dict[str, list[int]]] = field(
default_factory=lambda: [
Expand All @@ -18,9 +19,9 @@ class Discriminator2dOption(NetworkOption):
debug_show_dim: bool = False


def create_discriminator2d(out_channels: int, opt: Discriminator2dOption) -> nn.Module:
def create_discriminator2d(opt: Discriminator2dOption) -> nn.Module:
return Discriminator2d(
in_channels=out_channels,
in_channels=opt.in_channels,
out_channels=1,
hidden_channels=opt.hidden_channels,
image_size=opt.image_size,
Expand Down Expand Up @@ -56,25 +57,16 @@ def __init__(
nn.Linear(hidden_channels, out_channels),
)

def forward(self, y: Tensor, xp: Tensor) -> Tensor:
b, n = y.size()[:2]
idx_y = randint(0, n, (b,))
idx_xp = randint(0, n, (b,))
x = cat(
[
y[arange(b), idx_y],
xp[arange(b), idx_xp],
],
dim=1,
)
def forward(self, x: Tensor) -> Tensor:
h = self.cnn(x)
z = self.bottleneck(h.reshape(b, -1))
z = self.bottleneck(h.reshape(h.size(0), -1))
return z


@dataclass
class Discriminator3dOption(NetworkOption):
hidden_channels: int = 64
in_channels: int = 8
hidden_channels: int = 256
image_size: list[int] = field(default_factory=lambda: [4, 4, 4])
conv_params: list[dict[str, list[int]]] = field(
default_factory=lambda: [
Expand All @@ -84,9 +76,9 @@ class Discriminator3dOption(NetworkOption):
debug_show_dim: bool = False


def create_discriminator3d(out_channels: int, opt: Discriminator3dOption) -> nn.Module:
def create_discriminator3d(opt: Discriminator3dOption) -> nn.Module:
return Discriminator3d(
in_channels=out_channels,
in_channels=opt.in_channels,
out_channels=1,
hidden_channels=opt.hidden_channels,
image_size=opt.image_size,
Expand Down Expand Up @@ -122,17 +114,7 @@ def __init__(
nn.Linear(hidden_channels, out_channels),
)

def forward(self, y: Tensor, xp: Tensor) -> Tensor:
b, n = y.size()[:2]
idx_y = randint(0, n, (b,))
idx_xp = randint(0, n, (b,))
x = cat(
[
y[arange(b), idx_y],
xp[arange(b), idx_xp],
],
dim=1,
)
def forward(self, x: Tensor) -> Tensor:
h = self.cnn(x)
z = self.bottleneck(h.reshape(b, -1))
z = self.bottleneck(h.reshape(h.size(0), -1))
return z
Loading

0 comments on commit c361269

Please sign in to comment.