Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix #447: Changed format for exporting models to MetaGraphDef for Tensorflow #455

Open
wants to merge 9 commits into
base: master
Choose a base branch
from
4 changes: 2 additions & 2 deletions example/tensorflow/code_template/tensorflow_sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,9 @@
# Read the protobuf text and build a tf.GraphDef
with open(model_file_name, 'r') as model_file:
model_protobuf = text_format.Parse(model_file.read(),
tf.GraphDef())
tf.MetaGraphDef())

# Import the GraphDef built above into the default graph
tf.import_graph_def(model_protobuf)
tf.train.import_meta_graph(model_protobuf)

# You can now add operations on top of the imported graph
13 changes: 8 additions & 5 deletions ide/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -311,19 +311,22 @@ def isProcessPossible(layerId):
json_str = json_str.strip("'<>() ").replace('\'', '\"')
lrnLayer = imp.load_source('LRN', BASE_DIR + '/keras_app/custom_layers/lrn.py')

# clear clutter from previous graph built by keras to avoid duplicates
K.clear_session()

model = model_from_json(json_str, {'LRN': lrnLayer.LRN})

sess = K.get_session()
tf.train.write_graph(sess.graph.as_graph_def(add_shapes=True), output_fld,
output_file + '.pbtxt', as_text=True)
tf.train.export_meta_graph(
abhigyan7 marked this conversation as resolved.
Show resolved Hide resolved
os.path.join(output_fld, output_file + '.meta'),
as_text=True)

Channel(reply_channel).send({
'text': json.dumps({
'result': 'success',
'action': 'ExportNet',
'id': 'randomId',
'name': randomId + '.pbtxt',
'url': '/media/' + randomId + '.pbtxt',
'name': randomId + '.meta',
'url': '/media/' + randomId + '.meta',
'customLayers': custom_layers_response
})
})
Expand Down
6 changes: 3 additions & 3 deletions tensorflow_app/views/export_graphdef.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,10 @@ def export_to_tensorflow(request):
randomId = response['randomId']
customLayers = response['customLayers']
os.chdir(BASE_DIR + '/tensorflow_app/views/')
os.system('KERAS_BACKEND=tensorflow python json2pbtxt.py -input_file ' +
os.system('KERAS_BACKEND=tensorflow python json2meta.py -input_file ' +
randomId + '.json -output_file ' + randomId)
return JsonResponse({'result': 'success',
'id': randomId,
'name': randomId + '.pbtxt',
'url': '/media/' + randomId + '.pbtxt',
'name': randomId + '.meta',
'url': '/media/' + randomId + '.meta',
'customLayers': customLayers})
46 changes: 42 additions & 4 deletions tensorflow_app/views/import_graphdef.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import numpy as np
import tensorflow as tf
from google.protobuf import text_format
from tensorflow.core.framework import graph_pb2
from django.views.decorators.csrf import csrf_exempt
from django.http import JsonResponse
import math
Expand Down Expand Up @@ -125,6 +124,46 @@ def get_padding(node, layer, session, input_layer_name, input_layer_dim):
return int(pad_h), int(pad_w)


def get_graph_def_from(model_protobuf):
"""
Parses and returns a GraphDef from input protobuf.

Args:
model_protobuf: a binary or text protobuf message.

Returns:
a tf.GraphDef object with the GraphDef from model_protobuf

Raises:
ValueError: if a GraphDef cannot be parsed from model_protobuf
"""
try:
meta_graph_def = text_format.Merge(model_protobuf, tf.MetaGraphDef())
graph_def = meta_graph_def.graph_def
return graph_def
except (text_format.ParseError, UnicodeDecodeError):
# not a valid text metagraphdef
pass
try:
graph_def = text_format.Merge(model_protobuf, tf.GraphDef())
return graph_def
except (text_format.ParseError, UnicodeDecodeError):
pass
try:
graph_def = tf.GraphDef()
graph_def.ParseFromString(model_protobuf)
return graph_def
except Exception:
pass
try:
meta_graph_def = tf.MetaGraphDef()
meta_graph_def.ParseFromString(model_protobuf)
return meta_graph_def.graph_def
except Exception:
pass
raise ValueError('Invalid model protobuf')


@csrf_exempt
def import_graph_def(request):
if request.method == 'POST':
Expand All @@ -151,15 +190,14 @@ def import_graph_def(request):
return JsonResponse({'result': 'error', 'error': 'No GraphDef model found'})

tf.reset_default_graph()
graph_def = graph_pb2.GraphDef()
d = {}
order = []
input_layer_name = ''
input_layer_dim = []

try:
text_format.Merge(config, graph_def)
except Exception:
graph_def = get_graph_def_from(config)
except ValueError:
return JsonResponse({'result': 'error', 'error': 'Invalid GraphDef'})

tf.import_graph_def(graph_def, name='')
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
parser.add_argument('-input_file', action="store",
dest='input_file', type=str, default='model.json')
parser.add_argument('-output_file', action="store",
dest='output_file', type=str, default='model.pbtxt')
dest='output_file', type=str, default='model.meta')
args = parser.parse_args()
input_file = args.input_file
output_file = args.output_file
Expand All @@ -30,6 +30,6 @@
lrn = imp.load_source('LRN', BASE_DIR + '/keras_app/custom_layers/lrn.py')
model = model_from_json(json_str, {'LRN': lrn.LRN})

sess = K.get_session()
tf.train.write_graph(sess.graph.as_graph_def(add_shapes=True), output_fld,
output_file + '.pbtxt', as_text=True)
tf.train.export_meta_graph(
os.path.join(output_fld, output_file + '.meta'),
as_text=True)
28 changes: 28 additions & 0 deletions tests/unit/tensorflow_app/test_views.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,3 +125,31 @@ def test_custom_lrn_tf_import(self):
response = self.client.post(reverse('tf-import'), {'file': model_file})
response = json.loads(response.content)
self.assertEqual(response['result'], 'success')


class ExportMetaGraphTest(unittest.TestCase):
def setUp(self):
self.client = Client()

def test_tf_export(self):
model_file = open(os.path.join(settings.BASE_DIR, 'example/keras',
'AlexNet.json'), 'r')
response = self.client.post(reverse('keras-import'), {'file': model_file})
response = json.loads(response.content)
net = get_shapes(response['net'])
response = self.client.post(reverse('tf-export'), {'net': json.dumps(net),
'net_name': ''})
response = json.loads(response.content)
self.assertEqual(response['result'], 'success')


class ImportMetaGraphTest(unittest.TestCase):
def setUp(self):
self.client = Client()

def test_tf_import(self):
model_file = open(os.path.join(settings.BASE_DIR, 'tests/unit/tensorflow_app',
'vgg16_import_test.meta'), 'r')
response = self.client.post(reverse('tf-import'), {'file': model_file})
response = json.loads(response.content)
self.assertEqual(response['result'], 'success')
Loading