Skip to content

Commit

Permalink
fixed convexpand?
Browse files Browse the repository at this point in the history
  • Loading branch information
AndreaBrg committed Jun 5, 2023
1 parent 564235c commit cb3723d
Show file tree
Hide file tree
Showing 6 changed files with 54 additions and 58 deletions.
1 change: 1 addition & 0 deletions .idea/simplify.iml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

36 changes: 19 additions & 17 deletions simplify/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,19 +40,31 @@ def forward(self, x):
return x + self.bf


class BatchNormB(nn.BatchNorm2d):
@staticmethod
def from_bn(module: nn.BatchNorm2d, bias):
module.__class__ = BatchNormB
module.register_parameter('bf', torch.nn.Parameter(bias))
return module

def forward(self, x):
x = super().forward(x)
return x + self.bf[:, None, None].expand_as(x[0])


class ConvExpand(nn.Conv2d):
@staticmethod
def from_conv(module: nn.Conv2d, idxs: torch.Tensor, bias):
def from_conv(module: nn.Conv2d, idxs: torch.Tensor, bias, use_bf, shape):
module.__class__ = ConvExpand

module.register_parameter('bf', torch.nn.Parameter(bias))
setattr(module, "use_bf", bias.abs().sum() != 0)
setattr(module, "use_bf", use_bf)

module.register_buffer('idxs', idxs.to(module.weight.device))
module.register_buffer('zeros', torch.zeros(1, *bias.shape, dtype=bias.dtype, device=module.weight.device))
module.register_buffer('zeros', torch.zeros(1, *shape[1:], dtype=bias.dtype, device=module.weight.device))

setattr(module, 'idxs_cache', module.idxs)
setattr(module, 'zero_cache', module.zeros)
setattr(module, 'idxs_cache', module.idxs)

return module

Expand All @@ -61,6 +73,7 @@ def forward(self, x):

zeros = self.zero_cache
index = self.idxs_cache

if zeros.shape[0] != x.shape[0]:
zeros = self.zeros.expand(x.shape[0], *self.zeros.shape[1:])
self.zero_cache = zeros
Expand All @@ -70,24 +83,13 @@ def forward(self, x):
self.idxs_cache = index

expanded = torch.scatter(zeros, 1, index, x)
return expanded + self.bf if self.use_bf else expanded
bf = self.bf if self.use_bf else self.bf[:, None, None].expand_as(expanded)
return expanded + bf

def __repr__(self):
return f'ConvExpand({self.in_channels}, {self.out_channels}, exp={len(self.idxs)})'


class BatchNormB(nn.BatchNorm2d):
@staticmethod
def from_bn(module: nn.BatchNorm2d, bias):
module.__class__ = BatchNormB
module.register_parameter('bf', torch.nn.Parameter(bias))
return module

def forward(self, x):
x = super().forward(x)
return x + self.bf[:, None, None].expand_as(x[0])


class BatchNormExpand(nn.BatchNorm2d):
@staticmethod
def from_bn(module: nn.BatchNorm2d, idxs: torch.Tensor, bias, shape):
Expand Down
8 changes: 4 additions & 4 deletions simplify/propagate.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,14 +119,14 @@ def __propagate_biases_hook(module, input, output, name=None):
else:
error('Unsupported module type:', module)

#############################################
## STEP 2. Propagate output to next module ##
#############################################
####################################################
## STEP 2. Propagate output (bias) to next module ##
####################################################

shape = module.weight.shape # Compute mask of zeroed (pruned) channels
pruned_channels = torch.abs(module.weight.view(shape[0], -1)).sum(dim=1) == 0

if isinstance(module, nn.Conv2d) and module.groups > 1:
if name in pinned_out or (isinstance(module, nn.Conv2d) and module.groups > 1):
# No bias is propagated for pinned layers
return output * float('nan')

Expand Down
13 changes: 5 additions & 8 deletions simplify/remove.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import torch
import torch.nn as nn

from .layers import BatchNormB, ConvExpand, BatchNormExpand, LinearExpand
from .layers import BatchNormB, ConvExpand, BatchNormExpand, LinearExpand, ConvB


@torch.no_grad()
Expand Down Expand Up @@ -113,18 +113,15 @@ def __remove_zeroed_channels_hook(module, input, output, name):

# Keep bias (bf) full size
elif isinstance(module, nn.Conv2d):
module_bf = getattr(module, 'bf', None)
if module_bf is None:
module_bf = torch.zeros_like(output[0])

module = ConvExpand.from_conv(module, idxs, module_bf)
bias = module.bf if isinstance(module, ConvB) else module.bias
module = ConvExpand.from_conv(module, idxs, bias, isinstance(module, ConvB), output.shape)

elif isinstance(module, nn.BatchNorm2d):
bias = module.bf if isinstance(module, BatchNormB) else module.bias
module = BatchNormExpand.from_bn(module, idxs, bias, output.shape)

if not isinstance(module, BatchNormB):
module.register_parameter("bias", None)
if not isinstance(module, (ConvB, BatchNormB)):
module.register_parameter("bias", None)
else:
if getattr(module, 'bf', None) is not None:
module.bf = nn.Parameter(module.bf[nonzero_idx])
Expand Down
2 changes: 1 addition & 1 deletion test/modules/test_remove.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def test_arch(arch, x, fuse_bn):
for architecture in models:
print(f"Testing with {architecture.__name__}")

for i in range(1):
for i in range(10):
with self.subTest(arch=architecture, fuse_bn=True):
test_arch(architecture, x, fuse_bn=True)

Expand Down
52 changes: 24 additions & 28 deletions test/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,14 @@
from torch._C._onnx import TrainingMode
from torch.nn.utils import prune
from torchvision.models import SqueezeNet
from torchvision.models import alexnet, resnet18, resnet50, vgg11, vgg11_bn, densenet121, inception_v3, googlenet, \
shufflenet_v2_x0_5, mobilenet_v2, mobilenet_v3_small, resnext50_32x4d, wide_resnet50_2, mnasnet0_5, mnasnet1_0
from torchvision.models import alexnet, resnet18
from torchvision.models.squeezenet import squeezenet1_0


class ResidualNet(nn.Module):
def __init__(self, *args, **kwargs):
super().__init__()
self.conv1 = nn.Conv2d(3, 2, kernel_size=7, stride=2, padding=3, bias=False)
self.conv1 = nn.Conv2d(3, 2, kernel_size=7, stride=2, padding=0, bias=False)
self.bn1 = nn.BatchNorm2d(2)
self.relu = nn.ReLU(inplace=True)
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
Expand Down Expand Up @@ -51,34 +50,31 @@ def forward(self, x):
return x4


# models = [
# ResidualNet
# ]
models = [
alexnet,
ResidualNet,
resnet18,
squeezenet1_0
]


# models = [
# alexnet,
# resnet18,
# squeezenet1_0
# vgg11, vgg11_bn,
# resnet18, resnet50,
# squeezenet1_0,
# densenet121,
# inception_v3,
# googlenet,
# shufflenet_v2_x0_5,
# mobilenet_v2, mobilenet_v3_small,
# resnext50_32x4d,
# wide_resnet50_2,
# mnasnet0_5, mnasnet1_0,
# densenet121
# ]


models = [
alexnet,
vgg11, vgg11_bn,
resnet18, resnet50,
squeezenet1_0,
densenet121,
inception_v3,
googlenet,
shufflenet_v2_x0_5,
mobilenet_v2, mobilenet_v3_small,
resnext50_32x4d,
wide_resnet50_2,
mnasnet0_5, mnasnet1_0,
densenet121
]


def get_model(architecture, arch):
# random.seed(0)
# os.environ["PYTHONHASHSEED"] = str(0)
Expand All @@ -95,15 +91,15 @@ def get_model(architecture, arch):
pretrained = True

model = arch(pretrained, progress=False)
model(torch.randn(64, 3, 224, 224))
model(torch.randn(16, 3, 224, 224))
model.eval()

for name, module in model.named_modules():
if isinstance(model, SqueezeNet) and 'classifier.1' in name:
continue

if isinstance(module, nn.Conv2d):
prune.random_structured(module, 'weight', amount=0.5, dim=0)
prune.random_structured(module, 'weight', amount=0.63, dim=0)
prune.remove(module, 'weight')

# if isinstance(module, nn.BatchNorm2d):
Expand All @@ -115,4 +111,4 @@ def get_model(architecture, arch):

if __name__ == '__main__':
model = ResidualNet()
torch.onnx.export(model, torch.randn(1, 3, 224, 224), "resnet18.onnx", training=TrainingMode.TRAINING)
torch.onnx.export(model, torch.randn(1, 3, 224, 224), "model.onnx", training=TrainingMode.TRAINING)

0 comments on commit cb3723d

Please sign in to comment.