Skip to content

Commit

Permalink
add connection_aggregation option
Browse files Browse the repository at this point in the history
  • Loading branch information
nnaakkaaii committed Jun 27, 2024
1 parent c34571c commit 4b1c795
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 3 deletions.
14 changes: 12 additions & 2 deletions hrdae/models/networks/hr_dae.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,12 +32,12 @@

@dataclass
class HRDAE2dOption(RDAE2dOption):
pass
connection_aggregation: str = "concatenation"


@dataclass
class HRDAE3dOption(RDAE3dOption):
pass
connection_aggregation: str = "concatenation"


def create_hrdae2d(out_channels: int, opt: HRDAE2dOption) -> nn.Module:
Expand All @@ -54,6 +54,7 @@ def create_hrdae2d(out_channels: int, opt: HRDAE2dOption) -> nn.Module:
motion_encoder,
opt.activation,
opt.aggregator,
opt.connection_aggregation,
opt.debug_show_dim,
)
return HRDAE2d(
Expand All @@ -65,6 +66,7 @@ def create_hrdae2d(out_channels: int, opt: HRDAE2dOption) -> nn.Module:
motion_encoder,
opt.activation,
opt.aggregator,
opt.connection_aggregation,
opt.debug_show_dim,
)

Expand Down Expand Up @@ -167,6 +169,7 @@ def __init__(
hidden_channels: int,
latent_dim: int,
conv_params: list[dict[str, list[int]]],
aggregation: str,
debug_show_dim: bool = False,
) -> None:
super().__init__()
Expand All @@ -180,6 +183,7 @@ def __init__(
out_channels,
hidden_channels,
conv_params,
aggregation,
debug_show_dim,
)
self.debug_show_dim = debug_show_dim
Expand All @@ -203,6 +207,7 @@ def __init__(
hidden_channels: int,
latent_dim: int,
conv_params: list[dict[str, list[int]]],
aggregation: str,
debug_show_dim: bool = False,
) -> None:
super().__init__()
Expand All @@ -216,6 +221,7 @@ def __init__(
out_channels,
hidden_channels,
conv_params,
aggregation,
debug_show_dim,
)
self.debug_show_dim = debug_show_dim
Expand Down Expand Up @@ -243,6 +249,7 @@ def __init__(
motion_encoder: MotionEncoder1d,
activation: str,
aggregator: str,
connection_aggregation: str,
debug_show_dim: bool = False,
) -> None:
super().__init__()
Expand All @@ -263,6 +270,7 @@ def __init__(
dec_hidden_channels,
2 * latent_dim if aggregator == "concatenation" else latent_dim,
conv_params[::-1],
connection_aggregation,
debug_show_dim,
)
self.aggregator = create_aggregator2d(aggregator, latent_dim, latent_dim)
Expand Down Expand Up @@ -342,6 +350,7 @@ def __init__(
motion_encoder: MotionEncoder2d,
activation: str,
aggregator: str,
connection_aggregation: str,
debug_show_dim: bool = False,
) -> None:
super().__init__()
Expand All @@ -362,6 +371,7 @@ def __init__(
dec_hidden_channels,
2 * latent_dim if aggregator == "concatenation" else latent_dim,
conv_params[::-1],
connection_aggregation,
debug_show_dim,
)
self.aggregator = create_aggregator2d(aggregator, latent_dim, latent_dim)
Expand Down
11 changes: 10 additions & 1 deletion hrdae/models/networks/modules/conv_block.py
Original file line number Diff line number Diff line change
Expand Up @@ -309,6 +309,7 @@ class ConvModuleBase(nn.Module):
layers: nn.ModuleList
use_skip: bool
debug_show_dim: bool
aggregation: str

def _forward(
self, x: Tensor, hs: list[Tensor] | None = None
Expand All @@ -323,7 +324,12 @@ def _forward(
for i, layer in enumerate(self.layers):
if self.use_skip:
assert hs is not None
x = cat([x, hs[i]], dim=1)
if self.aggregation == "concatenation":
x = cat([x, hs[i]], dim=1)
elif self.aggregation == "addition":
x = x + hs[i]
else:
raise ValueError(f"Invalid aggregation: {self.aggregation}")
x = layer(x)
if self.debug_show_dim:
print(f"{self.__class__.__name__} Layer {i}", x.size())
Expand Down Expand Up @@ -381,6 +387,7 @@ def __init__(

self.debug_show_dim = debug_show_dim
self.use_skip = False
self.aggregation = "concatenation"

def forward(self, x: Tensor) -> Tensor:
return self._forward(x)[0]
Expand Down Expand Up @@ -435,6 +442,7 @@ def __init__(

self.debug_show_dim = debug_show_dim
self.use_skip = False
self.aggregation = "concatenation"

def forward(self, x: Tensor) -> Tensor:
return self._forward(x)[0]
Expand Down Expand Up @@ -489,6 +497,7 @@ def __init__(

self.debug_show_dim = debug_show_dim
self.use_skip = False
self.aggregation = "concatenation"

def forward(self, x: Tensor) -> Tensor:
return self._forward(x)[0]
6 changes: 6 additions & 0 deletions hrdae/models/networks/modules/conv_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ def __init__(
out_channels: int,
latent_dim: int,
conv_params: list[dict[str, list[int]]],
aggregation: str,
debug_show_dim: bool = False,
) -> None:
super().__init__()
Expand Down Expand Up @@ -58,6 +59,7 @@ def __init__(

self.debug_show_dim = debug_show_dim
self.use_skip = True
self.aggregation = aggregation

def forward(self, x: Tensor, hs: list[Tensor]) -> Tensor:
return self._forward(x, hs)[0]
Expand All @@ -70,6 +72,7 @@ def __init__(
out_channels: int,
latent_dim: int,
conv_params: list[dict[str, list[int]]],
aggregation: str,
debug_show_dim: bool = False,
) -> None:
super().__init__()
Expand Down Expand Up @@ -110,6 +113,7 @@ def __init__(

self.debug_show_dim = debug_show_dim
self.use_skip = True
self.aggregation = aggregation

def forward(self, x: Tensor, hs: list[Tensor]) -> Tensor:
return self._forward(x, hs)[0]
Expand All @@ -122,6 +126,7 @@ def __init__(
out_channels: int,
latent_dim: int,
conv_params: list[dict[str, list[int]]],
aggregation: str,
debug_show_dim: bool = False,
) -> None:
super().__init__()
Expand Down Expand Up @@ -162,6 +167,7 @@ def __init__(

self.debug_show_dim = debug_show_dim
self.use_skip = True
self.aggregation = aggregation

def forward(self, x: Tensor, hs: list[Tensor]) -> Tensor:
return self._forward(x, hs)[0]

0 comments on commit 4b1c795

Please sign in to comment.