Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

A question about reinitialize #9

Open
Yanxingang opened this issue May 7, 2019 · 9 comments
Open

A question about reinitialize #9

Yanxingang opened this issue May 7, 2019 · 9 comments

Comments

@Yanxingang
Copy link

Hi, I feel a little confused about reinitialize.
After reinitialize, the weights of the network have changed immediately. The accuracy may go down at the reinitializing point. Could you tell me why the curve is continuous?

@siahuat0727
Copy link
Owner

Hi,
I mask the weights that connected to the pruned filters with zeros. code
Note that this operation doesn't mentioned in the paper.

@machanic
Copy link

machanic commented May 7, 2019

@siahuat0727 I possibly found a bug: In the original paper, the author says there are two networks, the original network and the pruned one. But your code seems has just one network. One consequece is that ,the prune does not work properly, because you prune the conv kernels by setting these conv kernels to 0, but after that, you train the network as usual, and the conv-layers update all their weights including the pruned one!

@machanic
Copy link

machanic commented May 7, 2019

@siahuat0727 I think in order to solve this problem, you need to freeze these pruned channel inside kernel after pruning, or create another pruned network and copy the original network's weights (except the pruned channels) to it and train this small pruned network, as the author written in that paper.

@siahuat0727
Copy link
Owner

siahuat0727 commented May 8, 2019

@sharpstill
Hi,
if I understand correctly, I already freeze it here.

RePr/main.py

Lines 68 to 73 in b0f46f4

if args.repr and any(s1 <= epoch < s1+S2 for s1 in range(S1, args.epochs, S1+S2)):
if i == 0:
print('freeze for this epoch')
with torch.no_grad():
for name, W in conv_weights:
W.grad[mask[name]] = 0

@machanic
Copy link

machanic commented May 8, 2019

@siahuat0727 I still think this step has problem, because you use

with torch.no_grad():
    for name, W in conv_weights:
        W.grad[mask[name]] = 0

but you freeze all the weights of conv-layers because conv_weights are all the channels' weights, according to

RePr/main.py

Line 266 in b0f46f4

conv_weights.append((name, W))

I mean you need to freeze the pruned-channels of the conv-kernel, not all of them.
I recommend you to consider my suggestion in my last response. (use a newly-created small sub-network and copy the parameters of un-pruned weight to it, then train this sub-network.)
Another question is that some conv-layers have bias, your W.grad[mask[name]] = 0 does not consider the bias case.

@machanic
Copy link

machanic commented May 8, 2019

Furthermore, the BN layer will also be affected by the channel number, and you use all the channels to train, BN layer also won't work properly.

@siahuat0727
Copy link
Owner

Hi,
I didn't freeze all of them, mask[name] is the indices of pruned filters.

RePr/main.py

Lines 161 to 166 in b0f46f4

mask = {}
drop_filters = {}
for name, W in conv_weights:
prune[name] = inter_filter_ortho[name] > threshold # e.g. [True, False, True, True, False]
# get indice of bad filters
mask[name] = np.where(prune[name])[0] # e.g. [0, 2, 3]

I think the only difference is the bias term but I don't think that's the point. You can try for the difference.

@machanic
Copy link

machanic commented May 8, 2019

Thank you for your reply, I will try my solution and report the result to compare with yours.

@siahuat0727
Copy link
Owner

Thanks. Look forward to your experiment.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants