Skip to content

Commit

Permalink
fix bugs
Browse files Browse the repository at this point in the history
  • Loading branch information
nnaakkaaii committed Jul 15, 2024
1 parent 95d06b4 commit 7896325
Show file tree
Hide file tree
Showing 7 changed files with 7 additions and 95 deletions.
3 changes: 1 addition & 2 deletions hrdae/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
ToTensorOption,
UniformShape3dOption,
)
from .models import BasicModelOption, GANModelOption, PVRModelOption, VRModelOption
from .models import BasicModelOption, GANModelOption, VRModelOption
from .models.losses import (
BCEWithLogitsLossOption,
ContrastiveLossOption,
Expand Down Expand Up @@ -139,7 +139,6 @@
)
cs.store(group="config/experiment/model", name="basic", node=BasicModelOption)
cs.store(group="config/experiment/model", name="vr", node=VRModelOption)
cs.store(group="config/experiment/model", name="pvr", node=PVRModelOption)
cs.store(group="config/experiment/model", name="gan", node=GANModelOption)
cs.store(group="config/experiment/model/loss", name="mse", node=MSELossOption)
cs.store(
Expand Down
3 changes: 0 additions & 3 deletions hrdae/models/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from .basic_model import BasicModelOption, create_basic_model
from .gan_model import GANModelOption, create_gan_model
from .option import ModelOption
from .pvr_model import PVRModelOption, create_pvr_model
from .vr_model import VRModelOption, create_vr_model


Expand All @@ -14,8 +13,6 @@ def create_model(
return create_basic_model(opt, n_epoch, steps_per_epoch)
if isinstance(opt, VRModelOption) and type(opt) is VRModelOption:
return create_vr_model(opt, n_epoch, steps_per_epoch)
if isinstance(opt, PVRModelOption) and type(opt) is PVRModelOption:
return create_pvr_model(opt, n_epoch, steps_per_epoch)
if isinstance(opt, GANModelOption) and type(opt) is GANModelOption:
return create_gan_model(opt, n_epoch, steps_per_epoch)
raise NotImplementedError(f"{opt.__class__.__name__} not implemented")
4 changes: 2 additions & 2 deletions hrdae/models/gan_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ def train(
# diff == zerosなら、異なるビデオと見破られたことになるため、state encoderのロスは最大となる
loss_g_adv = self.criterion_g(
same, torch.zeros_like(same)
) # + self.criterion_g(diff, torch.ones_like(diff))
) # + self.criterion_g(diff, torch.ones_like(diff))

loss_g = loss_g_basic + adv_ratio * loss_g_adv
loss_g.backward()
Expand Down Expand Up @@ -212,7 +212,7 @@ def train(
)
loss_g_adv = self.criterion_g(
same, torch.zeros_like(same)
) # + self.criterion_g(diff, torch.ones_like(diff))
) # + self.criterion_g(diff, torch.ones_like(diff))

loss_g = loss_g_basic + adv_ratio * loss_g_adv
loss_d_adv_same = self.criterion_d(same, torch.ones_like(same))
Expand Down
3 changes: 2 additions & 1 deletion hrdae/models/networks/discriminator.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,8 @@ def __init__(
nn.BatchNorm1d(hidden_channels),
nn.ReLU(),
nn.Dropout1d(dropout_rate),
] * fc_layer,
]
* fc_layer,
nn.Linear(hidden_channels, out_channels),
)

Expand Down
87 changes: 0 additions & 87 deletions hrdae/models/pvr_model.py

This file was deleted.

1 change: 1 addition & 0 deletions test/models/test_gan_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,7 @@ def test_basic_model():

model = GANModel(
generator,
"",
discriminator,
optimizer_g,
optimizer_d,
Expand Down
1 change: 1 addition & 0 deletions test/models/test_vr_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,7 @@ def test_vr_model():

model = VRModel(
network,
"",
optimizer,
scheduler,
criterion,
Expand Down

0 comments on commit 7896325

Please sign in to comment.