-
Notifications
You must be signed in to change notification settings - Fork 0
/
test_AE_Mogi.py
171 lines (148 loc) · 6.2 KB
/
test_AE_Mogi.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
import argparse
import json
import torch
from tqdm import tqdm
import data_loader.data_loaders as module_data
import model.loss as module_loss
import model.metric as module_metric
import model.model as module_arch
from parse_config import ConfigParser
import pandas as pd
import numpy as np
from rtm_torch.rtm import RTM
import os
CURRENT_DIR = os.path.dirname(os.path.realpath(__file__))
def main(config):
logger = config.get_logger('test')
# setup data_loader instances
data_loader = getattr(module_data, config['data_loader']['type_test'])(
config['data_loader']['data_dir_test'],
batch_size=512,
shuffle=False,
validation_split=0.0,
num_workers=2
)
# build model architecture
model = config.init_obj('arch', module_arch)
logger.info(model)
# get function handles of loss and metrics
loss_fn = getattr(module_loss, config['loss_test'])
metric_fns = [getattr(module_metric, met) for met in config['metrics']]
logger.info('Loading checkpoint: {} ...'.format(config.resume))
checkpoint = torch.load(config.resume)
state_dict = checkpoint['state_dict']
if config['n_gpu'] > 1:
model = torch.nn.DataParallel(model)
model.load_state_dict(state_dict)
# prepare model for testing
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = model.to(device)
model.eval()
total_loss = 0.0
total_metrics = torch.zeros(len(metric_fns))
data_key = config['trainer']['input_key']
target_key = config['trainer']['output_key']
# analyze the reconstruction loss per band
if config['arch']['type'] == 'VanillaAE':
# use number of the latent codes as their names
ATTRS = [str(i+1) for i in range(config['arch']['args']['hidden_dim'])]
else:
ATTRS = ['xcen', 'ycen', 'd', 'dV']
analyzer = {}
station_info = json.load(open(os.path.join(CURRENT_DIR,
'configs/station_info.json')))
GPS = []
for direction in ['ux', 'uy', 'uz']:
for station in station_info.keys():
GPS.append(f'{direction}_{station}')
with torch.no_grad():
for batch_idx, data_dict in enumerate(data_loader):
data = data_dict[data_key].to(device)
target = data_dict[target_key].to(device)
if config['arch']['type'] in ['AE_Mogi', 'AE_Mogi_corr']:
outputs = model(data)
output = outputs[-1]
latent = outputs[1]
if config['arch']['type'] == 'AE_Mogi_corr':
# get the model output and compute the bias
init_output = outputs[2]
bias = output - init_output
elif config['arch']['type'] == 'VanillaAE':
output = model(data)
latent = model.encode(data)
# add the model output and corrected bias if the model is AE_RTM_corr
if config['arch']['type'] in ['AE_Mogi_corr']:
data_concat(analyzer, 'init_output', init_output)
data_concat(analyzer, 'bias', bias)
if config['arch']['type'] == 'VanillaAE':
assert len(
ATTRS) == latent.shape[1], "latent shape does not match"
else:
assert ATTRS == list(latent.keys()), "latent keys do not match"
# latent is a dictionary of parameters, convert it to a tensor
latent = torch.stack([latent[k] for k in latent.keys()], dim=1)
# l2_per_band = torch.square(output-target)
data_concat(analyzer, 'output', output)
data_concat(analyzer, 'target', target)
data_concat(analyzer, 'latent', latent)
data_concat(analyzer, 'date', data_dict['date'])
# computing loss, metrics on test set
loss = loss_fn(output, target)
batch_size = data.shape[0]
total_loss += loss.item() * batch_size
for i, metric in enumerate(metric_fns):
total_metrics[i] += metric(output, target) * batch_size
n_samples = len(data_loader.sampler)
log = {'loss': total_loss / n_samples}
log.update({
met.__name__: total_metrics[i].item() / n_samples
for i, met in enumerate(metric_fns)
})
logger.info(log)
# save the analyzer to csv using pandas
columns = []
for k in ['output', 'target', 'latent']:
if k != 'latent':
columns += [k+'_'+b for b in GPS]
else:
columns += [k+'_'+b for b in ATTRS]
# hstack the columns we want to save
data = torch.hstack((
analyzer['output'],
analyzer['target'],
# analyzer['l2'],
analyzer['latent']
))
if config['arch']['type'] in ['AE_Mogi_corr']:
columns += ['init_output_'+b for b in GPS]
columns += ['bias_'+b for b in GPS]
data = torch.hstack((
data,
analyzer['init_output'],
analyzer['bias']
))
data = data.cpu().numpy()
df = pd.DataFrame(columns=columns, data=data)
df['date'] = analyzer['date']
df.to_csv(os.path.join(CURRENT_DIR, str(config.resume).split('.pth')[0]+'_testset_analyzer.csv'),
index=False)
logger.info('Analyzer saved to {}'.format(
os.path.join(CURRENT_DIR, str(config.resume).split('.pth')[0]+'_testset_analyzer.csv')
))
def data_concat(analyzer: dict, key: str, data):
if key not in analyzer:
analyzer[key] = data
elif type(data) == torch.Tensor:
analyzer[key] = torch.cat((analyzer[key], data), dim=0)
elif type(data) == list:
analyzer[key] = analyzer[key] + data
if __name__ == '__main__':
args = argparse.ArgumentParser(description='PyTorch Template')
args.add_argument('-c', '--config', default=None, type=str,
help='config file path (default: None)')
args.add_argument('-r', '--resume', default=None, type=str,
help='path to latest checkpoint (default: None)')
args.add_argument('-d', '--device', default=None, type=str,
help='indices of GPUs to enable (default: all)')
config = ConfigParser.from_args(args)
main(config)