From 0c7cb76fb0a8910afcb96f1c7457082f11b15e4d Mon Sep 17 00:00:00 2001 From: Lutz Roeder Date: Sat, 21 Dec 2024 11:08:14 -0800 Subject: [PATCH] Update python.js (#1061) --- source/python.js | 46 +++--- source/pytorch.js | 354 ++++++++++++++++++---------------------------- test/models.json | 11 +- 3 files changed, 171 insertions(+), 240 deletions(-) diff --git a/source/python.js b/source/python.js index f0167d7cc7..f970dd65ae 100644 --- a/source/python.js +++ b/source/python.js @@ -8809,8 +8809,8 @@ python.Execution = class { }); this.registerType('torch.Value', class { constructor(node) { - this._unique = node && node._next_unique ? node._next_unique++ : node._graph._next_unique++; // remove always node - this._node = node && node._next_unique ? null : node; + this._unique = node._graph._next_unique++; + this._node = node; this._uses = []; } unique() { @@ -9162,13 +9162,13 @@ python.Execution = class { 'init2(__torch__.torch.classes._nnapi.Compilation self, Tensor serialized_model_tensor, Tensor[] parameter_buffers, int compilation_preference, bool relax_f32_to_f16) -> NoneType', 'run(__torch__.torch.classes._nnapi.Compilation self, Tensor[] inputs, Tensor[] outputs) -> NoneType' ] }, - { name: '__torch__.torch.classes.quantized.Conv2dPackedParamsBase' }, - { name: '__torch__.torch.classes.quantized.Conv3dPackedParamsBase' }, - { name: '__torch__.torch.classes.quantized.LinearPackedParamsBase' }, - { name: '__torch__.torch.classes.rnn.CellParamsBase' }, - { name: '__torch__.torch.classes.xnnpack.Conv2dOpContext' }, - { name: '__torch__.torch.classes.xnnpack.LinearOpContext' }, - { name: '__torch__.torch.classes.xnnpack.TransposeConv2dOpContext' }, + { name: '__torch__.torch.classes.quantized.Conv2dPackedParamsBase', attributes: 'Tensor weight, Tensor bias, int[] stride, int[] padding, int[] dilation, int groups' }, + { name: '__torch__.torch.classes.quantized.Conv3dPackedParamsBase', attributes: 'Tensor weight, Tensor bias, int[] stride, int[] padding, int[] dilation, int groups' }, + { name: '__torch__.torch.classes.quantized.LinearPackedParamsBase', attributes: 'Tensor weight, Tensor? bias' }, + { name: '__torch__.torch.classes.rnn.CellParamsBase', attributes: 'str type, Tensor[] tensors, float[] doubles, int[] longs, __torch__.torch.classes.quantized.LinearPackedParamsBase[] packed_params' }, + { name: '__torch__.torch.classes.xnnpack.Conv2dOpContext', attributes: 'Tensor weight, Tensor? bias, int[] stride, int[] padding, int[] dilation, int groups, int[] output_min, int[] output_max' }, + { name: '__torch__.torch.classes.xnnpack.LinearOpContext', attributes: 'Tensor weight, Tensor bias, int[] output_min, int[] output_max' }, + { name: '__torch__.torch.classes.xnnpack.TransposeConv2dOpContext', attributes: 'Tensor weight, Tensor? bias, int[] stride, int[] padding, int[] output_padding, int[] dilation, int groups, int[] output_min, int[] output_max' }, ]; for (const known_type of known_types) { const prefix = new torch.jit.QualifiedName(known_type.name); @@ -9179,6 +9179,13 @@ python.Execution = class { const fn = new torch.jit.BuiltinOpFunction(name, schema); type.addMethod(fn); } + if (known_type.attributes) { + const schema = new torch.FunctionSchema(`(${known_type.attributes}) -> ()`); + for (const arg of schema.arguments) { + type.addAttribute(arg.name, arg.real_type); + } + } + this._compilation_unit.register_type(type); } if (this._reader.has_record('model.json')) { @@ -9186,20 +9193,27 @@ python.Execution = class { } const constants = this.readArchive('constants'); for (let i = 0; i < constants.length; i++) { - execution.builtins.CONSTANTS[`c${i}`] = constants[i]; + let val = constants[i]; + if (val && val.__class__ && val.__class__.__module__.startsWith('__torch__.torch.classes.')) { + const type = this._source_importer.resolveType(`${val.__class__.__module__}.${val.__class__.__name__}`); + const obj = torch.ScriptObject.create(type); + obj._ivalue = val; + val = obj; + } + execution.builtins.CONSTANTS[`c${i}`] = val; } const obj = this.readArchive('data'); - const convertModule = (obj) => { + const convertObject = (obj) => { if (obj.__class__) { const name = `${obj.__class__.__module__}.${obj.__class__.__name__}`; const type = this._source_importer.loadType(new torch.jit.QualifiedName(name)); - const module = new torch.ScriptModule(type, this._compilation_unit); + const module = type.is_module() ? new torch.ScriptModule(type, this._compilation_unit) : new torch.ScriptObject(type); for (let i = 0; i < type.numAttributes(); i++) { const k = type.getAttributeName(i); const t = type.getAttribute(i); const v = obj[k]; - if (t.is_module()) { - module.__setattr__(k, convertModule(v)); + if (t instanceof torch.ClassType) { + module.__setattr__(k, convertObject(v)); } else { if (t instanceof torch.TensorType && v && v.__class__ && v instanceof torch.Tensor === false && v.__class__.__module__ === '__torch__.torch.classes.quantized') { const name = `${v.__class__.__module__}.${v.__class__.__name__}`; @@ -9217,7 +9231,7 @@ python.Execution = class { } throw new python.Error('Module class not found.'); }; - return convertModule(obj); + return convertObject(obj); } LEGACY_deserialize() { const execution = this._compilation_unit.execution; @@ -9740,7 +9754,7 @@ python.Execution = class { if (!this.forward) { return null; } - execution.traceAttr = false; + execution.traceAttr = true; const args = []; if (!execution.traceAttr) { args.push(this); // self diff --git a/source/pytorch.js b/source/pytorch.js index fe67c477e3..0bd2d43cec 100644 --- a/source/pytorch.js +++ b/source/pytorch.js @@ -105,66 +105,64 @@ pytorch.Graph = class { } } } - const queue = [module]; - while (queue.length > 0) { - const module = queue.shift(); - const children = module.named_children(); - for (const [key, obj] of children) { - obj.__parent__ = module; - obj.__name__ = obj.__name__ || key; - queue.push(obj); - const type = obj._c._type(); - for (let i = 0; i < type.numAttributes(); i++) { - const k = type.getAttributeName(i); - const v = obj.__getattr__(k); - if (pytorch.Utility.isObject(v)) { - initializers.set(v, v); + const param_node = graph.param_node(); + const self = param_node && param_node.outputs().length > 0 && param_node.outputs()[0].type() === module._c._type() ? param_node.outputs()[0] : null; + if (self) { + const getattr = (value) => { + if (value.value === undefined) { + const node = value.node(); + if (node.kind() === 'prim::GetAttr') { + const [input] = node.inputs(); + getattr(input); + if (input.value !== undefined) { + const name = node.s('name'); + value.value = input.value.__getattr__(name); + value.identifier = input.identifier ? `${input.identifier}.${name}` : name; + } + } + if (node === param_node && value === param_node.outputs()[0]) { + value.value = module; + value.identifier = ''; } } - } - for (const buffer of module.buffers()) { - buffer.__parent__ = module; - if (buffer.storage() && !buffer.__origin__ && (buffer.__count__ === undefined || buffer.__count__ === 1)) { - initializers.set(buffer, new pytorch.Tensor(buffer.name, buffer)); - } - } - for (const parameter of module.parameters()) { - parameter.__parent__ = module; - if (parameter.storage() && !parameter.__origin__ && (parameter.__count__ === undefined || parameter.__count__ === 1)) { - initializers.set(parameter, new pytorch.Tensor(parameter.name, parameter)); + }; + for (const node of graph.nodes()) { + for (const input of node.inputs()) { + getattr(input, node); } } - for (const [key, obj] of children) { - if (key !== '__module__' && key !== '__name__' && key !== '__class__' && key !== '__parent__') { - if (!Array.isArray(obj) && obj === Object(obj)) { - if (pytorch.Utility.isTensor(obj)) { - const parameter = obj; - parameter.__parent__ = module; - if (parameter.storage() && !parameter.__origin__) { - if (parameter.__count__ === undefined || parameter.__count__ === 1) { - initializers.set(parameter, new pytorch.Tensor(parameter.name, parameter)); - } - } - } else if (pytorch.Utility.isObject(obj)) { - if (obj.__count__ === undefined || obj.__count__ === 1) { - initializers.set(obj, obj); - } - queue.push(obj); - } else if (obj instanceof torch.Value || obj instanceof torch.Node) { - continue; - } else if (obj && obj.__class__) { - obj.__parent__ = module; - obj.__name__ = obj.__name__ || key; - queue.push(obj); + const delattr = (value) => { + for (const use of Array.from(value.uses())) { + const node = use.user; + if (node.kind() === 'prim::GetAttr') { + for (const output of node.outputs()) { + delattr(output); } + node.destroy(); } } + }; + delattr(param_node.outputs()[0], ''); + } + for (const node of graph.nodes()) { + if (node.kind() === 'prim::Constant') { + const kind = node.kindOf('value'); + const value = node[kind]('value'); + for (const output of node.outputs()) { + output.identifier = output.debugName(); + output.value = value; + } + node.destroy(); } } - for (const value of graph.inputs()) { - const identifier = pytorch.Utility.unique(value); - const name = value.debugName() || identifier; - this.inputs.push(new pytorch.Argument(name, [values.map(identifier)])); + for (const v of graph.inputs()) { + if (self.uses().length === 0 && v === self) { + continue; + } + const identifier = pytorch.Utility.unique(v); + const name = v.debugName() || identifier; + const value = values.map(identifier); + this.inputs.push(new pytorch.Argument(name, [value])); } for (const value of graph.outputs()) { const identifier = pytorch.Utility.unique(value); @@ -209,23 +207,6 @@ pytorch.Graph = class { } this.nodes.push(new pytorch.Node(execution, metadata, null, null, node, initializers, values)); } - if (module) { - const queue = [module]; - while (queue.length > 0) { - const module = queue.pop(); - if (module) { - const modules = Array.from(module.children()); - queue.push(...modules.reverse()); - if (!module.__hide__ && module.named_parameters().size > 0) { - for (const [name] of module.named_children()) { - module.__delattr__(name); - } - const node = new pytorch.Node(execution, metadata, null, null, module, initializers, values); - this.nodes.push(node); - } - } - } - } } else if (torch && module instanceof torch.export.exported_program.ExportedProgram && module.graph) { const exported_program = module; const graph = exported_program.graph; @@ -379,7 +360,6 @@ pytorch.Node = class { this.inputs = []; this.outputs = []; this.metadata = []; - let module = null; if (torch && obj instanceof torch.Node) { const node = obj; const sourceRange = node.sourceRange(); @@ -405,6 +385,7 @@ pytorch.Node = class { case 's': value = node.s(name); type = 'string'; break; case 'i': value = node.i(name); type = 'int64'; break; case 'f': value = node.f(name); type = 'float32'; break; + case 't': value = node.t(name); type = 'tensor'; break; case 'ss': value = node.ss(name); type = 'string[]'; break; case 'tys': value = node.tys(name).map((ty) => pytorch.Utility.toType(ty)); type = 'type[]'; break; case 'ival': value = node.ival(name); break; @@ -417,59 +398,15 @@ pytorch.Node = class { const attribute = new pytorch.Argument(name, value, type); this.attributes.push(attribute); } - let match = true; - let count = 0; - for (const input of inputs) { - const value = input.value; - let values = []; - if (pytorch.Utility.isObject(value)) { - values = Object.values(value); - } else if (pytorch.Utility.isTensor(value)) { - values = [value]; - } else if (Array.isArray(value) && value.every((value) => pytorch.Utility.isTensor(value))) { - values = value; - } else if (input instanceof torch.Value && input.type() instanceof torch.ListType && input.type().getElementType() instanceof torch.TensorType) { - if (input.node() && - input.node().kind() === 'prim::ListConstruct' && - input.uses().length === 1 && - input.node().inputs().every((input) => pytorch.Utility.isTensor(input.value))) { - values = input.node().inputs().map((input) => input.value); - } - } - for (const value of values) { - const parameter = initializers.get(value); - if (parameter) { - if (value.__parent__ && (module === null || module === value.__parent__)) { - module = value.__parent__; - count++; - } else if (value.__name__ && value.__name__.startsWith('CONSTANTS.c')) { - count++; - } else { - match = false; - break; - } + const mapTensor = (value) => { + if (value.identifier && pytorch.Utility.isTensor(value.value)) { + const identifier = value.identifier; + if (!values.has(identifier)) { + const tensor = new pytorch.Tensor(identifier, value.value); + values.set(identifier, new pytorch.Value(identifier, null, null, tensor)); } + return values.map(identifier); } - if (!match) { - break; - } - } - if (module) { - const tensors = new Map(); - for (const [name, value] of module.named_parameters()) { - tensors.set(name, value); - } - for (const [name, value] of module.named_buffers()) { - tensors.set(name, value); - } - tensors.delete('num_batches_tracked'); - if (tensors.size === count && match) { - module.__hide__ = true; - } else { - module = null; - } - } - const mapTensor = (value) => { let initializer = null; let identifier = value.hasDebugName() ? `%${value.debugName().toString()}` : `%${value.unique().toString()}`; if (value.value) { @@ -502,11 +439,20 @@ pytorch.Node = class { } else if (array && Array.isArray(obj) && obj.every((obj) => initializers.has(obj))) { const node = obj.map((obj) => new pytorch.Node(execution, metadata, name, type.qualified_name(), obj, initializers, values)); argument = new pytorch.Argument(name, node, 'object[]'); - } else { + } else if (array && input.node().kind() === 'prim::ListConstruct' && input.uses().length === 1 && input.node().inputs().every((input) => input.value)) { + const node = input.node().inputs().map((input) => new pytorch.Node(execution, metadata, name, null, input.value, initializers, values)); + argument = new pytorch.Argument(name, node, 'object[]'); + } else if (input.value === undefined) { const identifier = pytorch.Utility.unique(input); const value = values.map(identifier); argument = new pytorch.Argument(name, [value]); + } else { + const node = new pytorch.Node(execution, metadata, null, null, input.value, initializers, values); + argument = new pytorch.Argument(name, node, 'object'); } + } else if ((input.type() instanceof torch.TensorType || (input.type() instanceof torch.OptionalType && input.type().getElementType() instanceof torch.TensorType)) && pytorch.Utility.isTensor(input.value)) { + const value = mapTensor(input); + argument = new pytorch.Argument(name, [value]); } else if (input instanceof torch.Value && !pytorch.Utility.isTensor(input.value)) { if (input.node() === null && input.value !== undefined) { if (Array.isArray(input.value) && input.value.every((value) => pytorch.Utility.isTensor(value))) { @@ -522,10 +468,7 @@ pytorch.Node = class { if (pytorch.Utility.isTensor(value.value)) { return mapTensor(value); } - if (value.uses().length === 1 && value.node().kind() === 'prim::Constant') { - return getAttribute(value.node(), 'value')[1]; - } - if (value.uses().length === 1 && value.node() === input.node() && value.value !== undefined) { + if (value.uses().length === 1 && value.value !== undefined) { return value.value; } const identifier = pytorch.Utility.unique(value); @@ -539,6 +482,12 @@ pytorch.Node = class { } } else if (input.type() instanceof torch.StringType && typeof input.value === 'string') { argument = new pytorch.Argument(name, input.value, 'string'); + } else if (input.type() instanceof torch.BoolType && typeof input.value === 'boolean') { + argument = new pytorch.Argument(name, input.value, 'boolean'); + } else if (input.type() instanceof torch.IntType && typeof input.value === 'number') { + argument = new pytorch.Argument(name, input.value, 'int64'); + } else if (input.type() instanceof torch.FloatType && typeof input.value === 'number') { + argument = new pytorch.Argument(name, input.value, 'float32'); } else if (input.node() && input.uses().length === 1 && input.node().kind() === 'prim::Constant') { let [type, value] = getAttribute(input.node(), 'value'); const valueType = input.type(); @@ -713,7 +662,10 @@ pytorch.Node = class { throw new pytorch.Error(`Unsupported node operation '${obj.op}'.`); } } else { - if (torch && obj instanceof torch.jit._script.RecursiveScriptModule && obj._c && obj._c.qualified_name) { + if (torch && obj instanceof torch.ScriptObject) { + type = obj._type().qualified_name(); + obj = obj._ivalue; + } else if (torch && obj instanceof torch.jit._script.RecursiveScriptModule && obj._c && obj._c.qualified_name) { type = obj._c._type(); const target = { _modules: obj._modules, @@ -766,7 +718,7 @@ pytorch.Node = class { const inputs = new Map(Array.isArray(this.type.inputs) ? this.type.inputs.map((input) => [input.name, input]) : []); const list = obj instanceof Map ? Array.from(obj) : Object.entries(obj); for (const [name, value] of list) { - if (name === '__class__' || name === '__parent__' || name === '__name__') { + if (name === '__class__' || name === '__name__') { continue; } else if (pytorch.Utility.isInstance(value, 'collections.OrderedDict') && value instanceof Map && value.size === 0) { continue; @@ -891,15 +843,6 @@ pytorch.Node = class { } } } - if (module && module.__name__) { - this.name = module.__name__; - while (module.__parent__) { - module = module.__parent__; - if (module.__name__) { - this.name = `${module.__name__}.${this.name}`; - } - } - } } }; @@ -1676,6 +1619,7 @@ pytorch.Execution = class extends python.Execution { } } this._graph = this.invoke('torch.Graph', []); + this._constants = new Map(); this._values = new Map(); } @@ -1693,36 +1637,6 @@ pytorch.Execution = class extends python.Execution { return this._graph; } - variable(obj, node) { - const torch = this.torch; - if (this._values.has(obj)) { - return this._values.get(obj); - } - let value = null; - if (node) { - value = node.addOutput(); - } else if (obj instanceof torch.Value) { - value = obj; - } else { - value = new torch.Value(node ? node : this.graph); - } - if (pytorch.Utility.isTensor(obj)) { - value.value = obj; - value.setType(torch.TensorType.get()); - if (typeof obj !== 'string' && typeof obj !== 'number') { - this._values.set(obj, value); - } - if (pytorch.Utility.isTensor(obj)) { - obj.__variable__ = pytorch.Utility.unique(value); - } - } - if (typeof obj === 'string') { - value.value = obj; - value.setType(torch.StringType.get()); - } - return value; - } - resolve(name) { const index = name.lastIndexOf('.'); const memberName = index === -1 ? name : name.substring(index + 1, name.length); @@ -1818,6 +1732,15 @@ pytorch.Execution = class extends python.Execution { return this._graph.create(kind, n_outputs).setSourceRange(loc); } + value(expr, context, typehint) { + const torch = this.torch; + const value = this.expression(expr, context, typehint); + if (value instanceof torch.Value) { + return value; + } + return this.constant(value); + } + expression(expr, context, typehint) { if (!this.trace) { return super.expression(expr, context); @@ -1833,6 +1756,7 @@ pytorch.Execution = class extends python.Execution { } case 'Constant': { if (expr.value === true || expr.value === false) { + // debugger; return this._graph.insertConstant(expr.value); } break; @@ -1842,7 +1766,7 @@ pytorch.Execution = class extends python.Execution { if (target instanceof ast.Name) { let value = this.expression(expr.value, context); if (typeof value === 'string' || typeof value === 'boolean' || typeof value === 'number') { - value = this._graph.insertConstant(value); + value = this.constant(value); } else if (typeof value !== 'object' && value !== undefined) { throw new pytorch.Error(`Unsupported assignment value type '${typeof value}'.`); } @@ -1944,23 +1868,19 @@ pytorch.Execution = class extends python.Execution { return node.output(); } if (func instanceof ast.Name && func.id === 'unchecked_cast') { - let value = this.expression(expr.args[1], context); - if (value instanceof torch.Value === false) { // remove - value = this.variable(value); - } + const value = this.value(expr.args[1], context); const type = this.type(expr.args[0]); return this._graph.insertUncheckedCast(value, type); } if (func instanceof ast.Name && func.id === 'isinstance') { - const value = this.expression(expr.args[0], context); + const value = this.value(expr.args[0], context); let [, types] = expr.args; if (types instanceof ast.Tuple) { types = types.elts.map((expr) => this.type(expr)); } else { types = [this.type(types)]; } - const v = this.variable(value); // remove - const node = this._graph.createIsInstance(v, types); + const node = this._graph.createIsInstance(value, types); this._graph.insertNode(node); return node.output(); } @@ -2031,6 +1951,12 @@ pytorch.Execution = class extends python.Execution { break; } case 'Attribute': { + if (expr.value instanceof ast.Name && expr.value.id === 'CONSTANTS') { + const constant = this.builtins[expr.value.id][expr.attr]; + const value = this._graph.insertConstant(constant); + value.setDebugName(`${expr.value.id}.${expr.attr}`); + return value; + } const target = this.target(expr.value, context); const attr = expr.attr; if (target instanceof torch.Value && target.type() instanceof torch.ClassType) { @@ -2042,17 +1968,17 @@ pytorch.Execution = class extends python.Execution { } case 'List': { const list = expr.elts.map((item) => this.expression(item, context)); - if (/* list.length > 0 && */ list.every((item) => item instanceof torch.Value || pytorch.Utility.isTensor(item) || Number.isInteger(item) || typeof item === 'string' || item === null)) { + if (/* list.length > 0 && */ list.every((item) => item instanceof torch.Value || pytorch.Utility.isTensor(item) || typeof item === 'number' || typeof item === 'string' || item === null)) { const values = []; let item_type = null; for (const item of list) { let value = null; if (item instanceof torch.Value) { value = item; - } else if (Number.isInteger(item) || typeof item === 'string' || item === null) { + } else if (typeof item === 'number' || typeof item === 'string' || item === null) { value = this._graph.insertConstant(item); } else if (pytorch.Utility.isTensor(item)) { - value = this.variable(item, null); + value = item; } else { throw new pytorch.Error('Unsupported list item type.'); } @@ -2077,7 +2003,7 @@ pytorch.Execution = class extends python.Execution { if (elt instanceof torch.Value) { value = elt; } else if (pytorch.Utility.isTensor(elt)) { - value = this.variable(elt, null); + throw new pytorch.Error(); } else if (elt === null || Number.isInteger(elt) || typeof elt === 'number' || typeof elt === 'boolean' || typeof elt === 'string') { value = this._graph.insertConstant(elt); } else { @@ -2096,18 +2022,16 @@ pytorch.Execution = class extends python.Execution { let keyType = null; let valueType = null; for (let i = 0; i < expr.keys.length; i++) { - const key = this.expression(expr.keys[i], context); - const keyValue = this.variable(key, null); - if (!keyType || keyType.isSubtypeOf(keyValue.type())) { - keyType = keyValue.type(); + const key = this.value(expr.keys[i], context); + if (!keyType || keyType.isSubtypeOf(key.type())) { + keyType = key.type(); } - keys.push(keyValue); - const value = this.expression(expr.values[i], context); - const valueValue = this.variable(value, null); - if (!valueType || valueType.isSubtypeOf(valueValue.type())) { - valueType = valueValue.type(); + keys.push(key); + const value = this.value(expr.values[i], context); + if (!valueType || valueType.isSubtypeOf(value.type())) { + valueType = value.type(); } - values.push(valueValue); + values.push(value); } const key_type = typehint ? typehint.getKeyType() : keyType; const value_type = typehint ? typehint.getValueType() : valueType; @@ -2659,13 +2583,20 @@ pytorch.Execution = class extends python.Execution { throw new pytorch.Error(`Unsupported type expression '${expr.type}'.`); } + constant(constant) { + if (!this._constants.has(constant)) { + const value = this._graph.insertConstant(constant); + this._constants.set(constant, value); + } + return this._constants.get(constant); + } + call(target, name, args, context, location) { if (!this.trace) { return super.call(target, name, args, context); } const ast = this.ast; const torch = this.torch; - const builtins = this.builtins; if (name === '__new__') { const identifier = this.identifier(target); if (identifier) { @@ -2712,8 +2643,7 @@ pytorch.Execution = class extends python.Execution { const return_field_names = [schema.returns[0].name]; const return_types = [schema.returns[0].real_type]; const inputs = [moduleTarget]; - for (const arg of evalArgs) { - const value = this.variable(arg); + for (const value of evalArgs) { inputs.push(value); } const matchedSchema = new torch.jit.MatchedSchema(inputs, return_types, return_field_names, name); @@ -2725,8 +2655,7 @@ pytorch.Execution = class extends python.Execution { const identifier = `${prefix}.${name}`; const type = this._resolver.resolveType(identifier); if (type instanceof torch.TupleType) { - const evalArgs = args.map((expression) => this.expression(expression, context)); - const values = evalArgs.map((arg) => this.variable(arg)); + const values = args.map((expression) => this.value(expression, context)); const node = this._graph.createTuple(values, type); node.setSourceRange(location); this._graph.insertNode(node); @@ -2737,8 +2666,7 @@ pytorch.Execution = class extends python.Execution { this._graph.insertNode(node); node.s_('name', name); const evalArgs = args.map((expression) => this.expression(expression, context)); - for (const arg of evalArgs) { - const value = this.variable(arg); + for (const value of evalArgs) { node.addInput(value); } return node.output(); @@ -2800,21 +2728,11 @@ pytorch.Execution = class extends python.Execution { if (v instanceof torch.Value) { input = v; match = true; - } else { - const values = []; - for (const arg of v || []) { - const tensor = arg; - if (tensor) { - tensor.__count__ = (tensor.__count__ || 0) + 1; - } - const value = this.variable(tensor); - value.setType(torch.TensorType.get()); - values.push(value); - } - const node = this._graph.createList(torch.TensorType.get(), values); - this._graph.insertNode(node); - input = node.output(); + } else if (v === null) { + input = this.constant(v); match = true; + } else { + throw new pytorch.Error(); } } else { if (optional) { @@ -2834,14 +2752,11 @@ pytorch.Execution = class extends python.Execution { if (v instanceof torch.Value) { input = v; match = true; - } else { - const value = this.variable(v); - value.value = v; - if (!value.type() && v instanceof builtins.dict) { - value.setType(type); - } - input = value; + } else if (v === null || typeof v === 'number' || typeof v === 'string' || typeof v === 'boolean') { + input = this.constant(v); match = true; + } else { + throw new pytorch.Error(); } } if (match) { @@ -2916,10 +2831,11 @@ pytorch.Execution = class extends python.Execution { } if (v instanceof torch.Value) { node.addInput(v); - } else { - const value = this.variable(v); - value.value = v; + } else if (v === null || typeof v === 'number' || typeof v === 'string' || typeof v === 'boolean') { + const value = this.constant(v); node.addInput(value); + } else { + throw new pytorch.Error(); } } } diff --git a/test/models.json b/test/models.json index 73929c250f..613f6fc331 100644 --- a/test/models.json +++ b/test/models.json @@ -5307,7 +5307,7 @@ "target": "alexnet_traced.pt.zip", "source": "https://github.com/lutzroeder/netron/files/6096602/alexnet_traced.pt.zip", "format": "TorchScript v1.6", - "assert": "model.graphs[0].nodes.length == 28", + "assert": "model.graphs[0].nodes.length == 25", "link": "https://github.com/lutzroeder/netron/issues/281" }, { @@ -5500,7 +5500,7 @@ "target": "cruise_go_vehicle_model.pt", "source": "https://raw.githubusercontent.com/ApolloAuto/apollo/master/modules/prediction/data/cruise_go_vehicle_model.pt", "format": "TorchScript v1.0", - "assert": "model.graphs[0].nodes.length == 73", + "assert": "model.graphs[0].nodes.length == 63", "link": "https://github.com/ApolloAuto/apollo" }, { @@ -5531,7 +5531,6 @@ "target": "deeplabv3_scripted.pt", "source": "https://github.com/lutzroeder/netron/files/5604999/deeplabv3_scripted.pt.zip[deeplabv3_scripted.pt]", "format": "TorchScript v1.6", - "assert": "model.graphs[0].nodes.length == 478", "link": "https://github.com/lutzroeder/netron/issues/630" }, { @@ -5757,7 +5756,6 @@ "target": "lane_scanning_vehicle_model.pt", "source": "https://raw.githubusercontent.com/ApolloAuto/apollo/master/modules/prediction/data/lane_scanning_vehicle_model.pt", "format": "TorchScript v1.0", - "assert": "model.graphs[0].nodes.length == 121", "link": "https://github.com/ApolloAuto/apollo" }, { @@ -6262,6 +6260,7 @@ "target": "resnet18.pt", "source": "https://github.com/lutzroeder/netron/files/5212015/resnet18.zip[resnet18.pt]", "format": "TorchScript v1.6", + "assert": "model.graphs[0].nodes[1].inputs[1].value.inputs[5].value == 1", "link": "https://github.com/lutzroeder/netron/issues/559" }, { @@ -6598,7 +6597,7 @@ "target": "TestSerialization.test_lstm.traced.pt", "source": "https://github.com/user-attachments/files/16121906/TestSerialization.test_lstm.traced.pt.zip[TestSerialization.test_lstm.traced.pt]", "format": "TorchScript v1.6", - "assert": "model.graphs[0].nodes.length == 10", + "assert": "model.graphs[0].nodes[5].inputs[2].value[0].inputs[0].value == 'quantized_dynamic'", "link": "https://github.com/lutzroeder/netron/issues/1067" }, { @@ -6642,6 +6641,7 @@ "target": "traced_online_obs_enc.pt", "source": "https://raw.githubusercontent.com/ApolloAuto/apollo/master/modules/prediction/data/traced_online_obs_enc.pt", "format": "TorchScript v1.0", + "assert": "model.graphs[0].nodes[46].inputs[1].value[0] == 1", "link": "https://github.com/ApolloAuto/apollo" }, { @@ -6819,6 +6819,7 @@ "target": "yolox_m.torchscript.pt", "source": "https://github.com/lutzroeder/netron/files/15031984/yolox_m.torchscript.pt.zip[yolox_m.torchscript.pt]", "format": "TorchScript v1.6", + "assert": "model.graphs[0].nodes.length == 441", "link": "https://github.com/lutzroeder/netron/issues/842" }, {