Skip to content

Commit b4dad5e

Browse files
committed
🔥 [Remove] & fix typo of momentum schedule
1 parent 67fbfa0 commit b4dad5e

File tree

2 files changed

+8
-8
lines changed

2 files changed

+8
-8
lines changed

‎yolo/config/config.py‎

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,7 @@ class DataConfig:
6666
class OptimizerArgs:
6767
lr: float
6868
weight_decay: float
69+
momentum: float
6970

7071

7172
@dataclass

‎yolo/utils/model_utils.py‎

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88
import torch.distributed as dist
99
from lightning import LightningModule, Trainer
1010
from lightning.pytorch.callbacks import Callback
11-
from lightning.pytorch.utilities import rank_zero_only
1211
from omegaconf import ListConfig
1312
from torch import Tensor, no_grad
1413
from torch.optim import Optimizer
@@ -77,9 +76,9 @@ def create_optimizer(model: YOLO, optim_cfg: OptimizerConfig) -> Optimizer:
7776
conv_params = [p for name, p in model.named_parameters() if "weight" in name and "bn" not in name]
7877

7978
model_parameters = [
80-
{"params": bias_params, "momentum": 0.8, "weight_decay": 0},
81-
{"params": conv_params, "momentum": 0.8},
82-
{"params": norm_params, "momentum": 0.8, "weight_decay": 0},
79+
{"params": bias_params, "momentum": 0.937, "weight_decay": 0},
80+
{"params": conv_params, "momentum": 0.937},
81+
{"params": norm_params, "momentum": 0.937, "weight_decay": 0},
8382
]
8483

8584
def next_epoch(self, batch_num, epoch_idx):
@@ -89,8 +88,8 @@ def next_epoch(self, batch_num, epoch_idx):
8988
# 0.937: Start Momentum
9089
# 0.8 : Normal Momemtum
9190
# 3 : The warm up epoch num
92-
self.min_mom = lerp(0.937, 0.8, max(epoch_idx, 3), 3)
93-
self.max_mom = lerp(0.937, 0.8, max(epoch_idx + 1, 3), 3)
91+
self.min_mom = lerp(0.937, 0.8, min(epoch_idx, 3), 3)
92+
self.max_mom = lerp(0.937, 0.8, min(epoch_idx + 1, 3), 3)
9493
self.batch_num = batch_num
9594
self.batch_idx = 0
9695

@@ -100,7 +99,7 @@ def next_batch(self):
10099
for lr_idx, param_group in enumerate(self.param_groups):
101100
min_lr, max_lr = self.min_lr[lr_idx], self.max_lr[lr_idx]
102101
param_group["lr"] = lerp(min_lr, max_lr, self.batch_idx, self.batch_num)
103-
param_group["momentum"] = lerp(self.min_mom, self.max_mom, self.batch_idx, self.batch_num)
102+
# param_group["momentum"] = lerp(self.min_mom, self.max_mom, self.batch_idx, self.batch_num)
104103
lr_dict[f"LR/{lr_idx}"] = param_group["lr"]
105104
return lr_dict
106105

@@ -125,7 +124,7 @@ def create_scheduler(optimizer: Optimizer, schedule_cfg: SchedulerConfig) -> _LR
125124
lambda1 = lambda epoch: (epoch + 1) / wepoch if epoch < wepoch else 1
126125
lambda2 = lambda epoch: 10 - 9 * ((epoch + 1) / wepoch) if epoch < wepoch else 1
127126
warmup_schedule = LambdaLR(optimizer, lr_lambda=[lambda2, lambda1, lambda1])
128-
schedule = SequentialLR(optimizer, schedulers=[warmup_schedule, schedule], milestones=[2])
127+
schedule = SequentialLR(optimizer, schedulers=[warmup_schedule, schedule], milestones=[wepoch - 1])
129128
return schedule
130129

131130

0 commit comments

Comments
 (0)