-
Notifications
You must be signed in to change notification settings - Fork 0
/
schedular.py
96 lines (83 loc) · 3.53 KB
/
schedular.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
import torch
import math
class Scheduler:
def __init__(self, model, optim, loss_scaler, lr_scheduler) -> None:
self.model = model
self.optim = optim
self.scaler = loss_scaler
self.scheduler = lr_scheduler
def loss_scale(self, loss:torch.Tensor)->torch.Tensor:
return self.scaler.scale(loss)
def zero_grad(self):
self.optim.zero_grad()
def loss_scale_and_backward(self, loss:torch.Tensor, create_graph=False):
loss = self.loss_scale(loss)
loss.backward(create_graph=create_graph)
def step_and_lr_schedule(self, epoch, clip_grad=None, update_grad=True):
if update_grad:
if clip_grad is not None:
self.scaler.unscale_(self.optim) # unscale the gradients of optimizer's assigned params in-place
norm = torch.nn.utils.clip_grad_norm_(self.model.parameters(), clip_grad)
else:
self.scaler.unscale_(self.optim)
norm = get_grad_norm_(self.model.parameters())
self.scaler.step(self.optim)
self.scaler.update()
else:
norm = None
lr = self.scheduler.lr_schedule(self.optim, epoch)
return lr
class Scheduler_fsdp:
def __init__(self, model, optim, lr_scheduler) -> None:
self.model = model
self.optim = optim
self.scheduler = lr_scheduler
def zero_grad(self):
self.optim.zero_grad()
def loss_scale_and_backward(self, loss:torch.Tensor, create_graph=False):
loss.backward(create_graph=create_graph)
def step_and_lr_schedule(self, epoch, clip_grad=None, update_grad=True):
if update_grad:
# if clip_grad is not None:
# self.scaler.unscale_(self.optim) # unscale the gradients of optimizer's assigned params in-place
# norm = torch.nn.utils.clip_grad_norm_(self.model.parameters(), clip_grad)
# else:
# self.scaler.unscale_(self.optim)
# norm = get_grad_norm_(self.model.parameters())
self.optim.step()
# self.optim.update()
else:
norm = None
lr = self.scheduler.lr_schedule(self.optim, epoch)
return lr
def get_grad_norm_(parameters, norm_type: float = 2.0) -> torch.Tensor:
if isinstance(parameters, torch.Tensor):
parameters = [parameters]
parameters = [p for p in parameters if p.grad is not None]
norm_type = float(norm_type)
if len(parameters) == 0:
return torch.tensor(0.)
device = parameters[0].grad.device
if norm_type == math.inf:
total_norm = max(p.grad.detach().abs().max().to(device) for p in parameters)
else:
total_norm = torch.norm(torch.stack([torch.norm(p.grad.detach(), norm_type).to(device) for p in parameters]), norm_type)
return total_norm
class LinearScheduler:
def __init__(self, args) -> None:
self.warmup_epochs = args.warmup_epochs
self.epochs = args.epochs
self.lr = args.lr
self.min_lr = args.min_lr
def lr_schedule(self, optimizer, epoch):
if epoch < self.warmup_epochs:
lr = self.lr * epoch / self.warmup_epochs
else:
lr = self.lr - (self.lr - self.min_lr) * \
(epoch - self.warmup_epochs) / (self.epochs - self.warmup_epochs)
for param_group in optimizer.param_groups:
if "lr_scale" in param_group:
param_group["lr"] = lr * param_group["lr_scale"]
else:
param_group["lr"] = lr
return lr