diff --git a/patch_conv/__version__.py b/patch_conv/__version__.py index 03d66bc..25d3d73 100644 --- a/patch_conv/__version__.py +++ b/patch_conv/__version__.py @@ -1 +1 @@ -__version__ = "0.0.0beta0" +__version__ = "0.0.1beta0" diff --git a/patch_conv/module.py b/patch_conv/module.py index dc5b844..015b5b5 100644 --- a/patch_conv/module.py +++ b/patch_conv/module.py @@ -4,14 +4,14 @@ class PatchConv2d(nn.Module): - def __init__(self, splits: int = 4, conv2d: nn.Conv2d = None, *args, **kwargs): + def __init__(self, splits: int = 4, sequential: bool = True, conv2d: nn.Conv2d = None, *args, **kwargs): super(PatchConv2d, self).__init__() + self.splits = splits + self.sequential = sequential if conv2d is not None: self.conv2d = conv2d - self.splits = splits else: self.conv2d = nn.Conv2d(*args, **kwargs) - self.splits = splits def forward(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor: b, c, h, w = x.shape @@ -29,7 +29,14 @@ def forward(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor: x_padded = x_padded.view(b * self.splits, c, -1, w) padding_bak = self.conv2d.padding self.conv2d.padding = (0, self.conv2d.padding[1]) - output = self.conv2d(x_padded, *args, **kwargs) + if self.sequential: + outputs = [] + for i in range(x_padded.shape[0]): + output = self.conv2d(x_padded[i : i + 1], *args, **kwargs) + outputs.append(output) + output = torch.cat(outputs, dim=0) + else: + output = self.conv2d(x_padded, *args, **kwargs) self.conv2d.padding = padding_bak _, oc, oh, ow = output.shape output = output.view(b, self.splits, oc, oh, ow).permute(0, 2, 1, 3, 4).reshape(b, oc, -1, ow)