Skip to content

Commit

Permalink
update hierarchical decoder latent dim
Browse files Browse the repository at this point in the history
  • Loading branch information
nnaakkaaii committed Jun 27, 2024
1 parent 74859c4 commit 3ddbe37
Showing 1 changed file with 12 additions and 6 deletions.
18 changes: 12 additions & 6 deletions hrdae/models/networks/modules/conv_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,14 @@ def __init__(
super().__init__()
assert len(conv_params) > 1

d = 2 if aggregation == "concatenation" else 1

self.layers = nn.ModuleList()
for i, conv_param in enumerate(conv_params[:-1]):
self.layers.append(
nn.Sequential(
ConvBlock1d(
2 * latent_dim if i > 0 else 2 * in_channels,
d * latent_dim if i > 0 else d * in_channels,
latent_dim, # 2*latent_dim ?
kernel_size=conv_param["kernel_size"],
stride=conv_param["stride"],
Expand All @@ -46,7 +48,7 @@ def __init__(
)
self.layers.append(
ConvBlock1d(
2 * latent_dim,
d * latent_dim,
out_channels,
kernel_size=conv_params[-1]["kernel_size"],
stride=conv_params[-1]["stride"],
Expand Down Expand Up @@ -78,12 +80,14 @@ def __init__(
super().__init__()
assert len(conv_params) > 1

d = 2 if aggregation == "concatenation" else 1

self.layers = nn.ModuleList()
for i, conv_param in enumerate(conv_params[:-1]):
self.layers.append(
nn.Sequential(
ConvBlock2d(
2 * latent_dim if i > 0 else 2 * in_channels,
d * latent_dim if i > 0 else d * in_channels,
latent_dim, # 2*latent_dim ?
kernel_size=conv_param["kernel_size"],
stride=conv_param["stride"],
Expand All @@ -100,7 +104,7 @@ def __init__(
)
self.layers.append(
ConvBlock2d(
2 * latent_dim,
d * latent_dim,
out_channels,
kernel_size=conv_params[-1]["kernel_size"],
stride=conv_params[-1]["stride"],
Expand Down Expand Up @@ -132,12 +136,14 @@ def __init__(
super().__init__()
assert len(conv_params) > 1

d = 2 if aggregation == "concatenation" else 1

self.layers = nn.ModuleList()
for i, conv_param in enumerate(conv_params[:-1]):
self.layers.append(
nn.Sequential(
ConvBlock3d(
2 * latent_dim if i > 0 else 2 * in_channels,
d * latent_dim if i > 0 else d * in_channels,
latent_dim, # 2*latent_dim ?
kernel_size=conv_param["kernel_size"],
stride=conv_param["stride"],
Expand All @@ -154,7 +160,7 @@ def __init__(
)
self.layers.append(
ConvBlock3d(
2 * latent_dim,
d * latent_dim,
out_channels,
kernel_size=conv_params[-1]["kernel_size"],
stride=conv_params[-1]["stride"],
Expand Down

0 comments on commit 3ddbe37

Please sign in to comment.