From 7a03fc7ce8b40d2929eb6ab086d3aa5be4b47a24 Mon Sep 17 00:00:00 2001 From: Sjoerd Groot Date: Mon, 11 Mar 2024 18:06:49 +0100 Subject: [PATCH] Fix split pruning for multiple iterative_steps --- tests/test_concat_split.py | 2 +- tests/test_non_feature_dim_cat.py | 2 +- tests/test_split.py | 2 +- torch_pruning/ops.py | 2 +- 4 files changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/test_concat_split.py b/tests/test_concat_split.py index 16f36bdd..1e4e8602 100644 --- a/tests/test_concat_split.py +++ b/tests/test_concat_split.py @@ -51,7 +51,7 @@ def test_pruner(): if isinstance(m, torch.nn.Linear) and m.out_features == 1000: ignored_layers.append(m) - iterative_steps = 1 + iterative_steps = 2 pruner = tp.pruner.MagnitudePruner( model, example_inputs, diff --git a/tests/test_non_feature_dim_cat.py b/tests/test_non_feature_dim_cat.py index 297cf54c..fde8e0d8 100644 --- a/tests/test_non_feature_dim_cat.py +++ b/tests/test_non_feature_dim_cat.py @@ -49,7 +49,7 @@ def test_pruner(): if isinstance(m, torch.nn.Linear) and m.out_features == 1000: ignored_layers.append(m) - iterative_steps = 1 + iterative_steps = 2 pruner = tp.pruner.MagnitudePruner( model, example_inputs, diff --git a/tests/test_split.py b/tests/test_split.py index 003ed192..eeaea3c8 100644 --- a/tests/test_split.py +++ b/tests/test_split.py @@ -53,7 +53,7 @@ def test_pruner(): if isinstance(m, torch.nn.Linear) and m.out_features == 1000: ignored_layers.append(m) - iterative_steps = 1 + iterative_steps = 2 pruner = tp.pruner.MagnitudePruner( model, example_inputs, diff --git a/torch_pruning/ops.py b/torch_pruning/ops.py index 3befdcaa..861d1cb0 100644 --- a/torch_pruning/ops.py +++ b/torch_pruning/ops.py @@ -135,7 +135,7 @@ def prune_out_channels(self, layer, idxs): offsets = [0] for i in range(len(new_split_sizes)): offsets.append(offsets[i] + new_split_sizes[i]) - self.offsets = offsets + layer.offsets = offsets prune_in_channels = prune_out_channels