Skip to content

Commit

Permalink
remove in_channel parameter
Browse files Browse the repository at this point in the history
  • Loading branch information
nnaakkaaii committed Jun 15, 2024
1 parent 033c301 commit d612fe0
Show file tree
Hide file tree
Showing 30 changed files with 119 additions and 587 deletions.
24 changes: 0 additions & 24 deletions hrdae/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,6 @@
)
from .models.networks import (
AutoEncoder2dNetworkOption,
Discriminator2dOption,
Discriminator3dOption,
FiveBranchAutoencoder2dOption,
FiveBranchAutoencoder3dOption,
HRDAE2dOption,
HRDAE3dOption,
RAE2dOption,
Expand Down Expand Up @@ -129,26 +125,6 @@
name="autoencoder2d",
node=AutoEncoder2dNetworkOption,
)
cs.store(
group="config/experiment/model/network",
name="discriminator2d",
node=Discriminator2dOption,
)
cs.store(
group="config/experiment/model/network",
name="discriminator3d",
node=Discriminator3dOption,
)
cs.store(
group="config/experiment/model/network",
name="fb_autoencoder2d",
node=FiveBranchAutoencoder2dOption,
)
cs.store(
group="config/experiment/model/network",
name="fb_autoencoder3d",
node=FiveBranchAutoencoder3dOption,
)
cs.store(group="config/experiment/model/network", name="hrdae2d", node=HRDAE2dOption)
cs.store(group="config/experiment/model/network", name="hrdae3d", node=HRDAE3dOption)
cs.store(group="config/experiment/model/network", name="rae2d", node=RAE2dOption)
Expand Down
2 changes: 1 addition & 1 deletion hrdae/conf/config.yaml
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
task_name: mmnist-hrdae2d-conv2d
task_name: mmnist-hrdae2d-guided1d
defaults:
- config_schema
- experiment: train
Expand Down
1 change: 0 additions & 1 deletion hrdae/conf/experiment/model/network/autoencoder2d.yaml
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
in_channels: 1
latent_dim: 16
conv_params:
- kernel_size: [3]
Expand Down
18 changes: 0 additions & 18 deletions hrdae/conf/experiment/model/network/discriminator2d.yaml

This file was deleted.

22 changes: 0 additions & 22 deletions hrdae/conf/experiment/model/network/discriminator3d.yaml

This file was deleted.

34 changes: 0 additions & 34 deletions hrdae/conf/experiment/model/network/fb_autoencoder2d.yaml

This file was deleted.

42 changes: 0 additions & 42 deletions hrdae/conf/experiment/model/network/fb_autoencoder3d.yaml

This file was deleted.

4 changes: 1 addition & 3 deletions hrdae/conf/experiment/model/network/hrdae2d.yaml
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
in_channels: 2
out_channels: 1
latent_dim: 32
conv_params:
- kernel_size: [3]
Expand All @@ -19,4 +17,4 @@ activation: sigmoid
debug_show_dim: false
defaults:
- /config/experiment/model/network/hrdae2d@_here_
- motion_encoder: conv2d
- motion_encoder: guided1d
2 changes: 0 additions & 2 deletions hrdae/conf/experiment/model/network/hrdae3d.yaml
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
in_channels: 2
out_channels: 1
latent_dim: 64
conv_params:
- kernel_size: [3]
Expand Down
6 changes: 6 additions & 0 deletions hrdae/conf/experiment/model/network/motion_encoder/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
# motion_encoder

note

* phase should be "0" or "t" in vr.yaml to use "guided" or "tsn" motion_encoder
* in_channels should be same as the number of slices in the dataset
2 changes: 0 additions & 2 deletions hrdae/conf/experiment/model/network/rae2d.yaml
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
in_channels: 2
out_channels: 1
latent_dim: 32
conv_params:
- kernel_size: [3]
Expand Down
2 changes: 0 additions & 2 deletions hrdae/conf/experiment/model/network/rae3d.yaml
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
in_channels: 2
out_channels: 1
latent_dim: 64
conv_params:
- kernel_size: [3]
Expand Down
2 changes: 0 additions & 2 deletions hrdae/conf/experiment/model/network/rdae2d.yaml
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
in_channels: 2
out_channels: 1
latent_dim: 32
conv_params:
- kernel_size: [3]
Expand Down
2 changes: 0 additions & 2 deletions hrdae/conf/experiment/model/network/rdae3d.yaml
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
in_channels: 2
out_channels: 1
latent_dim: 64
conv_params:
- kernel_size: [3]
Expand Down
2 changes: 2 additions & 0 deletions hrdae/conf/experiment/model/vr.yaml
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
loss_coef:
wmse: 0.5
pjc2d: 0.5
phase: "0"
pred_diff: false
defaults:
- /config/experiment/model/vr@_here_
- [email protected]: wmse
Expand Down
2 changes: 1 addition & 1 deletion hrdae/conf/experiment/test.yaml
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
run_name: 2024-06-16-0152
run_name: 2024-06-16-0158
result_dir: results
defaults:
- /config/experiment/test@_here_
Expand Down
2 changes: 1 addition & 1 deletion hrdae/conf/experiment/train.yaml
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
run_name: 2024-06-16-0152
run_name: 2024-06-16-0158
result_dir: results
n_epoch: 30
debug: false
Expand Down
2 changes: 1 addition & 1 deletion hrdae/models/basic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ def create_basic_model(
n_epoch: int,
steps_per_epoch: int,
) -> Model:
network = create_network(opt.network)
network = create_network(1, 1, opt.network)
optimizer = create_optimizer(
opt.optimizer,
network.parameters(),
Expand Down
44 changes: 10 additions & 34 deletions hrdae/models/networks/__init__.py
Original file line number Diff line number Diff line change
@@ -1,61 +1,37 @@
from torch import nn

from .autoencoder import AutoEncoder2dNetworkOption, create_autoencoder2d
from .discriminator import (
Discriminator2dOption,
Discriminator3dOption,
create_discriminator2d,
create_discriminator3d,
)
from .fb_autoencoder import (
FiveBranchAutoencoder2d,
FiveBranchAutoencoder2dOption,
FiveBranchAutoencoder3dOption,
create_fb_autoencoder3d,
)
from .hr_dae import HRDAE2dOption, HRDAE3dOption, create_hrdae2d, create_hrdae3d
from .option import NetworkOption
from .r_ae import RAE2dOption, RAE3dOption, create_rae2d, create_rae3d
from .r_dae import RDAE2dOption, RDAE3dOption, create_rdae2d, create_rdae3d


def create_network(opt: NetworkOption) -> nn.Module:
def create_network(
in_channels: int, out_channels: int, opt: NetworkOption
) -> nn.Module:
if (
isinstance(opt, AutoEncoder2dNetworkOption)
and type(opt) is AutoEncoder2dNetworkOption
):
return create_autoencoder2d(opt)
if isinstance(opt, Discriminator2dOption) and type(opt) is Discriminator2dOption:
return create_discriminator2d(opt)
if isinstance(opt, Discriminator3dOption) and type(opt) is Discriminator3dOption:
return create_discriminator3d(opt)
if (
isinstance(opt, FiveBranchAutoencoder3dOption)
and type(opt) is FiveBranchAutoencoder3dOption
):
return create_fb_autoencoder3d(opt)
return create_autoencoder2d(out_channels, opt)
if isinstance(opt, HRDAE2dOption) and type(opt) is HRDAE2dOption:
return create_hrdae2d(opt)
return create_hrdae2d(in_channels, out_channels, opt)
if isinstance(opt, HRDAE3dOption) and type(opt) is HRDAE3dOption:
return create_hrdae3d(opt)
return create_hrdae3d(in_channels, out_channels, opt)
if isinstance(opt, RAE2dOption) and type(opt) is RAE2dOption:
return create_rae2d(opt)
return create_rae2d(in_channels, out_channels, opt)
if isinstance(opt, RAE3dOption) and type(opt) is RAE3dOption:
return create_rae3d(opt)
return create_rae3d(in_channels, out_channels, opt)
if isinstance(opt, RDAE2dOption) and type(opt) is RDAE2dOption:
return create_rdae2d(opt)
return create_rdae2d(in_channels, out_channels, opt)
if isinstance(opt, RDAE3dOption) and type(opt) is RDAE3dOption:
return create_rdae3d(opt)
return create_rdae3d(in_channels, out_channels, opt)
raise NotImplementedError(f"network {opt.__class__.__name__} not implemented")


__all__ = [
"AutoEncoder2dNetworkOption",
"Discriminator2dOption",
"Discriminator3dOption",
"FiveBranchAutoencoder2d",
"FiveBranchAutoencoder2dOption",
"FiveBranchAutoencoder3dOption",
"HRDAE2dOption",
"HRDAE3dOption",
"RAE2dOption",
Expand Down
7 changes: 4 additions & 3 deletions hrdae/models/networks/autoencoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@

@dataclass
class AutoEncoder2dNetworkOption(NetworkOption):
in_channels: int = 1
latent_dim: int = 64
conv_params: list[dict[str, list[int]]] = field(
default_factory=lambda: [
Expand All @@ -19,9 +18,11 @@ class AutoEncoder2dNetworkOption(NetworkOption):
debug_show_dim: bool = False


def create_autoencoder2d(opt: AutoEncoder2dNetworkOption) -> nn.Module:
def create_autoencoder2d(
out_channels: int, opt: AutoEncoder2dNetworkOption
) -> nn.Module:
return AutoEncoder2d(
in_channels=opt.in_channels,
in_channels=out_channels,
latent_dim=opt.latent_dim,
conv_params=opt.conv_params,
activation=opt.activation,
Expand Down
51 changes: 0 additions & 51 deletions hrdae/models/networks/discriminator.py

This file was deleted.

Loading

0 comments on commit d612fe0

Please sign in to comment.