-
Notifications
You must be signed in to change notification settings - Fork 0
/
lens_main.py
126 lines (115 loc) · 4.46 KB
/
lens_main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
import os
import sys
import datetime
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import torch
from torch.utils.data.dataloader import DataLoader
from torchvision import transforms
import ground_based_dataset as gbd
import resnet_ssl_model as rsm
from data_transforms import Log10, Clamp, AugmentTranslate, WhitenInput
from run_loop import SNTGRunLoop
from learning_rate_update import learning_rate_update
path = ''
if sys.platform == 'linux':
path = '/home/mingx/datasets/'
else:
path = 'C:\\Users\\miles\\Documents\\dataset'
save_path = os.path.join(path, 'saved_model')
train_params = {
'n_data': 5120,
'num_classes': 2,
'batch_size': 128,
'n_eval_data': 1024,
'test_offset': 10000,
'test_len': 1000,
'run_eval': True,
'run_test': False,
'num_epochs': 100,
'rampup_length': 80,
'rampdown_length': 50,
'learning_rate': 0.003,
'pred_decay': 0.6,
'embed': True,
'embed_coeff': 0.2,
'adam_beta1': 0.9,
'rd_beta1_target': 0.5,
'augment_mirror': True,
'augment_translation': 2,
'unsup_wght': 0.0,
'whiten_inputs': 'norm', # norm, zca
'polyak_decay': 0.999,
'unsup_wght_scale': 1.0
}
torch.manual_seed(770715)
has_cuda = torch.cuda.is_available()
if has_cuda:
torch.cuda.manual_seed_all(770715)
torch.backends.cudnn.deterministic = True
train_composed = transforms.Compose(
[WhitenInput(), Clamp(1e-9, 100),
AugmentTranslate(train_params['augment_translation'], 101)])
test_composed = transforms.Compose([WhitenInput(), Clamp(1e-9, 100)])
ground_train_dataset = gbd.GroundBasedDataset(
path, length=train_params['n_data'], transform=train_composed)
ground_train_loader = DataLoader(
ground_train_dataset, batch_size=train_params['batch_size'],
shuffle=True, pin_memory=not has_cuda)
ground_eval_dataset = gbd.GroundBasedDataset(
path, offset=train_params['n_data'], length=train_params['n_eval_data'],
transform=test_composed)
ground_eval_loader = DataLoader(
ground_eval_dataset, batch_size=train_params['batch_size'],
shuffle=False, pin_memory=not has_cuda)
ground_test_dataset = gbd.GroundBasedDataset(
path, offset=train_params['test_offset'], length=train_params['test_len'],
transform=test_composed)
ground_test_loader = DataLoader(
ground_test_dataset, batch_size=train_params['test_len'],
pin_memory=not has_cuda)
# ssl_lens_net = rsm.ResNetSSL([3, 3, 3, 3, 3])
ssl_lens_net = rsm.SNTGModel(4)
# labeled_loss = nn.BCEWithLogitsLoss()
lr_fn = learning_rate_update(
train_params['rampup_length'], train_params['rampdown_length'],
train_params['learning_rate'], train_params['adam_beta1'],
train_params['rd_beta1_target'], train_params['num_epochs'],
scale=train_params['unsup_wght_scale']
# train_params['rd_beta1_target'], 100
)
rnssl_run_loop = SNTGRunLoop(
ssl_lens_net, ground_train_loader, train_params, lr_fn,
eval_loader=ground_eval_loader, test_loader=ground_test_loader,
has_cuda=has_cuda)
train_losses, train_accs, eval_losses, eval_accs, \
ema_eval_losses, ema_eval_accs = rnssl_run_loop.train()
# rnssl_run_loop.test()
result = pd.DataFrame(
data={'epoch': np.arange(train_params['num_epochs']),
'train loss': train_losses, 'evaluation loss': eval_losses,
'ema evaluation loss': ema_eval_losses,
'train accuracy': train_accs, 'evaluation accuracy': eval_accs,
'ema evaluation accuracy': ema_eval_accs})
# show result
fig, (ax1, ax2) = plt.subplots(ncols=2, nrows=1, figsize=(20, 10))
result.plot(x='epoch', y='train loss', color='C1', ax=ax1)
result.plot(x='epoch', y='evaluation loss', color='C2', ax=ax1)
# result.plot(x='epoch', y='ema evaluation loss', color='C3', ax=ax1)
result.plot(x='epoch', y='train accuracy', color='C1', ax=ax2)
result.plot(x='epoch', y='evaluation accuracy', color='C2', ax=ax2)
# result.plot(x='epoch', y='ema evaluation accuracy', color='C3', ax=ax2)
if not os.path.isdir(save_path):
os.mkdir(save_path)
plot_file_name = 'ground_based' + \
datetime.datetime.now().isoformat(
'-', timespec='minutes').replace(':', '-') + '.png'
plt.savefig(os.path.join(save_path, plot_file_name))
arg_file_name = 'ground_based' + \
datetime.datetime.now().isoformat(
'-', timespec='minutes').replace(':', '-') + '.pth'
torch.save(ssl_lens_net.state_dict(), os.path.join(save_path, arg_file_name))
# test_net = rsm.ResNetSSL([3, 3, 3, 3, 3])
# test_net.load_state_dict(torch.load(os.path.join(save_path, file_name)))
# test_net.eval()