Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove dim=1 parameters from LARS updates and weight decay #301

Closed
turian opened this issue Sep 11, 2022 · 9 comments
Closed

Remove dim=1 parameters from LARS updates and weight decay #301

turian opened this issue Sep 11, 2022 · 9 comments

Comments

@turian
Copy link
Contributor

turian commented Sep 11, 2022

Describe the bug

In addition to solo.utils.misc.remove_bias_and_norm_from_weight_decay, we might consider also solo.utils.misc.ignore_dim1_parameters to ignore weight decay and lars_adaptation_filter them form.

Versions
solo-learn main

Additional comments
I see that solo.utils.misc.remove_bias_and_norm_from_weight_decay was added in #289. This is quite nice. (I can't tell from your patch, but bias and norm should also be excluded from the LARS adaptation. Is this or is this not the case? Your line 371 appears to do this but it's hard to grok from the comment and the small diff. See Facebook code here).

The key thing is that one might want to wish to exclude ALL ndim==1 parameters from LARS updates and weight decay. (FB vicreg code) There might be other norms, etc., for which this is a good idea. In the FB code they just had batchnorm so it didn't matter, but there might be other settings where this would be a good idea.

If this works, I might pester/help pytorch (pytorch/pytorch#1402) and lightning (with whom I've been discussing this issue) to upstream this change.

@vturrisi
Copy link
Owner

The issue is that for some reason I noticed slightly lower performance with and without remove_bias_and_norm_from_weight_decay on a couple of methods that I tried. I'm not sure if the cause for this is just the combination of the other parameters or something else entirely. I think having them decoupled is beneficial.

Line 371 just excludes the parameters from the scheduler (only SimSiam does this).

@turian
Copy link
Contributor Author

turian commented Sep 12, 2022

@vturrisi perhaps this could be a configurable option? (Or separate options for disabling weight decay on ndim 1 and disabling LARS updates on ndim 1)

Here's more evidence:

  • FB vicreg code excludes ALL ndim==1 parameters from LARS updates and weight decay
  • fast.ai: "Recently, Tencent published a very nice paper showing <7 minute training of Imagenet on 2,048 GPUs. They mentioned a trick we hadn’t tried before, but makes perfect sense: removing weight decay from batchnorm layers. That allowed us to trim another couple of epochs from our training time. (The Tencent paper also used a dynamic learning rate approach developed by NVIDIA research, called LARS, which we’ve also been developing for fastai, but is not included yet in these results.)"
  • Weight decay in the optimizers is a bad idea (especially with BatchNorm) "Correct me if I’m wrong, but there is no reason the beta and gamma parameters in BatchNorm should ever be subject to weight decay, ie L2 regularization, that pulls them toward 0. In fact it seems like a very bad idea to pull them toward 0."
  • Lightning-Flash is considering adding this too, since they have a LARS optimizer that is not in core Lightning.

@vturrisi
Copy link
Owner

@turian I see. As it is now, we should already have these two options, no? Maybe just the naming that's confusing.
Here we only update parameters that have ndim != 1 if exclude_bias_n_norm is True. About excluding from weight decay, I believe that I copied some old code from timm, but I think this makes more sense. Nonetheless, the exclude_bias_n_norm_wd should trigger this function to be executed.

Maybe the solution is just:

@turian
Copy link
Contributor Author

turian commented Sep 12, 2022

Ah I see! Yes, I guess I found the names confusing. Maybe a docstring could help too, but the renames could be great.

@vturrisi
Copy link
Owner

I re-checked my function and found an issue that we were overwriting the weight decay. Gonna push a fix for that in the upcoming release. About renaming, I thought a bit more about that and I think it's a fair enough name, since the only 1d parameters that we have are normalization layers and biases.

@turian
Copy link
Contributor Author

turian commented Sep 14, 2022

About renaming, I thought a bit more about that and I think it's a fair enough name, since the only 1d parameters that we have are normalization layers and biases.

That doesn't seem so forward thinking. For example, what if some SSL algorithm uses PReLU activations? If I'm not mistaken those are ndim 1 also and you probably don't want to weight decay or LARS adjust them.

@turian
Copy link
Contributor Author

turian commented Sep 14, 2022

I re-checked my function and found an issue that we were overwriting the weight decay. Gonna push a fix for that in the upcoming release.

I'm having to do code review of the PR, I'm excited to work on ptl SSL code, given what I'm currently building

@vturrisi
Copy link
Owner

Indeed the name wouldn't be completely clear for that method, even though it would still incorporate those ndim=1 learnable parameters. Still, having a name as exclude_ndim1 might not be clear for some people.

@vturrisi
Copy link
Owner

We now have exclude_bias_n_norm_wd as an entry to optimizer and exclude_bias_n_norm as an entry of optimizer/kwargs for when lars is enabled. This addresses both cases and I think it's clear enough.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants