forked from Shobhit20/Image-Captioning
-
Notifications
You must be signed in to change notification settings - Fork 0
/
test_test.py
135 lines (112 loc) · 4.57 KB
/
test_test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
import pytest
from tensorflow.python.client import device_lib
import keras.backend as K
import os, subprocess, atexit
from test_mod import process_caption,generate_captions
from test import text
from train import train
from SceneDesc import scenedesc
import encode_image as ei
test_img=os.path.join('Flickr8K_Data','997722733_0cb5439472.jpg')
sd=scenedesc()
try:
model=sd.create_model(ret_model=True)
model.load_weights(os.path.join('Output','Weights.h5'))
preload_model=True
print('pre-computed weights loaded; model created')
except IOError:
print('no pre-computed weights found; skipping testing tests')
preload_model=False
try:
tf_available_devices=str(device_lib.list_local_devices())
k_gpus=K.tensorflow_backend._get_available_gpus()
if 'GPU' in tf_available_devices and len(k_gpus)>0:
print('GPUs visible to keras!')
has_gpu=True
else:
print('No GPUs visible to keras - skipping training tests!')
has_gpu=False
except AttributeError:
print('Could not detect whether any GPU is visible to keras - most likely due to an outdated version of keras. Skipping training tests!')
has_gpu=False
def test_process_caption():
'''
Wherein we test the post-processing step in test_mod.py
that removes start and end tokens.
'''
capt=process_caption(sd,'<start> A black dog is running after a white dog in the snow . <end>')
assert(capt=='A black dog is running after a white dog in the snow .')
@pytest.mark.skipif(not preload_model, reason='no pre-computed weights found')
def test_generate_captions():
'''
Wherein we test test_mod.generate_captions. Since you may
use pre-computed weights from any source only print the
generated sentence to stdout and check that it is nonempty
'''
encoded_img=ei.encodings(ei.model_gen(),test_img)
caption=generate_captions(sd,model,encoded_img,beam_size=3)
def report():
print('The model generated the caption: '+caption)
atexit.register(report)
assert(len(caption)>0)
@pytest.mark.skipif(not preload_model, reason='no pre-computed weights found')
def test_text_creation():
'''
Wherein we test the caption generation wrapper test.text.
'''
text(test_img)
def report():
print('The model generated a caption that you should hear (and read, if you run pytest with the -s flag). This caption should be the same as the caption you can read below, and you should also have heard it read out.\n')
atexit.register(report)
pass
@pytest.mark.skipif(not has_gpu, reason='no gpu visible to keras')
def test_training():
'''
Wherein we test whether we can run the training pipeline
for a single epoch. We afterwards check that the model and
weights files have been written to file and are nonzero.
If pre-computed model/weights files are present, they will
be copied to a temporary location to save them from being
wantonly and relentlessly overwritten.
'''
#First, we make sure that we do not overwrite our pre-computed model and weights
if os.path.exists(os.path.join('Output','Model.h5')):
try:
subprocess.call(['mv',os.path.join('Output','Model.h5'),os.path.join('Output','Model.h5.swp')])
swapped_model=True
except:
print('Found pre-computed model, but could not copy it in temporary location. Aborting test.')
proceed=False #model present, but copying failed
swapped_model=False
else:
proceed=True #there is no model to overwrite
swapped_model=False
if os.path.exists(os.path.join('Output','Weights.h5')):
try:
subprocess.call(['mv',os.path.join('Output','Weights.h5'),os.path.join('Output','Weights.h5.swp')])
swapped_weights=True
except:
print('Found pre-computed weights, but could not copy them in temporary location. Aborting test.')
proceed=False #weights present, but copying failed
swapped_weights=False
else:
proceed=True #there are no weights to overwrite
swapped_weights=False
#train for one epoch; check that new model and weight files exist and are nonempty
if proceed:
train(1)
assert(os.path.exists('Output','Model.h5'))
assert(os.path.exists('Output','Weights.h5'))
assert(os.path.getsize(os.path.join('Output','Model.h5'))>0)
assert(os.path.getsize(os.path.join('Output','Weights.h5'))>0)
#swapping back model and weights from their temporary locations
if swapped_model:
try:
subprocess.call(['mv',os.path.join('Output','Model.h5.swp'),os.path.join('Output','Model.h5')])
except:
print('Could not copy back model file! Possibly it has been overwritten...')
if swapped_weights:
try:
subprocess.call(['mv',os.path.join('Output','Weights.h5.swp'),os.path.join('Output','Weights.h5')])
except:
print('Could not copy back weights file! Possibly it has been overwritten...')