-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathengine.py
84 lines (67 loc) · 2.2 KB
/
engine.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
from pathlib import Path
import torch
from tqdm.auto import tqdm
from util import save_model
def train_step(model, optimizer, loss_fn, x, y):
logits = model(x)
_, _, C = logits.shape
logits = logits.view(-1, C)
y = y.view(-1)
loss = loss_fn(logits, y)
model.zero_grad(set_to_none=True)
loss.backward()
optimizer.step()
def evaluate_model(model, steps, loss_fn, train_data_loader, test_data_loader, device):
with torch.inference_mode():
train_loss_total, test_loss_total = 0, 0
train_iter, test_iter = iter(train_data_loader), iter(test_data_loader)
for _ in tqdm(range(steps), desc="Evaluating", leave=False):
x, y = next(train_iter)
x, y = x.to(device), y.to(device)
logits = model(x)
_, _, C = logits.shape
logits = logits.view(-1, C)
y = y.view(-1)
train_loss = loss_fn(logits, y)
x, y = next(test_iter)
x, y = x.to(device), y.to(device)
logits = model(x)
_, _, C = logits.shape
logits = logits.view(-1, C)
y = y.view(-1)
test_loss = loss_fn(logits, y)
train_loss_total += train_loss
test_loss_total += test_loss
return train_loss_total / steps, test_loss_total / steps
def train(
model,
optimizer,
loss_fn,
train_data_loader,
test_data_loader,
train_steps,
log_interval,
eval_steps,
device,
):
model.to(device)
model.train()
train_iter = iter(train_data_loader)
for epoch in tqdm(range(train_steps), desc="Training"):
x, y = next(train_iter)
x, y = x.to(device), y.to(device)
train_step(model, optimizer, loss_fn, x, y)
if (epoch == 1) or (epoch + 1) % log_interval == 0:
model.eval()
train_loss, test_loss = evaluate_model(
model,
eval_steps,
loss_fn,
train_data_loader,
test_data_loader,
device,
)
tqdm.write(
f"Epoch {epoch + 1}: Train loss: {train_loss:.3f}, Test loss: {test_loss:.3f}"
)
model.train()