-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmain_qa_VA_eval.py
168 lines (135 loc) · 6.07 KB
/
main_qa_VA_eval.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
from videoqa_GCS import *
from dataloader import sample_loader_VA as sample_loader
from build_vocab import Vocabulary
from utils import *
import argparse
import eval_mc_VA
NUM_THREADS = 1
torch.set_num_threads(NUM_THREADS)
def main(args):
mode = args.mode
if mode == 'train':
batch_size = 64
num_worker = 1
else:
batch_size = 64 #you may need to change to a number that is divisible by the size of test/val set, e.g., 4
num_worker = 2
model_type = 'GCS' #(GCS, EVQA, STVQA, CoMem, HME, HGA)
if model_type == 'STVQA':
spatial = True
else:
spatial = False # True for STVQA
if spatial:
#STVQA
video_feature_path = 'dataset/feats/'
video_feature_cache = 'dataset/feats/cache/'
else:
video_feature_cache = 'dataset/feats/cache/'
video_feature_path = 'dataset/feats/'
dataset = 'NExT-OOD'
sample_list_path = 'dataset/{}/'.format(dataset)
vocab = pkload('dataset/{}/vocab.pkl'.format(dataset))
glove_embed = 'dataset/{}/glove_embed.npy'.format(dataset)
use_bert = args.bert #True #Otherwise GloVe
model_prefix = args.checkpoint
checkpoint_path = 'models/' + args.checkpoint
if os.path.exists(checkpoint_path) == False:
os.mkdir(checkpoint_path)
vis_step = 106 # visual step
lr_rate = 5e-5 if use_bert else 1e-4
epoch_num = 50
data_loader = sample_loader.QALoader(batch_size, num_worker, video_feature_path, video_feature_cache,
sample_list_path, vocab, use_bert, model_type, args.N, True, False)
train_loader, val_loader, test_loader = data_loader.run(mode=mode)
vqa = VideoQA(vocab, train_loader, val_loader, test_loader, glove_embed, use_bert, checkpoint_path, model_type,
model_prefix, vis_step, lr_rate, batch_size, epoch_num, args.gin, args.delta, args.lambda1, args.lambda2)
results = {}
results_objs = {}
if args.epoch > 0:
start = args.epoch
end = args.epoch + 1
else:
start = 1 # model selection
end = 50
for epoch in range(start, end):
for file in os.listdir(checkpoint_path):
if file.split('-')[2] == str(epoch):
model_file = file
break
if mode == 'val':
result_file = f'results/VA_{model_type}-{model_prefix}-{mode}2.json'
# try:
vqa.predict(model_file, result_file)
balance_results_objs, balance_results_avg = eval_mc_VA.main(result_file, mode, args.N)
results_objs['e'+str(epoch)] = balance_results_objs
results[str(epoch)] = round(balance_results_avg,2)
# except:
# print('Something wrong in epoch', epoch)
sorted_results = sorted(results.items(), key=lambda kv: (kv[1], kv[0]))
best = sorted_results[-1]
# do test
for file in os.listdir(checkpoint_path):
if file.split('-')[3] == str(best[0]):
model_file = file
break
mode = 'test'
result_file = f'results/VA_{model_type}-{model_prefix}-{mode}2.json'
vqa.predict_test(model_file, result_file)
balance_results_test_objs, balance_results_test_avg = eval_mc_VA.main(result_file, mode, args.N)
results_objs['test'] = balance_results_test_objs
sorted_results.append(('test', round(balance_results_test_avg, 2)))
return results_objs, sorted_results, model_type
if __name__ == "__main__":
torch.backends.cudnn.enabled = False
torch.manual_seed(666)
torch.cuda.manual_seed(666)
torch.backends.cudnn.benchmark = True
parser = argparse.ArgumentParser()
parser.add_argument('--gpu', dest='gpu', type=int,
default=0, help='gpu device id')
parser.add_argument('--mode', dest='mode', type=str,
default='train', help='train or val')
parser.add_argument('--bert', dest='bert', action='store_true',
help='use bert or glove')
parser.add_argument('--checkpoint', dest='checkpoint', type=str,
default='ck', help='checkpoint name')
parser.add_argument('--N', dest='N', type=str,
default='all', help='construction parameter')
parser.add_argument('--gin', dest='gin', type=int,
default=3, help='Layer number of GIN')
parser.add_argument('--delta', dest='delta', type=float,
default=0.5, help='parameter in GIN')
parser.add_argument('--lambda1', dest='lambda1', type=float,
default=1.0, help='lambda1')
parser.add_argument('--lambda2', dest='lambda2', type=float,
default=2.0, help='lambda2')
parser.add_argument('--epoch', dest='epoch', type=int,
default=-1, help='epoch model for evaluation')
args = parser.parse_args()
if args.N == 'all':
N = ['N1', 'N2', 'N5']
results_N = []
for n in N:
start = time.time()
args.N = n
results_objs, results, model_type = main(args)
results_N.append(results)
print('--------NExT-OOD-VA with', n, '---------')
print(results)
print('VA OOD best val:', results[-2], ', test result:', results[-1])
val = []
test = []
print('======================================================================')
for idx, results in enumerate(results_N):
print('-----------------NExT-OOD-VA with', N[idx], '-------------')
print(results)
print('VA OOD best val:', results[-2], ', test result:', results[-1])
val.append(results[-2][1])
test.append(results[-1][1])
print('======================================================================')
print('Val avg:', (val[0] + val[1] + val[2]) / 3, ' Test avg:', (test[0] + test[1] + test[2]) / 3)
else:
results_objs, results, model_type = main(args)
print('---------NExT-OOD-VA with', args.N, '------')
print(results)
print('*** NExT-OOD-VA best val:', results[-2], ', test result:', results[-1])