forked from deepdrive/deepdrive
-
Notifications
You must be signed in to change notification settings - Fork 0
/
tf_utils.py
92 lines (68 loc) · 2.36 KB
/
tf_utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
import tensorflow as tf
import time
from tensorflow.python.client import timeline
import numpy as np
import logs
log = logs.get_log(__name__)
IMAGE = tf.placeholder(tf.float64)
DEPTH = tf.placeholder(tf.float64)
def _image_op(x):
y = x ** 0.45 # gamma correction
y = tf.clip_by_value(y, 0, 1)
y = y * 255.
y = tf.cast(y, tf.uint8)
return y
def _depth_op(x):
x = x ** -(1 / 3.)
x = _normalize_op(x)
x = _heatmap_op(x)
return x
def _normalize_op(x):
amax = tf.reduce_max(x)
amin = tf.reduce_min(x)
arange = amax - amin
x = (x - amin) / arange
return x
def _heatmap_op(x):
red = x
green = 1.0 - tf.abs(0.5 - x) * 2.
blue = 1. - x
y = tf.stack([red, green, blue])
y = tf.transpose(y, (1, 2, 0))
y = tf.cast(y * 255, tf.uint8)
return y
image_op = _image_op(IMAGE)
depth_op = _depth_op(DEPTH)
def preprocess_image(image, sess, trace=False):
return _run_op(sess, image_op, IMAGE, image, trace, op_name='preprocess_image')
def preprocess_depth(depth, sess, trace=False):
return _run_op(sess, depth_op, DEPTH, depth, trace, op_name='preprocess_depth')
def _run_op(sess, op, X, x, trace=False, op_name='tf_op'):
start = time.time()
if trace:
run_options = tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE)
run_metadata = tf.RunMetadata()
ret = sess.run(op, feed_dict={X: x}, options=run_options, run_metadata=run_metadata)
# Create the Timeline object, and write it to a json
tl = timeline.Timeline(run_metadata.step_stats)
ctf = tl.generate_chrome_trace_format()
with open('timeline.json', 'w') as f:
f.write(ctf)
else:
ret = sess.run(op, feed_dict={X: x})
end = time.time()
log.debug('%r took %rms', op_name, (end - start) * 1000)
return ret
def _main():
h = w = 227
import sys
log.basicConfig(level=log.DEBUG, stream=sys.stdout, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
with tf.Session() as sess:
preprocess_image(np.random.rand(h, w, 3), sess)
preprocess_image(np.random.rand(h, w, 3), sess)
preprocess_image(np.random.rand(h, w, 3), sess)
preprocess_depth(np.random.rand(h, w,), sess)
preprocess_depth(np.random.rand(h, w,), sess)
preprocess_depth(np.random.rand(h, w,), sess)
if __name__ == '__main__':
_main()