-
Notifications
You must be signed in to change notification settings - Fork 0
/
train.py
114 lines (93 loc) · 3.39 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
"""
Main file for training Yolo model on Pascal VOC and COCO dataset
"""
import config
import torch
import torch.optim as optim
from model import Yolov5,config1
from tqdm.auto import tqdm
from utils import (
mean_average_precision,
cells_to_bboxes,
get_evaluation_bboxes,
save_checkpoint,
load_checkpoint,
check_class_accuracy,
get_loaders,
plot_couple_examples
)
from loss import YoloLoss
import warnings
warnings.filterwarnings("ignore")
#This flag allows you to enable the inbuilt cudnn auto-tuner to find the best algorithm to use for your hardware.
torch.backends.cudnn.benchmark = True
def train_fn(train_loader, model, optimizer, loss_fn, scaler, scaled_anchors):
loop = tqdm(train_loader, leave=True)
losses = []
for batch_idx, (x, y) in enumerate(loop):
x = x.to(config.DEVICE)
y0, y1, y2 = (
y[0].to(config.DEVICE),
y[1].to(config.DEVICE),
y[2].to(config.DEVICE),
)
with torch.cuda.amp.autocast():
out = model(x)
loss = (
loss_fn(out[0], y0, scaled_anchors[0])
+ loss_fn(out[1], y1, scaled_anchors[1])
+ loss_fn(out[2], y2, scaled_anchors[2])
)
losses.append(loss.item())
optimizer.zero_grad()
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()
# update progress bar
mean_loss = sum(losses) / len(losses)
loop.set_postfix(loss=mean_loss)
def main():
model = Yolov5(config=config1,n_channels=3,num_classes=config.NUM_CLASSES).to(config.DEVICE)
optimizer = optim.Adam(
model.parameters(), lr=config.LEARNING_RATE, weight_decay=config.WEIGHT_DECAY
)
loss_fn = YoloLoss()
scaler = torch.cuda.amp.GradScaler()
train_loader, test_loader, train_eval_loader = get_loaders()
if config.LOAD_MODEL:
load_checkpoint(
config.CHECKPOINT_FILE, model, optimizer, config.LEARNING_RATE
)
scaled_anchors = (
torch.tensor(config.ANCHORS)
* torch.tensor(config.S).unsqueeze(1).unsqueeze(1).repeat(1, 3, 2)
).to(config.DEVICE)
for epoch in range(config.NUM_EPOCHS):
#plot_couple_examples(model, test_loader, 0.6, 0.5, scaled_anchors)
train_fn(train_loader, model, optimizer, loss_fn, scaler, scaled_anchors)
#if config.SAVE_MODEL:
# save_checkpoint(model, optimizer, filename=f"checkpoint.pth.tar")
#print(f"Currently epoch {epoch}")
#print("On Train Eval loader:")
#print("On Train loader:")
#check_class_accuracy(model, train_loader, threshold=config.CONF_THRESHOLD)
if epoch > 0 and epoch % 3 == 0:
check_class_accuracy(model, test_loader, threshold=config.CONF_THRESHOLD)
pred_boxes, true_boxes = get_evaluation_bboxes(
test_loader,
model,
iou_threshold=config.NMS_IOU_THRESH,
anchors=config.ANCHORS,
threshold=config.CONF_THRESHOLD,
)
mapval = mean_average_precision(
pred_boxes,
true_boxes,
iou_threshold=config.MAP_IOU_THRESH,
box_format="midpoint",
num_classes=config.NUM_CLASSES,
)
print(f"MAP: {mapval.item()}")
model.train()
if __name__ == "__main__":
main()