Skip to content

Commit

Permalink
Fix: consider dropped filters when re-init #2
Browse files Browse the repository at this point in the history
  • Loading branch information
siahuat0727 committed Mar 25, 2019
1 parent 412309b commit 4fa1b8d
Showing 1 changed file with 12 additions and 8 deletions.
20 changes: 12 additions & 8 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,15 +153,18 @@ def pruning(conv_weights, prune_ratio):
threshold = np.percentile(ranks, 100*(1-prune_ratio))
# get indice of bad filters
mask = {}
drop_filters = {}
for name, W in conv_weights:
mask[name] = np.where(inter_filter_ortho[name] > threshold)[0]
with torch.no_grad():
W.data[mask[name]] = 0
if mask[name].size > 0:
with torch.no_grad():
drop_filters[name] = W.data[mask[name]].view(mask[name].size, -1).cpu().numpy()
W.data[mask[name]] = 0

test_filter_sparsity(conv_weights)
return mask
return mask, drop_filters

def reinitialize(mask, conv_weights, fc_weights):
def reinitialize(mask, drop_filters, conv_weights, fc_weights):
print('Reinitializing...')
with torch.no_grad():
prev_layer_name = None
Expand All @@ -170,8 +173,8 @@ def reinitialize(mask, conv_weights, fc_weights):
if W.dim() == 4: # conv weights
# find null space
size = W.size()
W2d = W.view(size[0], -1)
null_space = qr_null(W2d.cpu().detach().numpy())
W2d = W.view(size[0], -1).cpu().numpy()
null_space = qr_null(np.vstack((drop_filters[name], W2d)))
null_space = torch.from_numpy(null_space).cuda()
null_space = null_space.transpose(0, 1).view(-1, size[1], size[2], size[3])

Expand Down Expand Up @@ -255,15 +258,16 @@ def main():
"repr" if args.repr else "norepr", args.epochs, args.comment))

mask = None
drop_filters = None
best_acc = 0 # best test accuracy
for epoch in range(args.epochs):
if args.repr:
# check if the end of S1 stage
if any(epoch == s for s in range(args.S1, args.epochs, args.S1+args.S2)):
mask = pruning(conv_weights, args.prune_ratio)
mask, drop_filters = pruning(conv_weights, args.prune_ratio)
# check if the end of S2 stage
if any(epoch == s for s in range(args.S1+args.S2, args.epochs, args.S1+args.S2)):
reinitialize(mask, conv_weights, fc_weights)
reinitialize(mask, drop_filters, conv_weights, fc_weights)
train(trainloader, criterion, optimizer, epoch, model, writer, mask, args, conv_weights)
acc = validate(testloader, criterion, model, writer, args, epoch, best_acc)
best_acc = max(best_acc, acc)
Expand Down

0 comments on commit 4fa1b8d

Please sign in to comment.