Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Implement GCSAM and LookSAM optimizers #344

Merged
merged 9 commits into from
Feb 9, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

## The reasons why you use `pytorch-optimizer`.

* Wide range of supported optimizers. Currently, **97 optimizers (+ `bitsandbytes`, `qgalore`, `torchao`)**, **16 lr schedulers**, and **13 loss functions** are supported!
* Wide range of supported optimizers. Currently, **98 optimizers (+ `bitsandbytes`, `qgalore`, `torchao`)**, **16 lr schedulers**, and **13 loss functions** are supported!
* Including many variants such as `ADOPT`, `Cautious`, `AdamD`, `StableAdamW`, and `Gradient Centrailiaztion`
* Easy to use, clean, and tested codes
* Active maintenance
Expand Down Expand Up @@ -204,6 +204,8 @@ get_supported_optimizers(['adam*', 'ranger*'])
| FOCUS | *First Order Concentrated Updating Scheme* | [github](https://github.com/liuyz0/FOCUS) | <https://arxiv.org/abs/2501.12243> | [cite](https://ui.adsabs.harvard.edu/abs/2025arXiv250112243M/exportcitation) |
| PSGD | *Preconditioned Stochastic Gradient Descent* | [github](https://github.com/lixilinx/psgd_torch) | <https://arxiv.org/abs/1512.04202> | [cite](https://github.com/lixilinx/psgd_torch?tab=readme-ov-file#resources) |
| EXAdam | *The Power of Adaptive Cross-Moments* | [github](https://github.com/AhmedMostafa16/EXAdam) | <https://arxiv.org/abs/2412.20302> | [cite](https://github.com/AhmedMostafa16/EXAdam?tab=readme-ov-file#citation) |
| GCSAM | *Gradient Centralized Sharpness Aware Minimization* | [github](https://github.com/mhassann22/GCSAM) | <https://arxiv.org/abs/2501.11584> | [cite](https://github.com/mhassann22/GCSAM?tab=readme-ov-file#citation) |
| LookSAM | *Towards Efficient and Scalable Sharpness-Aware Minimization* | [github](https://github.com/rollovd/LookSAM) | <https://arxiv.org/abs/2203.02714> | [cite](https://ui.adsabs.harvard.edu/abs/2022arXiv220302714L/exportcitation) |

## Supported LR Scheduler

Expand Down
8 changes: 8 additions & 0 deletions docs/changelogs/v3.4.1.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,13 @@
### Change Log

### Feature

* Support `GCSAM` optimizer. (#343, #344)
* [Gradient Centralized Sharpness Aware Minimization](https://arxiv.org/abs/2501.11584)
* you can use it from `SAM` optimizer by setting `use_gc=True`.
* Support `LookSAM` optimizer. (#343, #344)
* [Towards Efficient and Scalable Sharpness-Aware Minimization](https://arxiv.org/abs/2203.02714)

### Update

* Support alternative precision training for `Shampoo` optimizer. (#339)
Expand Down
4 changes: 3 additions & 1 deletion docs/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

## The reasons why you use `pytorch-optimizer`.

* Wide range of supported optimizers. Currently, **97 optimizers (+ `bitsandbytes`, `qgalore`, `torchao`)**, **16 lr schedulers**, and **13 loss functions** are supported!
* Wide range of supported optimizers. Currently, **98 optimizers (+ `bitsandbytes`, `qgalore`, `torchao`)**, **16 lr schedulers**, and **13 loss functions** are supported!
* Including many variants such as `ADOPT`, `Cautious`, `AdamD`, `StableAdamW`, and `Gradient Centrailiaztion`
* Easy to use, clean, and tested codes
* Active maintenance
Expand Down Expand Up @@ -204,6 +204,8 @@ get_supported_optimizers(['adam*', 'ranger*'])
| FOCUS | *First Order Concentrated Updating Scheme* | [github](https://github.com/liuyz0/FOCUS) | <https://arxiv.org/abs/2501.12243> | [cite](https://ui.adsabs.harvard.edu/abs/2025arXiv250112243M/exportcitation) |
| PSGD | *Preconditioned Stochastic Gradient Descent* | [github](https://github.com/lixilinx/psgd_torch) | <https://arxiv.org/abs/1512.04202> | [cite](https://github.com/lixilinx/psgd_torch?tab=readme-ov-file#resources) |
| EXAdam | *The Power of Adaptive Cross-Moments* | [github](https://github.com/AhmedMostafa16/EXAdam) | <https://arxiv.org/abs/2412.20302> | [cite](https://github.com/AhmedMostafa16/EXAdam?tab=readme-ov-file#citation) |
| GCSAM | *Gradient Centralized Sharpness Aware Minimization* | [github](https://github.com/mhassann22/GCSAM) | <https://arxiv.org/abs/2501.11584> | [cite](https://github.com/mhassann22/GCSAM?tab=readme-ov-file#citation) |
| LookSAM | *Towards Efficient and Scalable Sharpness-Aware Minimization* | [github](https://github.com/rollovd/LookSAM) | <https://arxiv.org/abs/2203.02714> | [cite](https://ui.adsabs.harvard.edu/abs/2022arXiv220302714L/exportcitation) |

## Supported LR Scheduler

Expand Down
4 changes: 4 additions & 0 deletions docs/optimizer.md
Original file line number Diff line number Diff line change
Expand Up @@ -240,6 +240,10 @@
:docstring:
:members:

::: pytorch_optimizer.LookSAM
:docstring:
:members:

::: pytorch_optimizer.MADGRAD
:docstring:
:members:
Expand Down
9 changes: 5 additions & 4 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,11 @@ keywords = [
"DAdaptLion", "DeMo", "DiffGrad", "EXAdam", "FAdam", "FOCUS", "Fromage", "FTRL", "GaLore", "Grams", "Gravity",
"GrokFast", "GSAM", "Kate", "Lamb", "LaProp", "LARS", "Lion", "LOMO", "Lookahead", "MADGRAD", "MARS", "MSVAG",
"Muno", "Nero", "NovoGrad", "OrthoGrad", "PAdam", "PCGrad", "PID", "PNM", "Prodigy", "PSGD", "QHAdam", "QHM",
"RAdam", "Ranger", "Ranger21", "RotoGrad", "SAM", "ScheduleFreeSGD", "ScheduleFreeAdamW", "ScheduleFreeRAdam",
"SGDP", "Shampoo", "ScalableShampoo", "SGDW", "SignSGD", "SM3", "SOAP", "SopihaH", "SPAM", "SRMM", "StableAdamW",
"SWATS", "TAM", "Tiger", "TRAC", "WSAM", "Yogi", "BCE", "BCEFocal", "Focal", "FocalCosine", "SoftF1", "Dice",
"LDAM", "Jaccard", "Bi-Tempered", "Tversky", "FocalTversky", "LovaszHinge", "bitsandbytes", "WSD", "QGaLore",
"RAdam", "Ranger", "Ranger21", "RotoGrad", "SAM", "GCSAM", "LookSAM", "ScheduleFreeSGD", "ScheduleFreeAdamW",
"ScheduleFreeRAdam", "SGDP", "Shampoo", "ScalableShampoo", "SGDW", "SignSGD", "SM3", "SOAP", "SopihaH", "SPAM",
"SRMM", "StableAdamW", "SWATS", "TAM", "Tiger", "TRAC", "WSAM", "Yogi", "BCE", "BCEFocal", "Focal", "FocalCosine",
"SoftF1", "Dice", "LDAM", "Jaccard", "Bi-Tempered", "Tversky", "FocalTversky", "LovaszHinge", "bitsandbytes", "WSD",
"QGaLore",
]
classifiers = [
"License :: OSI Approved :: Apache Software License",
Expand Down
1 change: 1 addition & 0 deletions pytorch_optimizer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,7 @@
LaProp,
Lion,
Lookahead,
LookSAM,
Muon,
Nero,
NovoGrad,
Expand Down
2 changes: 1 addition & 1 deletion pytorch_optimizer/optimizer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@
from pytorch_optimizer.optimizer.ranger import Ranger
from pytorch_optimizer.optimizer.ranger21 import Ranger21
from pytorch_optimizer.optimizer.rotograd import RotoGrad
from pytorch_optimizer.optimizer.sam import BSAM, GSAM, SAM, WSAM
from pytorch_optimizer.optimizer.sam import BSAM, GSAM, SAM, WSAM, LookSAM
from pytorch_optimizer.optimizer.schedulefree import ScheduleFreeAdamW, ScheduleFreeRAdam, ScheduleFreeSGD
from pytorch_optimizer.optimizer.sgd import ASGD, SGDW, AccSGD, SGDSaI, SignSGD
from pytorch_optimizer.optimizer.sgdp import SGDP
Expand Down
198 changes: 195 additions & 3 deletions pytorch_optimizer/optimizer/sam.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from pytorch_optimizer.base.exception import NoClosureError
from pytorch_optimizer.base.optimizer import BaseOptimizer
from pytorch_optimizer.base.types import BETAS, CLOSURE, DEFAULTS, OPTIMIZER, PARAMETERS
from pytorch_optimizer.optimizer.gc import centralize_gradient
from pytorch_optimizer.optimizer.utils import disable_running_stats, enable_running_stats


Expand Down Expand Up @@ -58,6 +59,7 @@ def closure():
:param base_optimizer: OPTIMIZER. base optimizer.
:param rho: float. size of the neighborhood for computing the max loss.
:param adaptive: bool. element-wise Adaptive SAM.
:param use_gc: bool. perform gradient centralization, GCSAM variant.
:param perturb_eps: float. eps for perturbation.
:param kwargs: Dict. parameters for optimizer.
"""
Expand All @@ -68,12 +70,14 @@ def __init__(
base_optimizer: OPTIMIZER,
rho: float = 0.05,
adaptive: bool = False,
use_gc: bool = False,
perturb_eps: float = 1e-12,
**kwargs,
):
self.validate_non_negative(rho, 'rho')
self.validate_non_negative(perturb_eps, 'perturb_eps')

self.use_gc = use_gc
self.perturb_eps = perturb_eps

defaults: DEFAULTS = {'rho': rho, 'adaptive': adaptive}
Expand All @@ -92,16 +96,20 @@ def reset(self):

@torch.no_grad()
def first_step(self, zero_grad: bool = False):
grad_norm = self.grad_norm()
grad_norm = self.grad_norm().add_(self.perturb_eps)
for group in self.param_groups:
scale = group['rho'] / (grad_norm + self.perturb_eps)
scale = group['rho'] / grad_norm

for p in group['params']:
if p.grad is None:
continue

grad = p.grad
if self.use_gc:
centralize_gradient(grad, gc_conv_only=False)

self.state[p]['old_p'] = p.clone()
e_w = (torch.pow(p, 2) if group['adaptive'] else 1.0) * p.grad * scale.to(p)
e_w = (torch.pow(p, 2) if group['adaptive'] else 1.0) * grad * scale.to(p)

p.add_(e_w)

Expand Down Expand Up @@ -670,3 +678,187 @@ def step(self, closure: CLOSURE = None):
self.third_step()

return loss


class LookSAM(BaseOptimizer):
r"""Towards Efficient and Scalable Sharpness-Aware Minimization.

Example:
-------
Here's an example::

model = YourModel()
base_optimizer = Ranger21
optimizer = LookSAM(model.parameters(), base_optimizer)

for input, output in data:
# first forward-backward pass

loss = loss_function(output, model(input))
loss.backward()
optimizer.first_step(zero_grad=True)

# second forward-backward pass
# make sure to do a full forward pass
loss_function(output, model(input)).backward()
optimizer.second_step(zero_grad=True)

Alternative example with a single closure-based step function::

model = YourModel()
base_optimizer = Ranger21
optimizer = LookSAM(model.parameters(), base_optimizer)

def closure():
loss = loss_function(output, model(input))
loss.backward()
return loss

for input, output in data:
loss = loss_function(output, model(input))
loss.backward()
optimizer.step(closure)
optimizer.zero_grad()

:param params: PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups.
:param base_optimizer: OPTIMIZER. base optimizer.
:param rho: float. size of the neighborhood for computing the max loss.
:param k: int. lookahead step.
:param alpha: float. lookahead blending alpha.
:param adaptive: bool. element-wise Adaptive SAM.
:param use_gc: bool. perform gradient centralization, GCSAM variant.
:param perturb_eps: float. eps for perturbation.
:param kwargs: Dict. parameters for optimizer.
"""

def __init__(
self,
params: PARAMETERS,
base_optimizer: OPTIMIZER,
rho: float = 0.1,
k: int = 10,
alpha: float = 0.7,
adaptive: bool = False,
use_gc: bool = False,
perturb_eps: float = 1e-12,
**kwargs,
):
self.validate_non_negative(rho, 'rho')
self.validate_positive(k, 'k')
self.validate_range(alpha, 'alpha', 0.0, 1.0, '()')
self.validate_non_negative(perturb_eps, 'perturb_eps')

self.k = k
self.alpha = alpha
self.use_gc = use_gc
self.perturb_eps = perturb_eps

defaults: DEFAULTS = {'rho': rho, 'adaptive': adaptive}
defaults.update(kwargs)

super().__init__(params, defaults)

self.base_optimizer: Optimizer = base_optimizer(self.param_groups, **kwargs)
self.param_groups = self.base_optimizer.param_groups

def __str__(self) -> str:
return 'LookSAM'

@torch.no_grad()
def reset(self):
pass

def get_step(self):
return (
self.param_groups[0]['step']
if 'step' in self.param_groups[0]
else next(iter(self.base_optimizer.state.values()))['step'] if self.base_optimizer.state else 0
)

@torch.no_grad()
def first_step(self, zero_grad: bool = False) -> None:
if self.get_step() % self.k != 0:
return

grad_norm = self.grad_norm().add_(self.perturb_eps)
for group in self.param_groups:
scale = group['rho'] / grad_norm

for i, p in enumerate(group['params']):
if p.grad is None:
continue

grad = p.grad
if self.use_gc:
centralize_gradient(grad, gc_conv_only=False)

self.state[p]['old_p'] = p.clone()
self.state[f'old_grad_p_{i}']['old_grad_p'] = grad.clone()

e_w = (torch.pow(p, 2) if group['adaptive'] else 1.0) * grad * scale.to(p)
p.add_(e_w)

if zero_grad:
self.zero_grad()

@torch.no_grad()
def second_step(self, zero_grad: bool = False):
step = self.get_step()

for group in self.param_groups:
for i, p in enumerate(group['params']):
if p.grad is None:
continue

grad = p.grad
grad_norm = grad.norm(p=2)

if step % self.k == 0:
old_grad_p = self.state[f'old_grad_p_{i}']['old_grad_p']

g_grad_norm = old_grad_p / old_grad_p.norm(p=2)
g_s_grad_norm = grad / grad_norm

self.state[f'gv_{i}']['gv'] = torch.sub(
grad, grad_norm * torch.sum(g_grad_norm * g_s_grad_norm) * g_grad_norm
)
else:
gv = self.state[f'gv_{i}']['gv']
grad.add_(grad_norm / (gv.norm(p=2) + 1e-8) * gv, alpha=self.alpha)

p.data = self.state[p]['old_p']

self.base_optimizer.step()

if zero_grad:
self.zero_grad()

@torch.no_grad()
def step(self, closure: CLOSURE = None):
if closure is None:
raise NoClosureError(str(self))

self.first_step(zero_grad=True)

with torch.enable_grad():
closure()

self.second_step()

def grad_norm(self) -> torch.Tensor:
shared_device = self.param_groups[0]['params'][0].device
return torch.norm(
torch.stack(
[
((torch.abs(p) if group['adaptive'] else 1.0) * p.grad).norm(p=2).to(shared_device)
for group in self.param_groups
for p in group['params']
if p.grad is not None
]
),
p=2,
)

def load_state_dict(self, state_dict: Dict):
super().load_state_dict(state_dict)
self.base_optimizer.param_groups = self.param_groups
1 change: 1 addition & 0 deletions tests/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,7 @@
'sam',
'gsam',
'wsam',
'looksam',
'pcgrad',
'lookahead',
'trac',
Expand Down
7 changes: 4 additions & 3 deletions tests/test_gradients.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import torch

from pytorch_optimizer.base.exception import NoSparseGradientError
from pytorch_optimizer.optimizer import SAM, TRAC, WSAM, AdamP, Lookahead, OrthoGrad, load_optimizer
from pytorch_optimizer.optimizer import SAM, TRAC, WSAM, AdamP, Lookahead, LookSAM, OrthoGrad, load_optimizer
from tests.constants import NO_SPARSE_OPTIMIZERS, SPARSE_OPTIMIZERS, VALID_OPTIMIZER_NAMES
from tests.utils import build_environment, simple_parameter, simple_sparse_parameter, sphere_loss

Expand Down Expand Up @@ -116,12 +116,13 @@ def test_sparse_supported(sparse_optimizer):
optimizer.step()


def test_sam_no_gradient():
@pytest.mark.parametrize('optimizer', [SAM, LookSAM])
def test_sam_no_gradient(optimizer):
(x_data, y_data), model, loss_fn = build_environment()
model.fc1.weight.requires_grad = False
model.fc1.weight.grad = None

optimizer = SAM(model.parameters(), AdamP)
optimizer = optimizer(model.parameters(), AdamP)
optimizer.zero_grad()

loss = loss_fn(y_data, model(x_data))
Expand Down
7 changes: 7 additions & 0 deletions tests/test_optimizer_parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
TRAC,
WSAM,
Lookahead,
LookSAM,
OrthoGrad,
PCGrad,
Ranger21,
Expand Down Expand Up @@ -110,6 +111,12 @@ def test_wsam_methods():
optimizer.load_state_dict(optimizer.state_dict())


def test_looksam_methods():
optimizer = LookSAM([simple_parameter()], load_optimizer('adamp'))
optimizer.reset()
optimizer.load_state_dict(optimizer.state_dict())


def test_safe_fp16_methods():
optimizer = SafeFP16Optimizer(load_optimizer('adamp')([simple_parameter()], lr=5e-1))
optimizer.load_state_dict(optimizer.state_dict())
Expand Down
Loading