Skip to content

Commit cebc007

Browse files
committed
Rename 'simple' flag for Muon to 'fallback', add support for inverted 'use_muon' to be compat with other Muon impl. Add fallback_list arg to full optim factory call chain.
1 parent d8b4c34 commit cebc007

File tree

3 files changed

+57
-43
lines changed

3 files changed

+57
-43
lines changed

timm/optim/_optim_factory.py

Lines changed: 17 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import logging
66
from dataclasses import dataclass
77
from functools import partial
8-
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Type, Union
8+
from typing import Any, Callable, Collection, Dict, List, Optional, Set, Tuple, Type, Union
99
from fnmatch import fnmatch
1010
import importlib
1111

@@ -234,7 +234,8 @@ def create_optimizer(
234234
momentum: float = 0.9,
235235
foreach: Optional[bool] = None,
236236
weight_decay_exclude_1d: bool = True,
237-
simple_no_weight_decay: bool = False,
237+
fallback_list: Collection[str] = (),
238+
fallback_no_weight_decay: bool = False,
238239
layer_decay: Optional[float] = None,
239240
layer_decay_min_scale: Optional[float] = None,
240241
layer_decay_no_opt_scale: Optional[float] = None,
@@ -251,7 +252,8 @@ def create_optimizer(
251252
momentum: Momentum factor for applicable optimizers
252253
foreach: Enable/disable foreach operation
253254
weight_decay_exclude_1d: Whether to skip weight decay for 1d params (biases and norm affine)
254-
simple_no_weight_decay: If True, params in no_weight_decay list will use simple/fallback optimizer (e.g., AdamW for Muon)
255+
fallback_list: Collection of parameter name patterns to use fallback optimizer for hybrid optimizers
256+
fallback_no_weight_decay: If True, params in no_weight_decay list will use fallback optimizer (e.g., AdamW for Muon)
255257
layer_decay: Layer-wise learning rate decay
256258
layer_scale_min_scale: Minimum layer scale factor clamp value
257259
layer_scale_no_opt_scale: Layer scale below which optimization is disabled
@@ -279,7 +281,8 @@ def create_optimizer(
279281
weight_decay=weight_decay,
280282
layer_decay=layer_decay,
281283
no_weight_decay_list=no_weight_decay,
282-
simple_no_weight_decay=simple_no_weight_decay,
284+
fallback_list=fallback_list,
285+
fallback_no_weight_decay=fallback_no_weight_decay,
283286
weight_decay_exclude_1d=weight_decay_exclude_1d,
284287
min_scale=layer_decay_min_scale,
285288
no_opt_scale=layer_decay_no_opt_scale,
@@ -290,7 +293,8 @@ def create_optimizer(
290293
model_or_params,
291294
weight_decay=weight_decay,
292295
no_weight_decay_list=no_weight_decay,
293-
simple_no_weight_decay=simple_no_weight_decay,
296+
fallback_list=fallback_list,
297+
fallback_no_weight_decay=fallback_no_weight_decay,
294298
)
295299
weight_decay = 0.
296300
else:
@@ -1167,7 +1171,8 @@ def create_optimizer_v2(
11671171
momentum: float = 0.9,
11681172
foreach: Optional[bool] = None,
11691173
filter_bias_and_bn: bool = True,
1170-
simple_no_weight_decay: bool = False,
1174+
fallback_list: Collection[str] = (),
1175+
fallback_no_weight_decay: bool = False,
11711176
layer_decay: Optional[float] = None,
11721177
layer_decay_min_scale: float = 0.0,
11731178
layer_decay_no_opt_scale: Optional[float] = None,
@@ -1195,8 +1200,10 @@ def create_optimizer_v2(
11951200
filter_bias_and_bn: If True, bias, norm layer parameters (all 1d params) will not have
11961201
weight decay applied. Only used when model_or_params is a model and
11971202
weight_decay > 0.
1198-
simple_no_weight_decay: If True, params in model's no_weight_decay() list will use
1199-
simple/fallback optimizer for hybrid optimizers (e.g., AdamW for Muon).
1203+
fallback_list: Collection of parameter name patterns to use fallback optimizer for
1204+
hybrid optimizers (e.g., AdamW for Muon). Supports wildcard matching.
1205+
fallback_no_weight_decay: If True, params in model's no_weight_decay() list will use
1206+
fallback optimizer for hybrid optimizers (e.g., AdamW for Muon).
12001207
layer_decay: Optional layer-wise learning rate decay factor. If provided,
12011208
learning rates will be scaled by layer_decay^(max_depth - layer_depth).
12021209
Only used when model_or_params is a model.
@@ -1247,7 +1254,8 @@ def create_optimizer_v2(
12471254
momentum=momentum,
12481255
foreach=foreach,
12491256
weight_decay_exclude_1d=filter_bias_and_bn,
1250-
simple_no_weight_decay=simple_no_weight_decay,
1257+
fallback_list=fallback_list,
1258+
fallback_no_weight_decay=fallback_no_weight_decay,
12511259
layer_decay=layer_decay,
12521260
layer_decay_min_scale=layer_decay_min_scale,
12531261
layer_decay_no_opt_scale=layer_decay_no_opt_scale,

timm/optim/_param_groups.py

Lines changed: 29 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -20,36 +20,36 @@ def param_groups_weight_decay(
2020
model: nn.Module,
2121
weight_decay: float = 1e-5,
2222
no_weight_decay_list: Collection[str] = (),
23-
simple_params_list: Collection[str] = (),
24-
simple_no_weight_decay: bool = False,
23+
fallback_list: Collection[str] = (),
24+
fallback_no_weight_decay: bool = False,
2525
):
26-
# Merge no_weight_decay into simple_params if requested
27-
if simple_no_weight_decay:
28-
simple_params_list = set(simple_params_list) | set(no_weight_decay_list)
26+
# Merge no_weight_decay into fallback_list if requested
27+
if fallback_no_weight_decay:
28+
fallback_list = set(fallback_list) | set(no_weight_decay_list)
2929

3030
decay = []
31-
decay_simple = []
31+
decay_fallback = []
3232
no_decay = []
33-
no_decay_simple = []
33+
no_decay_fallback = []
3434
for name, param in model.named_parameters():
3535
if not param.requires_grad:
3636
continue
3737

38-
# Determine if this is a "simple" parameter for fallback optimizer (if available)
39-
is_simple = _matches_pattern(name, simple_params_list)
38+
# Determine if this is a "fallback" parameter for fallback optimizer (if available)
39+
is_fallback = _matches_pattern(name, fallback_list)
4040

4141
# Determine weight decay
4242
matches_pattern = _matches_pattern(name, no_weight_decay_list)
4343
if param.ndim <= 1 or name.endswith(".bias") or matches_pattern:
4444
# No weight decay
45-
if is_simple:
46-
no_decay_simple.append(param)
45+
if is_fallback:
46+
no_decay_fallback.append(param)
4747
else:
4848
no_decay.append(param)
4949
else:
5050
# With weight decay
51-
if is_simple:
52-
decay_simple.append(param)
51+
if is_fallback:
52+
decay_fallback.append(param)
5353
else:
5454
decay.append(param)
5555

@@ -58,10 +58,10 @@ def param_groups_weight_decay(
5858
groups.append({'params': no_decay, 'weight_decay': 0.})
5959
if decay:
6060
groups.append({'params': decay, 'weight_decay': weight_decay})
61-
if no_decay_simple:
62-
groups.append({'params': no_decay_simple, 'weight_decay': 0., 'simple': True})
63-
if decay_simple:
64-
groups.append({'params': decay_simple, 'weight_decay': weight_decay, 'simple': True})
61+
if no_decay_fallback:
62+
groups.append({'params': no_decay_fallback, 'weight_decay': 0., 'use_fallback': True})
63+
if decay_fallback:
64+
groups.append({'params': decay_fallback, 'weight_decay': weight_decay, 'use_fallback': True})
6565

6666
return groups
6767

@@ -103,8 +103,8 @@ def param_groups_layer_decay(
103103
model: nn.Module,
104104
weight_decay: float = 0.05,
105105
no_weight_decay_list: Collection[str] = (),
106-
simple_params_list: Collection[str] = (),
107-
simple_no_weight_decay: bool = False,
106+
fallback_list: Collection[str] = (),
107+
fallback_no_weight_decay: bool = False,
108108
weight_decay_exclude_1d: bool = True,
109109
layer_decay: float = .75,
110110
min_scale: float = 0.,
@@ -115,9 +115,9 @@ def param_groups_layer_decay(
115115
Parameter groups for layer-wise lr decay & weight decay
116116
Based on BEiT: https://github.com/microsoft/unilm/blob/master/beit/optim_factory.py#L58
117117
"""
118-
# Merge no_weight_decay into simple_params if requested
119-
if simple_no_weight_decay:
120-
simple_params_list = set(simple_params_list) | set(no_weight_decay_list)
118+
# Merge no_weight_decay into fallback_list if requested
119+
if fallback_no_weight_decay:
120+
fallback_list = set(fallback_list) | set(no_weight_decay_list)
121121

122122
param_group_names = {} # NOTE for debugging
123123
param_groups = {}
@@ -136,8 +136,8 @@ def param_groups_layer_decay(
136136
if not param.requires_grad:
137137
continue
138138

139-
# Determine if this is a "simple" parameter for fallback optimizer (if available)
140-
is_simple = _matches_pattern(name, simple_params_list)
139+
# Determine if this is a "fallback" parameter for fallback optimizer (if available)
140+
is_fallback = _matches_pattern(name, fallback_list)
141141

142142
# Determine weight decay
143143
if (weight_decay_exclude_1d and param.ndim <= 1) or _matches_pattern(name, no_weight_decay_list):
@@ -155,23 +155,23 @@ def param_groups_layer_decay(
155155
param.requires_grad = False
156156
continue
157157

158-
simple_suffix = "_simple" if is_simple else ""
159-
group_name = "layer_%d_%s%s" % (layer_id, g_decay, simple_suffix)
158+
fallback_suffix = "_fallback" if is_fallback else ""
159+
group_name = "layer_%d_%s%s" % (layer_id, g_decay, fallback_suffix)
160160

161161
if group_name not in param_groups:
162162
param_group_names[group_name] = {
163163
"lr_scale": this_scale,
164164
"weight_decay": this_decay,
165-
"simple": is_simple,
165+
"use_fallback": is_fallback,
166166
"param_names": [],
167167
}
168168
param_groups[group_name] = {
169169
"lr_scale": this_scale,
170170
"weight_decay": this_decay,
171171
"params": [],
172172
}
173-
if is_simple:
174-
param_groups[group_name]["simple"] = True
173+
if is_fallback:
174+
param_groups[group_name]["use_fallback"] = True
175175

176176
param_group_names[group_name]["param_names"].append(name)
177177
param_groups[group_name]["params"].append(param)

timm/optim/muon.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -376,7 +376,7 @@ class Muon(torch.optim.Optimizer):
376376
"""Muon - MomentUm Orthogonalized by Newton-schulz
377377
378378
Combines Muon for 2D+ parameters (weight matrices) with AdamW for 1D parameters (biases, norms) and
379-
parameter groups with 'simple=True' set.
379+
parameter groups with 'use_fallback=True' set (or 'use_muon=False' for compatibility).
380380
"""
381381

382382
def __init__(
@@ -423,7 +423,7 @@ def __init__(
423423
# Manual control over parameter groups
424424
optimizer = Muon([
425425
{'params': weight_matrices, 'lr': 0.02},
426-
{'params': biases, 'simple': True, 'lr': 3e-4}, # use AdamW if simple=True
426+
{'params': biases, 'use_fallback': True, 'lr': 3e-4}, # use AdamW if use_fallback=True
427427
])
428428
```
429429
"""
@@ -494,12 +494,18 @@ def step(self, closure=None):
494494

495495
# Determine routing on first encounter (cache in state)
496496
if "use_muon" not in state:
497-
# Check explicit simple flag first
497+
# Check explicit flags first (support both 'use_fallback' and 'use_muon' for compatibility)
498498
reason = None
499-
if group.get("simple", False):
499+
if group.get("use_fallback", False):
500+
# use_fallback=True means use AdamW (use_muon=False)
500501
state["use_muon"] = False
501502
if verbose:
502-
reason = "simple_flag"
503+
reason = "use_fallback_flag"
504+
elif "use_muon" in group:
505+
# Explicit use_muon flag for compatibility with other Muon implementations
506+
state["use_muon"] = group["use_muon"]
507+
if verbose:
508+
reason = "use_muon_flag"
503509
else:
504510
# Check shape suitability
505511
if verbose:

0 commit comments

Comments
 (0)