Skip to content

Commit

Permalink
add sequential mode
Browse files Browse the repository at this point in the history
  • Loading branch information
lmxyy committed May 26, 2024
1 parent c192b69 commit 0b52e9d
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 5 deletions.
2 changes: 1 addition & 1 deletion patch_conv/__version__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "0.0.0beta0"
__version__ = "0.0.1beta0"
15 changes: 11 additions & 4 deletions patch_conv/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,14 @@


class PatchConv2d(nn.Module):
def __init__(self, splits: int = 4, conv2d: nn.Conv2d = None, *args, **kwargs):
def __init__(self, splits: int = 4, sequential: bool = True, conv2d: nn.Conv2d = None, *args, **kwargs):
super(PatchConv2d, self).__init__()
self.splits = splits
self.sequential = sequential
if conv2d is not None:
self.conv2d = conv2d
self.splits = splits
else:
self.conv2d = nn.Conv2d(*args, **kwargs)
self.splits = splits

def forward(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor:
b, c, h, w = x.shape
Expand All @@ -29,7 +29,14 @@ def forward(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor:
x_padded = x_padded.view(b * self.splits, c, -1, w)
padding_bak = self.conv2d.padding
self.conv2d.padding = (0, self.conv2d.padding[1])
output = self.conv2d(x_padded, *args, **kwargs)
if self.sequential:
outputs = []
for i in range(x_padded.shape[0]):
output = self.conv2d(x_padded[i : i + 1], *args, **kwargs)
outputs.append(output)
output = torch.cat(outputs, dim=0)
else:
output = self.conv2d(x_padded, *args, **kwargs)
self.conv2d.padding = padding_bak
_, oc, oh, ow = output.shape
output = output.view(b, self.splits, oc, oh, ow).permute(0, 2, 1, 3, 4).reshape(b, oc, -1, ow)
Expand Down

0 comments on commit 0b52e9d

Please sign in to comment.