-
Notifications
You must be signed in to change notification settings - Fork 15
/
serve.py
44 lines (34 loc) · 1.43 KB
/
serve.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
import tensorflow as tf
import os
SAVE_PATH = './save'
MODEL_NAME = 'test'
VERSION = 1
SERVE_PATH = './serve/{}/{}'.format(MODEL_NAME, VERSION)
checkpoint = tf.train.latest_checkpoint(SAVE_PATH)
tf.reset_default_graph()
with tf.Session() as sess:
# import the saved graph
saver = tf.train.import_meta_graph(checkpoint + '.meta')
# get the graph for this session
graph = tf.get_default_graph()
sess.run(tf.global_variables_initializer())
# get the tensors that we need
inputs = graph.get_tensor_by_name('inputs:0')
predictions = graph.get_tensor_by_name('prediction/Sigmoid:0')
# create tensors info
model_input = tf.saved_model.utils.build_tensor_info(inputs)
model_output = tf.saved_model.utils.build_tensor_info(predictions)
# build signature definition
signature_definition = tf.saved_model.signature_def_utils.build_signature_def(
inputs={'inputs': model_input},
outputs={'outputs': model_output},
method_name= tf.saved_model.signature_constants.PREDICT_METHOD_NAME)
builder = tf.saved_model.builder.SavedModelBuilder(SERVE_PATH)
builder.add_meta_graph_and_variables(
sess, [tf.saved_model.tag_constants.SERVING],
signature_def_map={
tf.saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY:
signature_definition
})
# Save the model so we can serve it with a model server :)
builder.save()