diff --git a/convlstm.py b/convlstm.py index e54a085..aec831f 100644 --- a/convlstm.py +++ b/convlstm.py @@ -138,12 +138,11 @@ def forward(self, input_tensor, hidden_state=None): b, _, _, h, w = input_tensor.size() # Implement stateful ConvLSTM - if hidden_state is not None: - raise NotImplementedError() - else: + if hidden_state is None: # Since the init is done in forward. Can send image size here - hidden_state = self._init_hidden(batch_size=b, - image_size=(h, w)) + hidden_state = self._init_hidden(batch_size=b, image_size=(h, w)) + elif len(hidden_state) != self.num_layers: + raise NotImplementedError() layer_output_list = [] last_state_list = [] @@ -168,7 +167,7 @@ def forward(self, input_tensor, hidden_state=None): if not self.return_all_layers: layer_output_list = layer_output_list[-1:] - last_state_list = last_state_list[-1:] + # last_state_list = last_state_list[-1:] return layer_output_list, last_state_list