diff --git a/hrdae/models/networks/hr_dae.py b/hrdae/models/networks/hr_dae.py index 101304c..fcd7a03 100644 --- a/hrdae/models/networks/hr_dae.py +++ b/hrdae/models/networks/hr_dae.py @@ -32,12 +32,12 @@ @dataclass class HRDAE2dOption(RDAE2dOption): - pass + connection_aggregation: str = "concatenation" @dataclass class HRDAE3dOption(RDAE3dOption): - pass + connection_aggregation: str = "concatenation" def create_hrdae2d(out_channels: int, opt: HRDAE2dOption) -> nn.Module: @@ -54,6 +54,7 @@ def create_hrdae2d(out_channels: int, opt: HRDAE2dOption) -> nn.Module: motion_encoder, opt.activation, opt.aggregator, + opt.connection_aggregation, opt.debug_show_dim, ) return HRDAE2d( @@ -65,6 +66,7 @@ def create_hrdae2d(out_channels: int, opt: HRDAE2dOption) -> nn.Module: motion_encoder, opt.activation, opt.aggregator, + opt.connection_aggregation, opt.debug_show_dim, ) @@ -167,6 +169,7 @@ def __init__( hidden_channels: int, latent_dim: int, conv_params: list[dict[str, list[int]]], + aggregation: str, debug_show_dim: bool = False, ) -> None: super().__init__() @@ -180,6 +183,7 @@ def __init__( out_channels, hidden_channels, conv_params, + aggregation, debug_show_dim, ) self.debug_show_dim = debug_show_dim @@ -203,6 +207,7 @@ def __init__( hidden_channels: int, latent_dim: int, conv_params: list[dict[str, list[int]]], + aggregation: str, debug_show_dim: bool = False, ) -> None: super().__init__() @@ -216,6 +221,7 @@ def __init__( out_channels, hidden_channels, conv_params, + aggregation, debug_show_dim, ) self.debug_show_dim = debug_show_dim @@ -243,6 +249,7 @@ def __init__( motion_encoder: MotionEncoder1d, activation: str, aggregator: str, + connection_aggregation: str, debug_show_dim: bool = False, ) -> None: super().__init__() @@ -263,6 +270,7 @@ def __init__( dec_hidden_channels, 2 * latent_dim if aggregator == "concatenation" else latent_dim, conv_params[::-1], + connection_aggregation, debug_show_dim, ) self.aggregator = create_aggregator2d(aggregator, latent_dim, latent_dim) @@ -342,6 +350,7 @@ def __init__( motion_encoder: MotionEncoder2d, activation: str, aggregator: str, + connection_aggregation: str, debug_show_dim: bool = False, ) -> None: super().__init__() @@ -362,6 +371,7 @@ def __init__( dec_hidden_channels, 2 * latent_dim if aggregator == "concatenation" else latent_dim, conv_params[::-1], + connection_aggregation, debug_show_dim, ) self.aggregator = create_aggregator2d(aggregator, latent_dim, latent_dim) diff --git a/hrdae/models/networks/modules/conv_block.py b/hrdae/models/networks/modules/conv_block.py index 7c38e0e..4d6e10f 100644 --- a/hrdae/models/networks/modules/conv_block.py +++ b/hrdae/models/networks/modules/conv_block.py @@ -309,6 +309,7 @@ class ConvModuleBase(nn.Module): layers: nn.ModuleList use_skip: bool debug_show_dim: bool + aggregation: str def _forward( self, x: Tensor, hs: list[Tensor] | None = None @@ -323,7 +324,12 @@ def _forward( for i, layer in enumerate(self.layers): if self.use_skip: assert hs is not None - x = cat([x, hs[i]], dim=1) + if self.aggregation == "concatenation": + x = cat([x, hs[i]], dim=1) + elif self.aggregation == "addition": + x = x + hs[i] + else: + raise ValueError(f"Invalid aggregation: {self.aggregation}") x = layer(x) if self.debug_show_dim: print(f"{self.__class__.__name__} Layer {i}", x.size()) @@ -381,6 +387,7 @@ def __init__( self.debug_show_dim = debug_show_dim self.use_skip = False + self.aggregation = "concatenation" def forward(self, x: Tensor) -> Tensor: return self._forward(x)[0] @@ -435,6 +442,7 @@ def __init__( self.debug_show_dim = debug_show_dim self.use_skip = False + self.aggregation = "concatenation" def forward(self, x: Tensor) -> Tensor: return self._forward(x)[0] @@ -489,6 +497,7 @@ def __init__( self.debug_show_dim = debug_show_dim self.use_skip = False + self.aggregation = "concatenation" def forward(self, x: Tensor) -> Tensor: return self._forward(x)[0] diff --git a/hrdae/models/networks/modules/conv_decoder.py b/hrdae/models/networks/modules/conv_decoder.py index 5747e46..b2dbb88 100644 --- a/hrdae/models/networks/modules/conv_decoder.py +++ b/hrdae/models/networks/modules/conv_decoder.py @@ -18,6 +18,7 @@ def __init__( out_channels: int, latent_dim: int, conv_params: list[dict[str, list[int]]], + aggregation: str, debug_show_dim: bool = False, ) -> None: super().__init__() @@ -58,6 +59,7 @@ def __init__( self.debug_show_dim = debug_show_dim self.use_skip = True + self.aggregation = aggregation def forward(self, x: Tensor, hs: list[Tensor]) -> Tensor: return self._forward(x, hs)[0] @@ -70,6 +72,7 @@ def __init__( out_channels: int, latent_dim: int, conv_params: list[dict[str, list[int]]], + aggregation: str, debug_show_dim: bool = False, ) -> None: super().__init__() @@ -110,6 +113,7 @@ def __init__( self.debug_show_dim = debug_show_dim self.use_skip = True + self.aggregation = aggregation def forward(self, x: Tensor, hs: list[Tensor]) -> Tensor: return self._forward(x, hs)[0] @@ -122,6 +126,7 @@ def __init__( out_channels: int, latent_dim: int, conv_params: list[dict[str, list[int]]], + aggregation: str, debug_show_dim: bool = False, ) -> None: super().__init__() @@ -162,6 +167,7 @@ def __init__( self.debug_show_dim = debug_show_dim self.use_skip = True + self.aggregation = aggregation def forward(self, x: Tensor, hs: list[Tensor]) -> Tensor: return self._forward(x, hs)[0]