From 0ffe6bcc204a697e50e1db4e5611ce496e60b70f Mon Sep 17 00:00:00 2001 From: Lutz Roeder Date: Mon, 16 Dec 2024 01:12:42 -0500 Subject: [PATCH] Update onnx.js (#6) --- source/onnx.js | 245 ++++++----------------------------------------- test/models.json | 14 +-- 2 files changed, 35 insertions(+), 224 deletions(-) diff --git a/source/onnx.js b/source/onnx.js index 2f25f44cac..e0bff8b606 100644 --- a/source/onnx.js +++ b/source/onnx.js @@ -1562,6 +1562,11 @@ onnx.ProtoReader = class { return new onnx.ProtoReader(context, 'text', 'model'); } } + const obj = context.peek('json'); + if (obj && (obj.ir_version === undefined && obj.producer_name === undefined && !Array.isArray(obj.opset_import) && !Array.isArray(obj.metadata_props)) && + (obj.irVersion !== undefined || obj.producerName !== undefined || Array.isArray(obj.opsetImport) || Array.isArray(obj.metadataProps) || (Array.isArray(obj.graph) && Array.isArray(obj.graph.node)))) { + return new onnx.ProtoReader(context, 'json', 'model'); + } return undefined; } @@ -1588,6 +1593,17 @@ onnx.ProtoReader = class { } break; } + case 'json': { + try { + const obj = this.context.read('json'); + this.model = onnx.proto.ModelProto.decodeJson(obj); + this.format = `ONNX${this.model.ir_version ? ` v${this.model.ir_version}` : ''}`; + } catch (error) { + const message = error && error.message ? error.message : error.toString(); + throw new onnx.Error(`File JSON format is not onnx.ModelProto (${message.replace(/\.$/, '')}).`); + } + break; + } case 'binary': { switch (this.type) { case 'tensor': { @@ -1877,7 +1893,8 @@ onnx.JsonReader = class { static open(context) { const obj = context.peek('json'); - if (obj && (obj.irVersion !== undefined || obj.ir_Version !== undefined || (obj.graph && Array.isArray(obj.graph.node)))) { + if (obj && obj.framework === undefined && obj.graph && + (obj.ir_version !== undefined || obj.producer_name !== undefined || Array.isArray(obj.opset_import))) { return new onnx.JsonReader(obj); } return null; @@ -1886,7 +1903,6 @@ onnx.JsonReader = class { constructor(obj) { this.name = 'onnx.json'; this.model = obj; - this._attributeTypes = new Map(Object.entries(onnx.AttributeType)); } async read() { @@ -1894,137 +1910,13 @@ onnx.JsonReader = class { this.format = `ONNX JSON${this.model.ir_version ? ` v${this.model.ir_version}` : ''}`; } - _tensor_shape(value) { - if (Array.isArray(value.dim)) { - for (const dimension of value.dim) { - if (dimension.dimValue !== undefined) { - dimension.dim_value = parseInt(dimension.dimValue, 10); - delete dimension.dimValue; - } else if (dimension.dimParam !== undefined) { - dimension.dim_param = dimension.dimParam; - delete dimension.dimParam; - } - } - } - return value; - } - - _tensor_type(value) { - if (value.elemType !== undefined) { - value.elem_type = value.elemType; - delete value.elemType; - } - if (value.shape) { - value.shape = this._tensor_shape(value.shape); - } - return value; - } - - _optional_type(value) { - if (value.elemType !== undefined) { - value.elem_type = this._type(value.elemType); - delete value.elemType; - } - return value; - } - - _sequence_type(value) { - if (value.elemType !== undefined) { - value.elem_type = this._type(value.elemType); - delete value.elemType; - } - return value; - } - - _map_type(value) { - if (value.keyType !== undefined) { - value.key_type = value.keyType; - delete value.keyType; - } - if (value.valueType !== undefined) { - value.value_type = this._type(value.valueType); - delete value.valueType; - } - return value; - } - - _sparse_tensor_type(value) { - if (value.elemType !== undefined) { - value.elem_type = value.elemType; - delete value.elemType; - } - if (value.shape) { - value.shape = this._tensor_shape(value.shape); - } - return value; - } - - _type(value) { - if (value.tensorType) { - value.tensor_type = this._tensor_type(value.tensorType); - delete value.tensorType; - } else if (value.tensor_type) { - value.tensor_type = this._tensor_type(value.tensor_type); - } else if (value.sequenceType) { - value.sequence_type = this._sequence_type(value.sequenceType); - delete value.sequenceType; - } else if (value.sequence_type) { - value.sequence_type = this._sequence_type(value.sequenceType); - } else if (value.optionalType !== undefined) { - value.optional_type = this._optional_type(value.optionalType); - delete value.optionalType; - } else if (value.optional_type) { - value.optional_type = this._optional_type(value.optionalType); - } else if (value.mapType) { - value.map_type = this._map_type(value.mapType); - delete value.mapType; - } else if (value.map_type) { - value.map_type = this._map_type(value.mapType); - } else if (value.sparseTensorType) { - value.sparse_tensor_type = this._sparse_tensor_type(value.sparseTensorType); - delete value.sparseTensorType; - } else if (value.sparse_tensor_type) { - value.sparse_tensor_type = this._sparse_tensor_type(value.sparseTensorType); - } else if (Object.keys(value).length > 0) { - throw new onnx.Error(`Unsupported ONNX JSON type '${JSON.stringify(Object.keys(value))}'.`); - } - return value; - } - _tensor(value) { - if (value.dataType !== undefined) { - value.data_type = value.dataType; - delete value.dataType; - } - value.dims = Array.isArray(value.dims) ? value.dims.map((dim) => parseInt(dim, 10)) : []; + value.dims = Array.isArray(value.dims) ? value.dims : []; if (value.raw_data !== undefined) { - if (value.raw_data && value.raw_data instanceof Uint8Array === false && - value.raw_data.type === 'Buffer' && Array.isArray(value.raw_data.data)) { + if (value.raw_data && value.raw_data instanceof Uint8Array === false && value.raw_data.type === 'Buffer' && Array.isArray(value.raw_data.data)) { value.data_location = onnx.DataLocation.DEFAULT; value.raw_data = new Uint8Array(value.raw_data.data); } - } else if (value.rawData !== undefined) { - value.data_location = onnx.DataLocation.DEFAULT; - const data = atob(value.rawData); - const length = data.length; - const array = new Uint8Array(length); - for (let i = 0; i < length; i++) { - array[i] = data[i].charCodeAt(0); - } - value.raw_data = array; - delete value.rawData; - } else if (Array.isArray(value.floatData)) { - value.data_location = onnx.DataLocation.DEFAULT; - value.float_data = value.floatData; - delete value.floatData; - } else if (Array.isArray(value.int32Data)) { - value.data_location = onnx.DataLocation.DEFAULT; - value.int32_data = value.int32Data; - delete value.int32Data; - } else if (Array.isArray(value.int64Data)) { - value.data_location = onnx.DataLocation.DEFAULT; - value.int64_data = value.int64Data.map((value) => parseInt(value, 10)); - delete value.int64Data; } else if ((Array.isArray(value.float_data) && value.float_data.length > 0) || (Array.isArray(value.int32_data) && value.int32_data.length > 0) || (Array.isArray(value.int64_data) && value.int64_data.length > 0)) { @@ -2042,13 +1934,7 @@ onnx.JsonReader = class { } _attribute(value) { - if (value.type && this._attributeTypes.has(value.type)) { - value.type = this._attributeTypes.get(value.type); - } - if (value.refAttrName) { - value.ref_attr_name = value.refAttrName; - delete value.refAttrName; - } else if (value.ref_attr_name) { + if (value.ref_attr_name) { value.ref_attr_name = value.ref_attr_name.toString(); } else if (value.type === onnx.AttributeType.FLOATS || (Array.isArray(value.floats) && value.floats.length > 0)) { value.floats = value.floats.map((value) => parseFloat(value)); @@ -2060,24 +1946,18 @@ onnx.JsonReader = class { value.tensors = value.tensors.map((value) => this._tensor(value)); } else if (value.type === onnx.AttributeType.GRAPHS || (Array.isArray(value.graphs) && value.graphs.length > 0)) { value.graphs = value.graphs.map((value) => this._graph(value)); - } else if (value.type === onnx.AttributeType.SPARSE_TENSORS || (Array.isArray(value.sparseTensors) && value.sparseTensors.length > 0)) { - value.sparse_tensors = value.sparseTensors.map((item) => this._sparse_tensor(item)); - delete value.sparseTensors; } else if (value.type === onnx.AttributeType.SPARSE_TENSORS || (Array.isArray(value.sparse_tensors) && value.sparse_tensors.length > 0)) { value.sparse_tensors = value.sparse_tensors.map((item) => this._sparse_tensor(item)); } else if (value.type === onnx.AttributeType.FLOAT || value.f !== undefined) { - value.f = parseFloat(value.f); + // continue } else if (value.type === onnx.AttributeType.INT || value.i !== undefined) { - value.i = parseInt(value.i, 10); + // continue } else if (value.type === onnx.AttributeType.STRING || value.s !== undefined) { value.s = atob(value.s); } else if (value.type === onnx.AttributeType.TENSOR || value.t !== undefined) { value.t = this._tensor(value.t); } else if (value.type === onnx.AttributeType.GRAPH || value.g !== undefined) { value.g = this._graph(value.g); - } else if (value.type === onnx.AttributeType.SPARSE_TENSOR || value.sparseTensor !== undefined) { - value.sparse_tensor = this._sparse_tensor(value.sparseTensor); - delete value.sparseTensor; } else if (value.type === onnx.AttributeType.SPARSE_TENSOR || value.sparse_tensor !== undefined) { value.sparse_tensor = this._sparse_tensor(value.sparse_tensor); } else { @@ -2087,43 +1967,18 @@ onnx.JsonReader = class { } _node(value) { - if (value.opType !== undefined) { - value.op_type = value.opType; - delete value.opType; - } value.input = Array.isArray(value.input) ? value.input : []; value.output = Array.isArray(value.output) ? value.output : []; value.attribute = Array.isArray(value.attribute) ? value.attribute.map((value) => this._attribute(value)) : []; return value; } - _value_info(value) { - value.type = this._type(value.type); - return value; - } - - _operator_set(value) { - value.version = parseInt(value.version, 10); - return value; - } - _graph(value) { value.node = value.node.map((value) => this._node(value)); value.initializer = Array.isArray(value.initializer) ? value.initializer.map((value) => this._tensor(value)) : []; - if (Array.isArray(value.sparseInitializer) && value.sparseInitializer.length > 0) { - value.sparse_initializer = value.sparseInitializer.map((item) => this._sparse_tensor(item)); - delete value.sparseInitializer; - } else if (Array.isArray(value.sparse_initializer) && value.sparse_initializer.length > 0) { - value.sparse_initializer = value.sparseInitializer.map((item) => this._sparse_tensor(item)); - } - if (Array.isArray(value.valueInfo) && value.valueInfo.length > 0) { - value.value_info = value.valueInfo.map((item) => this._value_info(item)); - delete value.valueInfo; - } else if (Array.isArray(value.value_info) && value.value_info.length > 0) { - value.value_info = value.value_info.map((item) => this._value_info(item)); - } - value.input = Array.isArray(value.input) ? value.input.map((value) => this._value_info(value)) : []; - value.output = Array.isArray(value.output) ? value.output.map((value) => this._value_info(value)) : []; + value.sparse_initializer = Array.isArray(value.sparse_initializer) ? value.sparse_initializer.map((item) => this._sparse_tensor(item)) : []; + value.input = Array.isArray(value.input) ? value.input : []; + value.output = Array.isArray(value.output) ? value.output : []; return value; } @@ -2132,57 +1987,13 @@ onnx.JsonReader = class { value.input = Array.isArray(value.input) ? value.input : []; value.output = Array.isArray(value.output) ? value.output : []; value.attribute = Array.isArray(value.attribute) ? value.attribute : []; - if (Array.isArray(value.attributeProto) && value.attributeProto.length > 0) { - value.attribute_proto = value.attributeProto.map((value) => this._attribute(value)); - delete value.attributeProto; - } else if (Array.isArray(value.attribute_proto) && value.attribute_proto.length > 0) { - value.attribute_proto = value.attribute_proto.map((value) => this._attribute(value)); - } - if (value.docString) { - value.doc_string = value.docString; - delete value.docString; - } + value.attribute_proto = Array.isArray(value.attribute_proto) ? value.attribute_proto.map((value) => this._attribute(value)) : []; return value; } _model(value) { - if (value.irVersion !== undefined) { - value.ir_version = parseInt(value.irVersion, 10); - delete value.irVersion; - } - if (value.version !== undefined) { - value.version = parseInt(value.version, 10); - } - if (value.producerName) { - value.producer_name = value.producerName; - delete value.producerName; - } - if (value.producerVersion) { - value.producer_version = value.producerVersion; - delete value.producerVersion; - } - if (value.modelVersion) { - value.model_version = parseInt(value.modelVersion, 10); - delete value.modelVersion; - } - if (value.docString) { - value.doc_string = value.docString; - delete value.docString; - } value.graph = this._graph(value.graph); - if (Array.isArray(value.opsetImport) && value.opsetImport.length > 0) { - value.opset_import = value.opsetImport.map((item) => this._operator_set(item)); - delete value.opsetImport; - } else if (Array.isArray(value.opset_import) && value.opset_import.length > 0) { - value.opset_import = value.opset_import.map((item) => this._operator_set(item)); - } - if (Array.isArray(value.metadataProps)) { - value.metadata_props = value.metadataProps; - delete value.metadataProps; - } - if (Array.isArray(value.functions)) { - value.functions = value.functions.map((item) => this._function(item)); - } + value.functions = Array.isArray(value.functions) ? value.functions.map((item) => this._function(item)) : []; return value; } }; diff --git a/test/models.json b/test/models.json index 3e4062321a..597a143018 100644 --- a/test/models.json +++ b/test/models.json @@ -4074,7 +4074,7 @@ "type": "onnx", "target": "candy.json.zip", "source": "https://github.com/lutzroeder/netron/files/12329067/candy.json.zip", - "format": "ONNX JSON v3", + "format": "ONNX v3", "assert": "model.graphs[0].nodes[2].attributes[1].visible == false", "tags": "validation", "link": "https://github.com/lutzroeder/netron/issues/6" @@ -4254,7 +4254,7 @@ "type": "onnx", "target": "gather.json", "source": "https://github.com/lutzroeder/netron/files/12306625/gather.json.zip[gather.json]", - "format": "ONNX JSON v6", + "format": "ONNX v6", "link": "https://github.com/lutzroeder/netron/issues/6" }, { @@ -4374,7 +4374,7 @@ "type": "onnx", "target": "issue_1138.json", "source": "https://github.com/lutzroeder/netron/files/12343742/issue_1138.json.zip[issue_1138.json]", - "format": "ONNX JSON v9", + "format": "ONNX v9", "link": "https://github.com/lutzroeder/netron/issues/1138" }, { @@ -4484,7 +4484,7 @@ "type": "onnx", "target": "nms_base_component.json", "source": "https://github.com/lutzroeder/netron/files/12306626/nms_base_component.json.zip[nms_base_component.json]", - "format": "ONNX JSON v8", + "format": "ONNX v8", "link": "https://github.com/lutzroeder/netron/issues/6" }, { @@ -4507,7 +4507,7 @@ "type": "onnx", "target": "optional_type.json", "source": "https://github.com/lutzroeder/netron/files/12329086/optional_type.json.zip[optional_type.json]", - "format": "ONNX JSON v8", + "format": "ONNX v8", "link": "https://github.com/lutzroeder/netron/issues/6" }, { @@ -4600,7 +4600,7 @@ "type": "onnx", "target": "sparse_initializer_as_output.json", "source": "https://github.com/lutzroeder/netron/files/12444489/sparse_initializer_as_output.json.zip[sparse_initializer_as_output.json]", - "format": "ONNX JSON v7", + "format": "ONNX v7", "assert": "model.graphs[0].outputs[0].value[0].type.layout == 'sparse'", "tags": "validation", "link": "https://github.com/lutzroeder/netron/issues/741" @@ -4769,7 +4769,7 @@ "type": "onnx", "target": "zipmap_int64float.json", "source": "https://github.com/lutzroeder/netron/files/12329104/zipmap_int64float.json.zip[zipmap_int64float.json]", - "format": "ONNX JSON v3", + "format": "ONNX v3", "link": "https://github.com/lutzroeder/netron/issues/6" }, {