Skip to content

Commit

Permalink
fix bugs in gan model
Browse files Browse the repository at this point in the history
  • Loading branch information
nnaakkaaii committed Jul 14, 2024
1 parent c361269 commit 6fb53b5
Show file tree
Hide file tree
Showing 4 changed files with 36 additions and 34 deletions.
10 changes: 0 additions & 10 deletions hrdae/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,16 +257,6 @@
name="discriminator3d",
node=Discriminator3dOption,
)
cs.store(
group="config/experiment/model/generator",
name="autoencoder2d",
node=AutoEncoder2dNetworkOption,
)
cs.store(
group="config/experiment/model/generator",
name="autoencoder3d",
node=AutoEncoder3dNetworkOption,
)
cs.store(
group="config/experiment/model/optimizer", name="adam", node=AdamOptimizerOption
)
Expand Down
10 changes: 3 additions & 7 deletions hrdae/models/gan_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

@dataclass
class GANModelOption(ModelOption):
generator: NetworkOption = MISSING
network: NetworkOption = MISSING
discriminator: NetworkOption = MISSING
optimizer_g: OptimizerOption = MISSING
optimizer_d: OptimizerOption = MISSING
Expand All @@ -29,7 +29,6 @@ class GANModelOption(ModelOption):
loss_coef: dict[str, float] = MISSING
loss_g: LossOption = MISSING
loss_d: LossOption = MISSING
serialize: bool = False


class GANModel(Model):
Expand All @@ -44,7 +43,6 @@ def __init__(
criterion: nn.Module,
criterion_g: nn.Module,
criterion_d: nn.Module,
serialize: bool = False,
) -> None:
self.generator = generator
self.discriminator = discriminator
Expand All @@ -55,7 +53,6 @@ def __init__(
self.criterion = criterion
self.criterion_g = criterion_g
self.criterion_d = criterion_d
self.serialize = serialize

if torch.cuda.is_available():
print("GPU is enabled")
Expand All @@ -77,7 +74,7 @@ def train(
max_iter = None
if debug:
max_iter = 5
adv_ratio = 0.01
adv_ratio = 0.1

least_val_loss_g = float("inf")
training_history: dict[str, list[dict[str, int | float]]] = {"history": []}
Expand Down Expand Up @@ -326,7 +323,7 @@ def create_gan_model(
n_epoch: int,
steps_per_epoch: int,
) -> Model:
generator = create_network(1, opt.generator)
generator = create_network(1, opt.network)
discriminator = create_network(2, opt.discriminator)
optimizer_g = create_optimizer(
opt.optimizer_g,
Expand Down Expand Up @@ -364,5 +361,4 @@ def create_gan_model(
criterion,
criterion_g,
criterion_d,
opt.serialize,
)
48 changes: 32 additions & 16 deletions notebook/mmnist.ipynb

Large diffs are not rendered by default.

0 comments on commit 6fb53b5

Please sign in to comment.