55import logging
66from dataclasses import dataclass
77from functools import partial
8- from typing import Any , Callable , Dict , List , Optional , Set , Tuple , Type , Union
8+ from typing import Any , Callable , Collection , Dict , List , Optional , Set , Tuple , Type , Union
99from fnmatch import fnmatch
1010import importlib
1111
@@ -234,7 +234,8 @@ def create_optimizer(
234234 momentum : float = 0.9 ,
235235 foreach : Optional [bool ] = None ,
236236 weight_decay_exclude_1d : bool = True ,
237- simple_no_weight_decay : bool = False ,
237+ fallback_list : Collection [str ] = (),
238+ fallback_no_weight_decay : bool = False ,
238239 layer_decay : Optional [float ] = None ,
239240 layer_decay_min_scale : Optional [float ] = None ,
240241 layer_decay_no_opt_scale : Optional [float ] = None ,
@@ -251,7 +252,8 @@ def create_optimizer(
251252 momentum: Momentum factor for applicable optimizers
252253 foreach: Enable/disable foreach operation
253254 weight_decay_exclude_1d: Whether to skip weight decay for 1d params (biases and norm affine)
254- simple_no_weight_decay: If True, params in no_weight_decay list will use simple/fallback optimizer (e.g., AdamW for Muon)
255+ fallback_list: Collection of parameter name patterns to use fallback optimizer for hybrid optimizers
256+ fallback_no_weight_decay: If True, params in no_weight_decay list will use fallback optimizer (e.g., AdamW for Muon)
255257 layer_decay: Layer-wise learning rate decay
256258 layer_scale_min_scale: Minimum layer scale factor clamp value
257259 layer_scale_no_opt_scale: Layer scale below which optimization is disabled
@@ -279,7 +281,8 @@ def create_optimizer(
279281 weight_decay = weight_decay ,
280282 layer_decay = layer_decay ,
281283 no_weight_decay_list = no_weight_decay ,
282- simple_no_weight_decay = simple_no_weight_decay ,
284+ fallback_list = fallback_list ,
285+ fallback_no_weight_decay = fallback_no_weight_decay ,
283286 weight_decay_exclude_1d = weight_decay_exclude_1d ,
284287 min_scale = layer_decay_min_scale ,
285288 no_opt_scale = layer_decay_no_opt_scale ,
@@ -290,7 +293,8 @@ def create_optimizer(
290293 model_or_params ,
291294 weight_decay = weight_decay ,
292295 no_weight_decay_list = no_weight_decay ,
293- simple_no_weight_decay = simple_no_weight_decay ,
296+ fallback_list = fallback_list ,
297+ fallback_no_weight_decay = fallback_no_weight_decay ,
294298 )
295299 weight_decay = 0.
296300 else :
@@ -1167,7 +1171,8 @@ def create_optimizer_v2(
11671171 momentum : float = 0.9 ,
11681172 foreach : Optional [bool ] = None ,
11691173 filter_bias_and_bn : bool = True ,
1170- simple_no_weight_decay : bool = False ,
1174+ fallback_list : Collection [str ] = (),
1175+ fallback_no_weight_decay : bool = False ,
11711176 layer_decay : Optional [float ] = None ,
11721177 layer_decay_min_scale : float = 0.0 ,
11731178 layer_decay_no_opt_scale : Optional [float ] = None ,
@@ -1195,8 +1200,10 @@ def create_optimizer_v2(
11951200 filter_bias_and_bn: If True, bias, norm layer parameters (all 1d params) will not have
11961201 weight decay applied. Only used when model_or_params is a model and
11971202 weight_decay > 0.
1198- simple_no_weight_decay: If True, params in model's no_weight_decay() list will use
1199- simple/fallback optimizer for hybrid optimizers (e.g., AdamW for Muon).
1203+ fallback_list: Collection of parameter name patterns to use fallback optimizer for
1204+ hybrid optimizers (e.g., AdamW for Muon). Supports wildcard matching.
1205+ fallback_no_weight_decay: If True, params in model's no_weight_decay() list will use
1206+ fallback optimizer for hybrid optimizers (e.g., AdamW for Muon).
12001207 layer_decay: Optional layer-wise learning rate decay factor. If provided,
12011208 learning rates will be scaled by layer_decay^(max_depth - layer_depth).
12021209 Only used when model_or_params is a model.
@@ -1247,7 +1254,8 @@ def create_optimizer_v2(
12471254 momentum = momentum ,
12481255 foreach = foreach ,
12491256 weight_decay_exclude_1d = filter_bias_and_bn ,
1250- simple_no_weight_decay = simple_no_weight_decay ,
1257+ fallback_list = fallback_list ,
1258+ fallback_no_weight_decay = fallback_no_weight_decay ,
12511259 layer_decay = layer_decay ,
12521260 layer_decay_min_scale = layer_decay_min_scale ,
12531261 layer_decay_no_opt_scale = layer_decay_no_opt_scale ,
0 commit comments