diff --git a/README.md b/README.md index f1ce6be..269ba25 100644 --- a/README.md +++ b/README.md @@ -13,6 +13,16 @@ The output consists of two files: 1. A data file (in NumPy's native format) containing the model's learned parameters. 2. A Python class that constructs the model's graph. +Run `pack-pb.py` to pack code and data to pb file + +before packing, you need generate code and data first + +e.g. convert caffe alexnet to tensorflow pb file + +1. `./convert.py ~/caffe/models/bvlc_alexnet/deploy.prototxt --caffemodel ~/caffe/models/bvlc_alexnet/bvlc_alexnet.caffemodel --data-output-path=AlexNet.npy --code-output-path=AlexNet.py` +2. `python pack-pb.py --model AlexNet` + + ### Examples See the [examples](examples/) folder for more details. diff --git a/kaffe/caffe/caffepb.py b/kaffe/caffe/caffe_pb2.py similarity index 100% rename from kaffe/caffe/caffepb.py rename to kaffe/caffe/caffe_pb2.py diff --git a/kaffe/caffe/resolver.py b/kaffe/caffe/resolver.py index b9580a7..0986aa2 100644 --- a/kaffe/caffe/resolver.py +++ b/kaffe/caffe/resolver.py @@ -14,7 +14,7 @@ def import_caffe(self): self.caffe = caffe except ImportError: # Fall back to the protobuf implementation - from . import caffepb + from . import caffe_pb2 as caffepb self.caffepb = caffepb show_fallback_warning() if self.caffe: diff --git a/kaffe/tensorflow/network.py b/kaffe/tensorflow/network.py index 6f3b153..2750f5e 100644 --- a/kaffe/tensorflow/network.py +++ b/kaffe/tensorflow/network.py @@ -130,11 +130,11 @@ def conv(self, output = convolve(input, kernel) else: # Split the input into groups and then convolve each of them independently - input_groups = tf.split(3, group, input) - kernel_groups = tf.split(3, group, kernel) + input_groups = tf.split(axis=3, num_or_size_splits=group, value=input) + kernel_groups = tf.split(axis=3, num_or_size_splits=group, value=kernel) output_groups = [convolve(i, k) for i, k in zip(input_groups, kernel_groups)] # Concatenate the groups - output = tf.concat(3, output_groups) + output = tf.concat(axis=3, values=output_groups) # Add the biases if biased: biases = self.make_var('biases', [c_o]) @@ -177,7 +177,7 @@ def lrn(self, input, radius, alpha, beta, name, bias=1.0): @layer def concat(self, inputs, axis, name): - return tf.concat(concat_dim=axis, values=inputs, name=name) + return tf.concat(axis=axis, values=inputs, name=name) @layer def add(self, inputs, name): diff --git a/pack-pb.py b/pack-pb.py new file mode 100644 index 0000000..bfb4fba --- /dev/null +++ b/pack-pb.py @@ -0,0 +1,56 @@ +#!/usr/bin/env python + +import argparse +import sys +import tensorflow as tf +from tensorflow.python.framework.graph_util import convert_variables_to_constants + +def convert(model,output,shape): + from alexnet import AlexNet as MyNet + MyNet=getattr(__import__(model),model) + batch_size = 1 + data_node = tf.placeholder(tf.float32, shape) + net = MyNet({'data': data_node}) + model_dir='./' + with tf.Session() as sess: + output_graph = sess._graph + net.load(data_path=model+'.npy', session=sess) + graph = convert_variables_to_constants(sess, sess.graph_def, [output]) + tf.train.write_graph(graph, '.', model+'.pb', as_text=False) + +def main(): + input_height = 227 + input_width = 227 + input_channel=3 + input_batch=1 + model="LeNet" + output="prob" + + parser = argparse.ArgumentParser() + parser.add_argument("--model", help="model name") + parser.add_argument("--output", help="output name") + parser.add_argument("--input_height", type=int, help="input height") + parser.add_argument("--input_width", type=int, help="input width") + parser.add_argument("--input_channel", type=int, help="input channel") + parser.add_argument("--input_batch", type=int, help="input batch") + args = parser.parse_args() + + if args.input_height: + input_height = args.input_height + if args.input_width: + input_width = args.input_width + if args.input_channel: + input_channel = args.input_channel + if args.input_batch: + input_batch = args.input_batch + + if args.model: + model = args.model + if args.output: + output = args.output + + convert(model,output,(input_batch,input_height,input_width,input_channel)) + + +if __name__ == '__main__': + main()