88import torch .distributed as dist
99from lightning import LightningModule , Trainer
1010from lightning .pytorch .callbacks import Callback
11- from lightning .pytorch .utilities import rank_zero_only
1211from omegaconf import ListConfig
1312from torch import Tensor , no_grad
1413from torch .optim import Optimizer
@@ -77,9 +76,9 @@ def create_optimizer(model: YOLO, optim_cfg: OptimizerConfig) -> Optimizer:
7776 conv_params = [p for name , p in model .named_parameters () if "weight" in name and "bn" not in name ]
7877
7978 model_parameters = [
80- {"params" : bias_params , "momentum" : 0.8 , "weight_decay" : 0 },
81- {"params" : conv_params , "momentum" : 0.8 },
82- {"params" : norm_params , "momentum" : 0.8 , "weight_decay" : 0 },
79+ {"params" : bias_params , "momentum" : 0.937 , "weight_decay" : 0 },
80+ {"params" : conv_params , "momentum" : 0.937 },
81+ {"params" : norm_params , "momentum" : 0.937 , "weight_decay" : 0 },
8382 ]
8483
8584 def next_epoch (self , batch_num , epoch_idx ):
@@ -89,8 +88,8 @@ def next_epoch(self, batch_num, epoch_idx):
8988 # 0.937: Start Momentum
9089 # 0.8 : Normal Momemtum
9190 # 3 : The warm up epoch num
92- self .min_mom = lerp (0.937 , 0.8 , max (epoch_idx , 3 ), 3 )
93- self .max_mom = lerp (0.937 , 0.8 , max (epoch_idx + 1 , 3 ), 3 )
91+ self .min_mom = lerp (0.937 , 0.8 , min (epoch_idx , 3 ), 3 )
92+ self .max_mom = lerp (0.937 , 0.8 , min (epoch_idx + 1 , 3 ), 3 )
9493 self .batch_num = batch_num
9594 self .batch_idx = 0
9695
@@ -100,7 +99,7 @@ def next_batch(self):
10099 for lr_idx , param_group in enumerate (self .param_groups ):
101100 min_lr , max_lr = self .min_lr [lr_idx ], self .max_lr [lr_idx ]
102101 param_group ["lr" ] = lerp (min_lr , max_lr , self .batch_idx , self .batch_num )
103- param_group ["momentum" ] = lerp (self .min_mom , self .max_mom , self .batch_idx , self .batch_num )
102+ # param_group["momentum"] = lerp(self.min_mom, self.max_mom, self.batch_idx, self.batch_num)
104103 lr_dict [f"LR/{ lr_idx } " ] = param_group ["lr" ]
105104 return lr_dict
106105
@@ -125,7 +124,7 @@ def create_scheduler(optimizer: Optimizer, schedule_cfg: SchedulerConfig) -> _LR
125124 lambda1 = lambda epoch : (epoch + 1 ) / wepoch if epoch < wepoch else 1
126125 lambda2 = lambda epoch : 10 - 9 * ((epoch + 1 ) / wepoch ) if epoch < wepoch else 1
127126 warmup_schedule = LambdaLR (optimizer , lr_lambda = [lambda2 , lambda1 , lambda1 ])
128- schedule = SequentialLR (optimizer , schedulers = [warmup_schedule , schedule ], milestones = [2 ])
127+ schedule = SequentialLR (optimizer , schedulers = [warmup_schedule , schedule ], milestones = [wepoch - 1 ])
129128 return schedule
130129
131130
0 commit comments