Skip to content

Commit e49f403

Browse files
author
zhangchen76
committed
Fix bugs.
1 parent b5d5b03 commit e49f403

File tree

2 files changed

+2
-2
lines changed

2 files changed

+2
-2
lines changed

run_pruning.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -173,12 +173,12 @@ def main():
173173
neuron_score /= norm_per_layer.unsqueeze(-1) + 1e-7
174174

175175
# Reorder for efficient indexing with module-wise sparsity.
176-
# reorder heads for each layer in model, reorder neuron in model
177176
base_model = getattr(model, model.base_model_prefix, model)
178177
head_score, head_indices = torch.sort(head_score, dim=1, descending=True)
179178
neuron_score, neuron_indices = torch.sort(neuron_score, dim=1, descending=True)
180179
head_indices = {layer_idx: indices for layer_idx, indices in enumerate(head_indices)}
181180
neuron_indices = {layer_idx: indices for layer_idx, indices in enumerate(neuron_indices)}
181+
base_model.reorder(head_indices, neuron_indices)
182182

183183
# Compute module-wise sparsity from overall sparsity.
184184
head_sort = [

run_sparsification.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -226,12 +226,12 @@ def main():
226226
neuron_score = args.lam * neuron_expressive_score + (1 - args.lam) * neuron_friendly_score
227227

228228
# Reorder for efficient indexing with module-wise sparsity.
229-
# reorder heads for each layer in model, reorder neuron in model
230229
base_model = getattr(t_model, t_model.base_model_prefix, t_model)
231230
head_score, head_indices = torch.sort(head_score, dim=1, descending=True)
232231
neuron_score, neuron_indices = torch.sort(neuron_score, dim=1, descending=True)
233232
head_indices = {layer_idx: indices for layer_idx, indices in enumerate(head_indices)}
234233
neuron_indices = {layer_idx: indices for layer_idx, indices in enumerate(neuron_indices)}
234+
base_model.reorder(head_indices, neuron_indices)
235235

236236
# Compute module-wise sparsity from overall sparsity.
237237
head_sort = [

0 commit comments

Comments
 (0)