forked from XiaobingSuper/examples
-
Notifications
You must be signed in to change notification settings - Fork 0
/
main.py
83 lines (63 loc) · 3 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
from __future__ import print_function
import argparse
from math import log10
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from model import Net
from data import get_training_set, get_test_set
# Training settings
parser = argparse.ArgumentParser(description='PyTorch Super Res Example')
parser.add_argument('--upscale_factor', type=int, required=True, help="super resolution upscale factor")
parser.add_argument('--batchSize', type=int, default=64, help='training batch size')
parser.add_argument('--testBatchSize', type=int, default=10, help='testing batch size')
parser.add_argument('--nEpochs', type=int, default=2, help='number of epochs to train for')
parser.add_argument('--lr', type=float, default=0.01, help='Learning Rate. Default=0.01')
parser.add_argument('--cuda', action='store_true', help='use cuda?')
parser.add_argument('--threads', type=int, default=4, help='number of threads for data loader to use')
parser.add_argument('--seed', type=int, default=123, help='random seed to use. Default=123')
opt = parser.parse_args()
print(opt)
if opt.cuda and not torch.cuda.is_available():
raise Exception("No GPU found, please run without --cuda")
torch.manual_seed(opt.seed)
device = torch.device("cuda" if opt.cuda else "cpu")
print('===> Loading datasets')
train_set = get_training_set(opt.upscale_factor)
test_set = get_test_set(opt.upscale_factor)
training_data_loader = DataLoader(dataset=train_set, num_workers=opt.threads, batch_size=opt.batchSize, shuffle=True)
testing_data_loader = DataLoader(dataset=test_set, num_workers=opt.threads, batch_size=opt.testBatchSize, shuffle=False)
print('===> Building model')
model = Net(upscale_factor=opt.upscale_factor).to(device)
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=opt.lr)
def train(epoch):
epoch_loss = 0
for iteration, batch in enumerate(training_data_loader, 1):
input, target = batch[0].to(device), batch[1].to(device)
optimizer.zero_grad()
loss = criterion(model(input), target)
epoch_loss += loss.item()
loss.backward()
optimizer.step()
print("===> Epoch[{}]({}/{}): Loss: {:.4f}".format(epoch, iteration, len(training_data_loader), loss.item()))
print("===> Epoch {} Complete: Avg. Loss: {:.4f}".format(epoch, epoch_loss / len(training_data_loader)))
def test():
avg_psnr = 0
with torch.no_grad():
for batch in testing_data_loader:
input, target = batch[0].to(device), batch[1].to(device)
prediction = model(input)
mse = criterion(prediction, target)
psnr = 10 * log10(1 / mse.item())
avg_psnr += psnr
print("===> Avg. PSNR: {:.4f} dB".format(avg_psnr / len(testing_data_loader)))
def checkpoint(epoch):
model_out_path = "model_epoch_{}.pth".format(epoch)
torch.save(model, model_out_path)
print("Checkpoint saved to {}".format(model_out_path))
for epoch in range(1, opt.nEpochs + 1):
train(epoch)
test()
checkpoint(epoch)