Skip to content

Commit

Permalink
update vr model to use pretrained
Browse files Browse the repository at this point in the history
  • Loading branch information
nnaakkaaii committed Jul 15, 2024
1 parent 11ddc82 commit 95d06b4
Show file tree
Hide file tree
Showing 4 changed files with 81 additions and 51 deletions.
34 changes: 21 additions & 13 deletions hrdae/models/gan_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
@dataclass
class GANModelOption(ModelOption):
network: NetworkOption = MISSING
network_weight: str = ""
discriminator: NetworkOption = MISSING
optimizer_g: OptimizerOption = MISSING
optimizer_d: OptimizerOption = MISSING
Expand All @@ -35,6 +36,7 @@ class GANModel(Model):
def __init__(
self,
generator: nn.Module,
generator_weight: str,
discriminator: nn.Module,
optimizer_g: Optimizer,
optimizer_d: Optimizer,
Expand All @@ -54,6 +56,9 @@ def __init__(
self.criterion_g = criterion_g
self.criterion_d = criterion_d

if generator_weight != "":
self.generator.load_state_dict(torch.load(generator_weight))

if torch.cuda.is_available():
print("GPU is enabled")
self.device = torch.device("cuda:0")
Expand Down Expand Up @@ -108,7 +113,7 @@ def train(
mixed_state1 = state1[shuffled_indices(batch_size)]

same = self.discriminator(torch.cat([state1, state2], dim=1))
diff = self.discriminator(torch.cat([state1, mixed_state1], dim=1))
# diff = self.discriminator(torch.cat([state1, mixed_state1], dim=1))

loss_g_basic = self.criterion(
y,
Expand All @@ -120,7 +125,7 @@ def train(
# diff == zerosなら、異なるビデオと見破られたことになるため、state encoderのロスは最大となる
loss_g_adv = self.criterion_g(
same, torch.zeros_like(same)
) + self.criterion_g(diff, torch.ones_like(diff))
) # + self.criterion_g(diff, torch.ones_like(diff))

loss_g = loss_g_basic + adv_ratio * loss_g_adv
loss_g.backward()
Expand All @@ -139,9 +144,9 @@ def train(
diff = self.discriminator(
torch.cat([state1.detach(), mixed_state1.detach()], dim=1)
)
loss_d_adv = self.criterion_d(
same, torch.ones_like(same)
) + self.criterion_d(diff, torch.zeros_like(diff))
loss_d_adv_same = self.criterion_d(same, torch.ones_like(same))
loss_d_adv_diff = self.criterion_d(diff, torch.zeros_like(diff))
loss_d_adv = (loss_d_adv_same + loss_d_adv_diff) / 2
loss_d_adv.backward()
self.optimizer_d.step()

Expand All @@ -152,6 +157,8 @@ def train(
f"Epoch: {epoch+1}, "
f"Batch: {idx}, "
f"Loss D Adv: {loss_d_adv.item():.6f}, "
f"Loss D Adv (same): {loss_d_adv_same.item():.6f}, "
f"Loss D Adv (diff): {loss_d_adv_diff.item():.6f}, "
f"Loss G: {loss_g.item():.6f}, "
f"Loss G Adv: {loss_g_adv.item():.6f}, "
f"Loss G Basic: {loss_g_basic.item():.6f}, "
Expand Down Expand Up @@ -194,7 +201,7 @@ def train(
mixed_state1 = state1[shuffled_indices(batch_size)]

same = self.discriminator(torch.cat([state1, state2], dim=1))
diff = self.discriminator(torch.cat([state1, mixed_state1], dim=1))
# diff = self.discriminator(torch.cat([state1, mixed_state1], dim=1))

y = y.detach().clone()
loss_g_basic = self.criterion(
Expand All @@ -205,12 +212,12 @@ def train(
)
loss_g_adv = self.criterion_g(
same, torch.zeros_like(same)
) + self.criterion_g(diff, torch.ones_like(diff))
) # + self.criterion_g(diff, torch.ones_like(diff))

loss_g = loss_g_basic + adv_ratio * loss_g_adv
loss_d_adv = self.criterion_d(
same, torch.ones_like(same)
) + self.criterion_d(diff, torch.zeros_like(diff))
loss_d_adv_same = self.criterion_d(same, torch.ones_like(same))
loss_d_adv_diff = self.criterion_d(diff, torch.zeros_like(diff))
loss_d_adv = (loss_d_adv_same + loss_d_adv_diff) / 2

total_val_loss_g += loss_g.item()
total_val_loss_g_basic += loss_g_basic.item()
Expand Down Expand Up @@ -272,6 +279,9 @@ def train(
}
)

with open(result_dir / "training_history.json", "w") as f:
json.dump(training_history, f, indent=2)

if epoch % 10 == 0:
data = next(iter(val_loader))

Expand All @@ -295,9 +305,6 @@ def train(
f"epoch_{epoch}",
)

with open(result_dir / "training_history.json", "w") as f:
json.dump(training_history, f)

return least_val_loss_g


Expand Down Expand Up @@ -353,6 +360,7 @@ def create_gan_model(
criterion_d = create_loss(opt.loss_d)
return GANModel(
generator,
opt.network_weight,
discriminator,
optimizer_g,
optimizer_d,
Expand Down
51 changes: 27 additions & 24 deletions hrdae/models/networks/discriminator.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from torch import Tensor, nn

from .modules import ConvModule2d, ConvModule3d
from .modules import ConvModule3d
from .option import NetworkOption


Expand All @@ -11,12 +11,8 @@ class Discriminator2dOption(NetworkOption):
in_channels: int = 8
hidden_channels: int = 256
image_size: list[int] = field(default_factory=lambda: [4, 4])
conv_params: list[dict[str, list[int]]] = field(
default_factory=lambda: [
{"kernel_size": [3], "stride": [2], "padding": [1], "output_padding": [1]},
]
)
debug_show_dim: bool = False
dropout_rate: float = 0.5
fc_layer: int = 3


def create_discriminator2d(opt: Discriminator2dOption) -> nn.Module:
Expand All @@ -25,8 +21,8 @@ def create_discriminator2d(opt: Discriminator2dOption) -> nn.Module:
out_channels=1,
hidden_channels=opt.hidden_channels,
image_size=opt.image_size,
conv_params=opt.conv_params,
debug_show_dim=opt.debug_show_dim,
dropout_rate=opt.dropout_rate,
fc_layer=opt.fc_layer,
)


Expand All @@ -37,29 +33,27 @@ def __init__(
out_channels: int,
hidden_channels: int,
image_size: list[int],
conv_params: list[dict[str, list[int]]],
debug_show_dim: bool,
fc_layer: int,
dropout_rate: float,
) -> None:
super().__init__()
self.cnn = ConvModule2d(
in_channels,
hidden_channels,
hidden_channels,
conv_params,
transpose=False,
act_norm=False,
debug_show_dim=debug_show_dim,
)
size = image_size[0] * image_size[1]
self.bottleneck = nn.Sequential(
nn.Linear(size * hidden_channels, hidden_channels),
self.fc = nn.Sequential(
nn.Linear(in_channels * size, hidden_channels),
nn.BatchNorm1d(hidden_channels),
nn.ReLU(),
nn.Dropout1d(dropout_rate),
*[
nn.Linear(hidden_channels, hidden_channels),
nn.BatchNorm1d(hidden_channels),
nn.ReLU(),
nn.Dropout1d(dropout_rate),
] * fc_layer,
nn.Linear(hidden_channels, out_channels),
)

def forward(self, x: Tensor) -> Tensor:
h = self.cnn(x)
z = self.bottleneck(h.reshape(h.size(0), -1))
z = self.fc(x.reshape(x.size(0), -1))
return z


Expand All @@ -73,6 +67,7 @@ class Discriminator3dOption(NetworkOption):
{"kernel_size": [3], "stride": [2], "padding": [1], "output_padding": [1]},
]
)
dropout_rate: float = 0.5
debug_show_dim: bool = False


Expand All @@ -83,6 +78,7 @@ def create_discriminator3d(opt: Discriminator3dOption) -> nn.Module:
hidden_channels=opt.hidden_channels,
image_size=opt.image_size,
conv_params=opt.conv_params,
dropout_rate=opt.dropout_rate,
debug_show_dim=opt.debug_show_dim,
)

Expand All @@ -95,6 +91,7 @@ def __init__(
hidden_channels: int,
image_size: list[int],
conv_params: list[dict[str, list[int]]],
dropout_rate: float,
debug_show_dim: bool,
) -> None:
super().__init__()
Expand All @@ -110,7 +107,13 @@ def __init__(
size = image_size[0] * image_size[1] * image_size[2]
self.bottleneck = nn.Sequential(
nn.Linear(size * hidden_channels, hidden_channels),
nn.BatchNorm1d(hidden_channels),
nn.ReLU(),
nn.Dropout1d(dropout_rate),
nn.Linear(hidden_channels, hidden_channels),
nn.BatchNorm1d(hidden_channels),
nn.ReLU(),
nn.Dropout1d(dropout_rate),
nn.Linear(hidden_channels, out_channels),
)

Expand Down
6 changes: 6 additions & 0 deletions hrdae/models/vr_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
@dataclass
class VRModelOption(ModelOption):
network: NetworkOption = MISSING
network_weight: str = ""
optimizer: OptimizerOption = MISSING
scheduler: SchedulerOption = MISSING
loss: dict[str, LossOption] = MISSING
Expand All @@ -31,6 +32,7 @@ class VRModel(Model):
def __init__(
self,
network: nn.Module,
network_weight: str,
optimizer: Optimizer,
scheduler: LRScheduler,
criterion: nn.Module,
Expand All @@ -42,6 +44,9 @@ def __init__(
self.criterion = criterion
self.use_triplet = use_triplet

if network_weight != "":
self.network.load_state_dict(torch.load(network_weight))

if torch.cuda.is_available():
print("GPU is enabled")
self.device = torch.device("cuda:0")
Expand Down Expand Up @@ -213,6 +218,7 @@ def create_vr_model(
)
return VRModel(
network,
opt.network_weight,
optimizer,
scheduler,
criterion,
Expand Down
41 changes: 27 additions & 14 deletions notebook/mmnist.ipynb

Large diffs are not rendered by default.

0 comments on commit 95d06b4

Please sign in to comment.