diff --git a/proto/model.proto b/proto/model.proto index f0243a9..8fb7b37 100644 --- a/proto/model.proto +++ b/proto/model.proto @@ -6,6 +6,8 @@ message Weights { repeated uint32 shape = 3; string type = 4; bytes data = 5; + float quantize_min = 6; + float quantize_max = 7; } message Model { diff --git a/python/encoder.py b/python/encoder.py index 5754d32..37c28d3 100644 --- a/python/encoder.py +++ b/python/encoder.py @@ -6,7 +6,21 @@ import model_pb2 -class Encoder(object): +def quantize_arr(arr): + """Quantization based on linear rescaling over min/max range. + """ + min_val, max_val = np.min(arr), np.max(arr) + if max_val - min_val > 0: + quantized = np.round(255 * (arr - min_val) / (max_val - min_val)) + else: + quantized = np.zeros(arr.shape) + quantized = quantized.astype(np.uint8) + min_val = min_val.astype(np.float32) + max_val = max_val.astype(np.float32) + return quantized, min_val, max_val + + +class Encoder: """Encoder class. Takes as input a Keras model saved in hdf5 format that includes the model architecture with the weights. @@ -19,11 +33,12 @@ class Encoder(object): See https://keras.io/getting-started/faq/#savingloading-whole-models-architecture-weights-optimizer-state """ - def __init__(self, hdf5_model_filepath, name): + def __init__(self, hdf5_model_filepath, name, quantize): if not hdf5_model_filepath: raise Exception('hdf5_model_filepath must be provided.') self.hdf5_model_filepath = hdf5_model_filepath self.name = name + self.quantize = quantize self.create_model() @@ -55,8 +70,15 @@ def serialize(self): w.layer_name = layer_name w.weight_name = weight_name w.shape.extend(list(weight_value.shape)) - w.type = 'float32' - w.data = weight_value.astype(np.float32).tobytes() + if quantize: + w.type = 'uint8' + quantized, min_val, max_val = quantize_arr(weight_value) + w.data = quantized.astype(np.uint8).tobytes() + w.quantize_min = min_val + w.quantize_max = max_val + else: + w.type = 'float32' + w.data = weight_value.astype(np.float32).tobytes() hdf5_file.close() @@ -73,8 +95,10 @@ def save(self): if __name__ == '__main__': parser = argparse.ArgumentParser() parser.add_argument('hdf5_model_filepath') - parser.add_argument('--name', type=str, required=False, + parser.add_argument('-n', '--name', type=str, required=False, help='model name (defaults to filename without extension if not provided)') + parser.add_argument('-q', '--quantize', action='store_true', required=False, + help='quantize weights to 8-bit unsigned int') args = parser.parse_args() hdf5_model_filepath = args.hdf5_model_filepath @@ -84,6 +108,8 @@ def save(self): else: name = os.path.splitext(os.path.basename(hdf5_model_filepath))[0] - encoder = Encoder(hdf5_model_filepath, name) + quantize = args.quantize + + encoder = Encoder(hdf5_model_filepath, name, quantize) encoder.serialize() encoder.save() diff --git a/python/model_pb2.py b/python/model_pb2.py index 1c9e666..726e3ef 100644 --- a/python/model_pb2.py +++ b/python/model_pb2.py @@ -19,7 +19,7 @@ name='model.proto', package='', syntax='proto3', - serialized_pb=_b('\n\x0bmodel.proto\"]\n\x07Weights\x12\x12\n\nlayer_name\x18\x01 \x01(\t\x12\x13\n\x0bweight_name\x18\x02 \x01(\t\x12\r\n\x05shape\x18\x03 \x03(\r\x12\x0c\n\x04type\x18\x04 \x01(\t\x12\x0c\n\x04\x64\x61ta\x18\x05 \x01(\x0c\"\x80\x01\n\x05Model\x12\n\n\x02id\x18\x01 \x01(\t\x12\x0c\n\x04name\x18\x02 \x01(\t\x12\x15\n\rkeras_version\x18\x03 \x01(\t\x12\x0f\n\x07\x62\x61\x63kend\x18\x04 \x01(\t\x12\x14\n\x0cmodel_config\x18\x05 \x01(\t\x12\x1f\n\rmodel_weights\x18\x06 \x03(\x0b\x32\x08.Weightsb\x06proto3') + serialized_pb=_b('\n\x0bmodel.proto\"\x89\x01\n\x07Weights\x12\x12\n\nlayer_name\x18\x01 \x01(\t\x12\x13\n\x0bweight_name\x18\x02 \x01(\t\x12\r\n\x05shape\x18\x03 \x03(\r\x12\x0c\n\x04type\x18\x04 \x01(\t\x12\x0c\n\x04\x64\x61ta\x18\x05 \x01(\x0c\x12\x14\n\x0cquantize_min\x18\x06 \x01(\x02\x12\x14\n\x0cquantize_max\x18\x07 \x01(\x02\"\x80\x01\n\x05Model\x12\n\n\x02id\x18\x01 \x01(\t\x12\x0c\n\x04name\x18\x02 \x01(\t\x12\x15\n\rkeras_version\x18\x03 \x01(\t\x12\x0f\n\x07\x62\x61\x63kend\x18\x04 \x01(\t\x12\x14\n\x0cmodel_config\x18\x05 \x01(\t\x12\x1f\n\rmodel_weights\x18\x06 \x03(\x0b\x32\x08.Weightsb\x06proto3') ) @@ -67,6 +67,20 @@ message_type=None, enum_type=None, containing_type=None, is_extension=False, extension_scope=None, options=None), + _descriptor.FieldDescriptor( + name='quantize_min', full_name='Weights.quantize_min', index=5, + number=6, type=2, cpp_type=6, label=1, + has_default_value=False, default_value=float(0), + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + options=None), + _descriptor.FieldDescriptor( + name='quantize_max', full_name='Weights.quantize_max', index=6, + number=7, type=2, cpp_type=6, label=1, + has_default_value=False, default_value=float(0), + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + options=None), ], extensions=[ ], @@ -79,8 +93,8 @@ extension_ranges=[], oneofs=[ ], - serialized_start=15, - serialized_end=108, + serialized_start=16, + serialized_end=153, ) @@ -145,8 +159,8 @@ extension_ranges=[], oneofs=[ ], - serialized_start=111, - serialized_end=239, + serialized_start=156, + serialized_end=284, ) _MODEL.fields_by_name['model_weights'].message_type = _WEIGHTS diff --git a/src/Model.js b/src/Model.js index ae22d63..1cc5147 100644 --- a/src/Model.js +++ b/src/Model.js @@ -371,15 +371,24 @@ export default class Model { } const { data, shape, type } = weightDef - if (type !== 'float32') { - throw new Error(`[Model] Only float32 weights supported for now.`) - } // need to make a copy of underlying ArrayBuffer const buf = new ArrayBuffer(data.byteLength) - new Uint8Array(buf).set(new Uint8Array(data.buffer, data.byteOffset, data.byteLength)) - - return new Tensor(new Float32Array(buf), shape) + const arr = new Uint8Array(buf) + arr.set(new Uint8Array(data.buffer, data.byteOffset, data.byteLength)) + + if (type === 'uint8') { + // weights are quantized + const { quantizeMin, quantizeMax } = weightDef + const unquantized = new Float32Array(arr) + for (let i = 0, len = unquantized.length; i < len; i++) { + unquantized[i] *= (quantizeMax - quantizeMin) / 255 + unquantized[i] += quantizeMin + } + return new Tensor(unquantized, shape) + } else { + return new Tensor(new Float32Array(buf), shape) + } }) layer.setWeights(weights) diff --git a/src/proto.js b/src/proto.js index 9de5c4f..6df6d0d 100644 --- a/src/proto.js +++ b/src/proto.js @@ -18,6 +18,8 @@ export const Weights = $root.Weights = (() => { * @property {Array.} [shape] Weights shape * @property {string} [type] Weights type * @property {Uint8Array} [data] Weights data + * @property {number} [quantizeMin] Weights quantizeMin + * @property {number} [quantizeMax] Weights quantizeMax */ /** @@ -75,6 +77,22 @@ export const Weights = $root.Weights = (() => { */ Weights.prototype.data = $util.newBuffer([]); + /** + * Weights quantizeMin. + * @member {number}quantizeMin + * @memberof Weights + * @instance + */ + Weights.prototype.quantizeMin = 0; + + /** + * Weights quantizeMax. + * @member {number}quantizeMax + * @memberof Weights + * @instance + */ + Weights.prototype.quantizeMax = 0; + /** * Creates a new Weights instance using the specified properties. * @function create @@ -113,6 +131,10 @@ export const Weights = $root.Weights = (() => { writer.uint32(/* id 4, wireType 2 =*/34).string(message.type); if (message.data != null && message.hasOwnProperty("data")) writer.uint32(/* id 5, wireType 2 =*/42).bytes(message.data); + if (message.quantizeMin != null && message.hasOwnProperty("quantizeMin")) + writer.uint32(/* id 6, wireType 5 =*/53).float(message.quantizeMin); + if (message.quantizeMax != null && message.hasOwnProperty("quantizeMax")) + writer.uint32(/* id 7, wireType 5 =*/61).float(message.quantizeMax); return writer; }; @@ -169,6 +191,12 @@ export const Weights = $root.Weights = (() => { case 5: message.data = reader.bytes(); break; + case 6: + message.quantizeMin = reader.float(); + break; + case 7: + message.quantizeMax = reader.float(); + break; default: reader.skipType(tag & 7); break; @@ -223,6 +251,12 @@ export const Weights = $root.Weights = (() => { if (message.data != null && message.hasOwnProperty("data")) if (!(message.data && typeof message.data.length === "number" || $util.isString(message.data))) return "data: buffer expected"; + if (message.quantizeMin != null && message.hasOwnProperty("quantizeMin")) + if (typeof message.quantizeMin !== "number") + return "quantizeMin: number expected"; + if (message.quantizeMax != null && message.hasOwnProperty("quantizeMax")) + if (typeof message.quantizeMax !== "number") + return "quantizeMax: number expected"; return null; }; @@ -256,6 +290,10 @@ export const Weights = $root.Weights = (() => { $util.base64.decode(object.data, message.data = $util.newBuffer($util.base64.length(object.data)), 0); else if (object.data.length) message.data = object.data; + if (object.quantizeMin != null) + message.quantizeMin = Number(object.quantizeMin); + if (object.quantizeMax != null) + message.quantizeMax = Number(object.quantizeMax); return message; }; @@ -279,6 +317,8 @@ export const Weights = $root.Weights = (() => { object.weightName = ""; object.type = ""; object.data = options.bytes === String ? "" : []; + object.quantizeMin = 0; + object.quantizeMax = 0; } if (message.layerName != null && message.hasOwnProperty("layerName")) object.layerName = message.layerName; @@ -293,6 +333,10 @@ export const Weights = $root.Weights = (() => { object.type = message.type; if (message.data != null && message.hasOwnProperty("data")) object.data = options.bytes === String ? $util.base64.encode(message.data, 0, message.data.length) : options.bytes === Array ? Array.prototype.slice.call(message.data) : message.data; + if (message.quantizeMin != null && message.hasOwnProperty("quantizeMin")) + object.quantizeMin = options.json && !isFinite(message.quantizeMin) ? String(message.quantizeMin) : message.quantizeMin; + if (message.quantizeMax != null && message.hasOwnProperty("quantizeMax")) + object.quantizeMax = options.json && !isFinite(message.quantizeMax) ? String(message.quantizeMax) : message.quantizeMax; return object; };