Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

updates returns zeros #1101

Open
GeophyAI opened this issue Oct 10, 2024 · 3 comments
Open

updates returns zeros #1101

GeophyAI opened this issue Oct 10, 2024 · 3 comments

Comments

@GeophyAI
Copy link

I'm using optax for training something, updates, opt_state = opt.update(gradient, opt_state), when I check the updates, I found that it contains only zeros, but the gradients do have values. In which situations this goona be happen?

@vroulet
Copy link
Collaborator

vroulet commented Oct 10, 2024

Hello @GeophyAI,

It depends on the optimizer. We cannot help without a concrete example of what you are trying to do.

@GeophyAI
Copy link
Author

I'm using optax.masked and optax.chain for assigning different learning rate to different paramemter groups, something like the following implementation:

paras_counts = len(pars_need_by_eq)

def create_mask_fn(index, num_params):
    return tuple(i == index for i in range(num_params))
optimizers = []
for i, para in enumerate(pars_need_by_eq):
    # Set the learning rate for each parameter
    _lr = 0. if para not in pars_need_invert else lr[para]#*scale_decay**idx_freq
    lr_schedule = optax.exponential_decay(_lr*scale_decay**idx_freq,1,epoch_decay)
    # opt = optax.adam(lr_schedule, eps=1e-22)
    opt = optax.inject_hyperparams(optax.adam)(learning_rate=lambda count: lr_schedule(count), eps=1e-22)
    self.logger.print(f"Learning rate for {para}: {lr_schedule(0)}")
    mask = create_mask_fn(i, paras_counts)
    optimizers.append(optax.masked(opt, mask))

return optax.chain(*optimizers)

When my model only have 1 group parameter, it works fine, when the parameter is larger than 1, the updates always be zeros.

@GeophyAI
Copy link
Author

I found that when I replace the line opt = optax.inject_hyperparams(optax.adam)(learning_rate=lambda count: lr_schedule(count), eps=1e-22) with opt = optax.adam(lr_schedule, eps=1e-22), it works for both single and multi parameter groups

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants