From cee1b0046a4da17b79eb4f50c971627cd977d795 Mon Sep 17 00:00:00 2001 From: Lutz Roeder Date: Sun, 20 Oct 2024 21:19:03 -0700 Subject: [PATCH] Update pytorch.js (#1061) --- source/python.js | 31 +- source/pytorch-metadata.json | 29 ++ source/pytorch.js | 573 +++++++++++++++++++++++------------ test/models.json | 4 +- tools/pytorch_script.py | 20 ++ 5 files changed, 450 insertions(+), 207 deletions(-) diff --git a/source/python.js b/source/python.js index d496b2de3c..95b7696a5e 100644 --- a/source/python.js +++ b/source/python.js @@ -4377,6 +4377,10 @@ python.Execution = class { } throw new python.Error(`Schema '${op_name}.${overload_name}' not found.`); }); + this.registerFunction('torch._C._jit_get_schemas_for_operator', (op_name) => { + const registry = torch._C._get_registry(); + return registry.getAllOperatorsFor(op_name).map((op) => op.schema()); + }); this.registerFunction('torch._C._jit_get_operation', (op_name) => { const registry = torch._C._get_registry(); const sortedOps = registry.getAllOperatorsFor(op_name); @@ -6152,14 +6156,14 @@ python.Execution = class { }); this.registerType('torch.Type', class { - constructor(kind, name) { + constructor(kind, annotation_str) { this._kind = kind; - if (name) { - this._name = name; + if (annotation_str) { + this._annotation_str = annotation_str; } } - static get(kind, name) { - return new torch.Type(kind, name); + static get(kind, annotation_str) { + return new torch.Type(kind, annotation_str); } kind() { return this._kind; @@ -6171,8 +6175,8 @@ python.Execution = class { throw new python.Error(`Not implemented '${this.kind()}'.`); } str() { - if (this._kind === 'VarType' && this._name) { - return this._name; + if (this._kind === 'VarType' && this._annotation_str) { + return this._annotation_str; } else if (this._kind === 'ScalarTypeType') { return 'ScalarType'; } else if (this._kind === 'QSchemeType') { @@ -6722,6 +6726,7 @@ python.Execution = class { case 't': case 't1': case 't2': case 'tVal': return torch.Type.get('VarType', value); case 'Any': return torch.AnyType.get(); case 'AnyEnumType': return torch.Type.get('AnyEnumType'); + case 'Dimname': return torch.StringType.get(); case 'QScheme': return torch.Type.get('QSchemeType'); case 'Stream': return torch.StreamObjType.get(); case 'Storage': return torch.Type.get('Storage'); @@ -7036,7 +7041,7 @@ python.Execution = class { }); this.registerType('torch.FunctionSchema', class { constructor(name, overload_name, args, returns, is_vararg, is_varret) { - let index = name.indexOf('('); + const index = name.indexOf('('); if (index === -1) { this._name = name; this._overload_name = overload_name; @@ -7046,15 +7051,15 @@ python.Execution = class { this._is_varret = is_varret; } else { const value = name.substring(0, index).trim(); - this._buffer = name.substring(index, name.length); - index = value.indexOf('.'); - if (index === -1) { + const dot = value.indexOf('.'); + if (dot === -1) { this._name = value; this._overload_name = ''; } else { - this._name = value.substring(0, index); - this._overload_name = value.substring(index + 1, value.length); + this._name = value.substring(0, dot); + this._overload_name = value.substring(dot + 1, value.length); } + this._buffer = name.substring(index, name.length); } } static parse(schema) { diff --git a/source/pytorch-metadata.json b/source/pytorch-metadata.json index 0f6f6d009f..26e5d7f76a 100755 --- a/source/pytorch-metadata.json +++ b/source/pytorch-metadata.json @@ -738,6 +738,16 @@ { "type": "Tensor" } ] }, + { + "name": "aten::__is__(t1 self, t2 obj) -> bool", + "inputs": [ + { "name": "self", "type": "t1" }, + { "name": "obj", "type": "t2" } + ], + "outputs": [ + { "type": "boolean" } + ] + }, { "name": "aten::__isnot__(t1 self, t2 obj) -> bool", "inputs": [ @@ -5084,6 +5094,25 @@ { "type": "Tensor" } ] }, + { + "name": "aten::device(str a) -> Device", + "inputs": [ + { "name": "a", "type": "string" } + ], + "outputs": [ + { "type": "Device" } + ] + }, + { + "name": "aten::device.with_index(str type, int index) -> Device", + "inputs": [ + { "name": "type", "type": "string" }, + { "name": "index", "type": "int64" } + ], + "outputs": [ + { "type": "Device" } + ] + }, { "name": "aten::diag(Tensor self, int diagonal=0) -> Tensor", "inputs": [ diff --git a/source/pytorch.js b/source/pytorch.js index 7f707c057c..28baf928a5 100644 --- a/source/pytorch.js +++ b/source/pytorch.js @@ -371,33 +371,6 @@ pytorch.Node = class { type.name = name.indexOf('::') === -1 ? name : name.split('::').pop().split('.')[0]; return type; }; - const createAttribute = (metadata, name, value) => { - let visible = true; - let type = 'attribute'; - metadata = name === 'training' ? { type: 'boolean', visible: false } : metadata; - if (metadata) { - if (metadata.type) { - type = metadata.type; - } - if (metadata.visible === false) { - visible = false; - } else if (metadata.default !== undefined) { - if (Array.isArray(value)) { - if (Array.isArray(metadata.default)) { - visible = value.length !== metadata.default || !value.every((item, index) => item === metadata.default[index]); - } else { - visible = !value.every((item) => item === metadata.default); - } - } else { - visible = value !== metadata.default; - } - } - } - if (Array.isArray(value) && value.length > 0 && value.every((obj) => obj && obj.__class__ && obj.__class__.__module__ && obj.__class__.__module__.startsWith('torch.nn'))) { - value = '?'; - } - return new pytorch.Argument(name, value, type, visible); - }; let module = null; if (pytorch.Utility.isInstance(obj, 'torch.Node')) { const node = obj; @@ -485,13 +458,13 @@ pytorch.Node = class { type = type.getElementType(); } let argument = null; - if (arg && pytorch.Utility.isInstance(arg.real_type, 'torch.ClassType')) { + if (type && pytorch.Utility.isInstance(type, 'torch.ClassType')) { const obj = input.value; if (!array && initializers.has(obj)) { const node = new pytorch.Node(metadata, name, type.qualified_name(), obj, initializers, values); argument = new pytorch.Argument(name, node, 'object'); } else if (array && Array.isArray(obj) && obj.every((obj) => initializers.has(obj))) { - const node = obj.map((obj) => new pytorch.Node(metadata, name, type, obj, initializers, values)); + const node = obj.map((obj) => new pytorch.Node(metadata, name, type.qualified_name(), obj, initializers, values)); argument = new pytorch.Argument(name, node, 'object[]'); } else { const identifier = input.unique().toString(); @@ -797,6 +770,33 @@ pytorch.Node = class { const argument = new pytorch.Argument(name, node, 'object', visible); this.inputs.push(argument); } else { + const createAttribute = (metadata, name, value) => { + let visible = true; + let type = 'attribute'; + metadata = name === 'training' ? { type: 'boolean', visible: false } : metadata; + if (metadata) { + if (metadata.type) { + type = metadata.type; + } + if (metadata.visible === false) { + visible = false; + } else if (metadata.default !== undefined) { + if (Array.isArray(value)) { + if (Array.isArray(metadata.default)) { + visible = value.length !== metadata.default || !value.every((item, index) => item === metadata.default[index]); + } else { + visible = !value.every((item) => item === metadata.default); + } + } else { + visible = value !== metadata.default; + } + } + } + if (Array.isArray(value) && value.length > 0 && value.every((obj) => obj && obj.__class__ && obj.__class__.__module__ && obj.__class__.__module__.startsWith('torch.nn'))) { + value = '?'; + } + return new pytorch.Argument(name, value, type, visible); + }; const argument = createAttribute(metadata.attribute(type, name), name, value); this.inputs.push(argument); } @@ -2923,7 +2923,7 @@ pytorch.jit.Execution = class extends pytorch.Execution { } let type = parameter.type; let optional = false; - if (parameter.type.endsWith('?')) { + if (type.endsWith('?')) { type = parameter.type.substring(0, parameter.type.length - 1); optional = true; } @@ -3483,6 +3483,144 @@ pytorch.jit.Execution = class extends pytorch.Execution { return result[0]; } + isNativeType(obj, type) { + const torch = this.torch; + switch (type.str()) { + case 'Tensor': + return !Array.isArray(obj) && (pytorch.Utility.isTensor(obj) || obj === null || + (obj instanceof torch.Value && obj.type() instanceof torch.TensorType) || + (obj instanceof torch.Value && obj.type() instanceof torch.OptionalType && obj.type().getElementType() instanceof torch.TensorType)); + case 'Tensor[]': + return (Array.isArray(obj) && obj.length > 0 && obj.every((tensor) => pytorch.Utility.isTensor(tensor) || tensor === null || (tensor instanceof torch.Value && tensor.type() instanceof torch.TensorType))) || + (obj instanceof torch.Value && obj.type() instanceof torch.ListType && obj.type().getElementType() instanceof torch.TensorType); + case 'Scalar': + return (obj !== null && (obj !== Object(obj) || obj instanceof Number)) || + (pytorch.Utility.isTensor(obj) && Array.isArray(obj.size()) && obj.size().length === 0) || + (obj instanceof torch.Value && (obj.type() instanceof torch.IntType || obj.type() instanceof torch.FloatType || obj.type() instanceof torch.NumberType)); + case 'bool': + return obj === true || obj === false || (pytorch.Utility.isInstance(obj, 'torch.Value') && obj.type() instanceof torch.BoolType); + case 'bool[]': + if (Array.isArray(obj) && obj.every((item) => item === true || item === false)) { + return true; + } + if (obj instanceof torch.Value && obj.type() instanceof torch.ListType && obj.type().getElementType() instanceof torch.BoolType) { + return true; + } + return false; + case 'SymInt': + case 'int': + return Number.isInteger(obj) || typeof obj === 'bigint' || + (typeof obj === 'number' && isNaN(obj)) || (obj instanceof Number) || + (obj instanceof torch.Value && obj.type() instanceof torch.IntType) || + (obj instanceof torch.Value && obj.type() instanceof torch.OptionalType && obj.type().getElementType() instanceof torch.IntType); + case 'SymInt[]': + case 'SymInt[2]': + case 'SymInt[3]': + case 'SymInt[4]': + case 'SymInt[5]': + case 'SymInt[6]': + if (Array.isArray(obj) && obj.every((item) => this.isNativeType(item, torch.SymIntType.get()) || item === undefined || (item.__class__ === 'number' && isNaN(item)))) { + return true; + } + if (obj instanceof torch.Value && obj.type() instanceof torch.ListType && obj.type().getElementType() instanceof torch.IntType) { + return true; + } + return false; + case 'SymInt[1]': + return this.isNativeType(obj, torch.IntType.get()) || this.isNativeType(obj, torch.ListType.get(torch.IntType.get())); + case 'int[]': + case 'int[2]': + case 'int[3]': + return (Array.isArray(obj) && obj.every((item) => this.isNativeType(item, torch.IntType.get()) || item === undefined || (item.__class__ === 'number' && isNaN(item))) || + (obj instanceof torch.Value && obj.type() instanceof torch.ListType && obj.type().getElementType() instanceof torch.IntType)) || + (obj instanceof torch.Value && obj.type() instanceof torch.OptionalType && obj.type().getElementType() instanceof torch.ListType && obj.type().getElementType().getElementType() instanceof torch.IntType); + case 'int[1]': + case 'float': + return obj !== null && (typeof obj === 'number' || obj instanceof Number) || + (pytorch.Utility.isInstance(obj, 'torch.Value') && pytorch.Utility.isInstance(obj.type(), 'torch.FloatType')); + case 'float[]': + if (Array.isArray(obj) && obj.every((item) => (typeof item === 'number' || item instanceof Number) && !isNaN(item))) { + return true; + } + if (pytorch.Utility.isInstance(obj, 'torch.Value') && pytorch.Utility.isInstance(obj.type(), 'torch.ListType') && (pytorch.Utility.isInstance(obj.type().getElementType(), 'torch.IntType') || pytorch.Utility.isInstance(obj.type().getElementType(), 'torch.FloatType'))) { + return true; + } + return false; + case 'str': + return obj === null || typeof obj === 'string' || + (obj instanceof torch.Value && obj.type() instanceof torch.StringType); + case 'str[]': + return (Array.isArray(obj) && obj.every((item) => item === null || typeof item === 'string')) || + (obj instanceof torch.Value && obj.type() instanceof torch.ListType && obj.type().getElementType() instanceof torch.StringType); + case 'str[][]': + return Array.isArray(obj) && obj.every((item) => Array.isArray(item) && item.every((item) => typeof item === 'string')); + case 'Layout': + case 'ScalarType': + case 'MemoryFormat': + return Number.isInteger(obj) || obj === null || + (obj instanceof torch.Value && obj.type() instanceof torch.IntType) || + (obj instanceof torch.Value && obj.type() instanceof torch.OptionalType && obj.type().getElementType() instanceof torch.IntType); + case 'Dimname': + return obj === null || (typeof obj === 'string' || obj instanceof String); + case 'Dimname[]': + return Array.isArray(obj) && obj.every((item) => item === null || typeof item === 'string'); + case 'Device': + return obj === null || obj === Object(obj); + case 't[]': + return Array.isArray(obj) || + (obj instanceof torch.Value && obj.type() instanceof torch.ListType) || + (obj instanceof torch.Value && obj.type() instanceof torch.OptionalType && obj.type().getElementType() instanceof torch.ListType); + case 't': + return true; + case 'AnyEnumType': + return false; + case 'complex': + return obj instanceof torch.Value && obj.type() instanceof torch.ComplexType; + case 'Any[]': + if (Array.isArray(obj)) { + return true; + } + if (obj instanceof torch.Value && obj.type() instanceof torch.ListType) { + return true; + } + return false; + case 't1': + case 't2': + return true; + default: { + if (type instanceof torch.ClassType && + obj && obj.__class__ && obj.__class__.__module__ && obj.__class__.__name__) { + return type.qualified_name() === `${obj.__class__.__module__}.${obj.__class__.__name__}`; + } + if (type instanceof torch.TupleType) { + throw new pytorch.Error('Not implemented.'); + /* + if (obj instanceof torch.Value && obj.type() instanceof torch.ListType && obj.type().getElementType() instanceof torch.TupleType) { + const elements = obj.type().getElementType().elements(); + if (elements.length === 2) { + if (pytorch.Utility.toType(elements[0]) === match[1]) { + return true; + } + } + } + return false; + */ + } + if (type instanceof torch.DictType) { + if (obj instanceof torch.Value && obj.type() instanceof torch.DictType) { + if ((type.getKeyType().kind() === 'VarType' || type.getKeyType().str() === obj.type().getKeyType().str()) || + (type.getValueType().kind() === 'VarType' || type.getValueType().str() === obj.type().getValueType().str())) { + return true; + } + } + return false; + } + // throw new pytorch.Error(`Unknown type '${type}'.`); + return true; + } + } + } + isType(obj, type) { const torch = this.torch; switch (type) { @@ -3491,8 +3629,8 @@ pytorch.jit.Execution = class extends pytorch.Execution { (obj instanceof torch.Value && obj.type() instanceof torch.TensorType) || (obj instanceof torch.Value && obj.type() instanceof torch.OptionalType && obj.type().getElementType() instanceof torch.TensorType)); case 'Tensor[]': - return Array.isArray(obj) && obj.length > 0 && - obj.every((tensor) => pytorch.Utility.isTensor(tensor) || tensor === null || (tensor instanceof torch.Value && tensor.type() instanceof torch.TensorType)); + return (Array.isArray(obj) && obj.length > 0 && obj.every((tensor) => pytorch.Utility.isTensor(tensor) || tensor === null || (tensor instanceof torch.Value && tensor.type() instanceof torch.TensorType))) || + (obj instanceof torch.Value && obj.type() instanceof torch.ListType && obj.type().getElementType() instanceof torch.TensorType); case 'Scalar': return (obj !== null && (obj !== Object(obj) || obj instanceof Number)) || (pytorch.Utility.isTensor(obj) && Array.isArray(obj.size()) && obj.size().length === 0) || @@ -3550,7 +3688,8 @@ pytorch.jit.Execution = class extends pytorch.Execution { } return false; case 'string[]': - return Array.isArray(obj) && obj.every((item) => item === null || typeof item === 'string'); + return (Array.isArray(obj) && obj.every((item) => item === null || typeof item === 'string')) || + (obj instanceof torch.Value && obj.type() instanceof torch.ListType && obj.type().getElementType() instanceof torch.StringType); case 'string[][]': return Array.isArray(obj) && obj.every((item) => Array.isArray(item) && item.every((item) => typeof item === 'string')); case 'Layout': @@ -3627,27 +3766,27 @@ pytorch.jit.Execution = class extends pytorch.Execution { const torch = this.torch; const type = name ? `${moduleName}.${name}` : moduleName; // https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/native_functions.yaml - let overloads = null; + let op_name = null; if (type.startsWith('torch.')) { - overloads = this._types.get(`aten::${type.substring(6)}`); - } else if (type.startsWith('ops.prim.')) { - overloads = this._types.get(`prim::${type.substring(9)}`); + op_name = `aten::${type.substring(6)}`; + } else if (type.startsWith('ops.')) { + op_name = type.substring(4).replace('.', '::'); } else if (type === 'int') { - overloads = this._types.get(`aten::Int`); + op_name = 'aten::Int'; } else if (type === 'str') { - overloads = this._types.get(`aten::str`); + op_name = 'aten::str'; } else if (type === 'bool') { - overloads = this._types.get(`aten::Bool`); + op_name = 'aten::Bool'; } else if (type === 'float') { - overloads = this._types.get(`aten::Float`); + op_name = 'aten::Float'; } else if (type === 'complex') { - overloads = this._types.get(`aten::Complex`); - } else if (type.startsWith('ops.') && !type.startsWith('ops.prim.')) { - const path = type.split('.'); - if (path.length === 3) { - overloads = this._types.get(`${path[1]}::${path[2]}`); - } - if (!overloads) { + op_name = 'aten::Complex'; + } + this.native = false; + if (this.native && op_name) { + const overloads = torch._C._jit_get_schemas_for_operator(op_name); + /* + if (!overloads && type.startsWith('ops.') && !type.startsWith('ops.prim')) { const module = this.import(moduleName); if (!module || !module[name]) { const metadata = {}; @@ -3674,53 +3813,83 @@ pytorch.jit.Execution = class extends pytorch.Execution { overloads = [metadata]; } } - } - if (!overloads) { - if (type.startsWith('aten::') || type.startsWith('prim::')) { - throw new pytorch.Error(`Unknown function '${type}'.`); - } - return null; - } - overloads = Array.isArray(overloads) ? overloads : [overloads]; - const evalArgs = args.map((argument) => { - if (argument.type === '=' && argument.target && argument.target.type === 'id') { - argument = argument.expression; + */ + if (!overloads) { + if (type.startsWith('aten::') || type.startsWith('prim::')) { + throw new pytorch.Error(`Unknown function '${type}'.`); + } + return null; } - return this.expression(argument, context); - }); - const matches = []; - for (const schema of overloads) { - const copyArgs = Array.prototype.slice.call(args); - const copyEvalArgs = Array.prototype.slice.call(evalArgs); - const parameters = Array.prototype.slice.call(schema.inputs || []).concat(Array.prototype.slice.call(schema.attributes || [])); - let next = false; - let kwarg_only = false; - while (copyEvalArgs.length > 0) { - if (parameters.length <= 0) { - next = !schema.name.startsWith('_caffe2::') && !schema.is_vararg; - break; + const evalArgs = args.map((argument) => { + if (argument.type === '=' && argument.target && argument.target.type === 'id') { + argument = argument.expression; } - if (copyArgs.every((arg) => arg.type === '=' && arg.target && arg.target.type === 'id') && - parameters.every((parameter) => parameter.type !== 'Tensor' && parameter.type !== 'Tensor[]')) { - const map = new Map(parameters.map((parameter) => [parameter.name, parameter])); - while (copyArgs.length > 0) { - const argument = copyArgs.shift(); - const arg = copyEvalArgs.shift(); - const parameter = map.get(argument.target.value); - if (!parameter) { + return this.expression(argument, context); + }); + const matches = []; + for (const schema of overloads) { + const parameters = schema.arguments || []; + let next = false; + let kwarg_only = false; + let position = 0; + let index = 0; + while (position < evalArgs.length) { + if (index >= parameters.length) { + next = !schema.name.startsWith('_caffe2::') && !schema.is_vararg; + break; + } + const arg = parameters[index]; + if (arg.kwarg_only) { + break; + } + index++; + const value = evalArgs[position]; + let type = arg.real_type; + let optional = false; + if (type instanceof torch.OptionalType) { + type = type.getElementType(); + optional = true; + } + if (optional === true && + (type instanceof torch.FloatType || type instanceof torch.BoolType || type instanceof torch.IntType || type instanceof torch.ComplexType || type.kind() === 'ScalarTypeType' || type instanceof torch.DeviceObjType || type.kind() === 'LayoutKind') && + value instanceof torch.Value && value.type() instanceof torch.NoneType) { + position++; + } else if (!this.isNativeType(value, type) && value !== null) { + if (optional) { + continue; + } + next = true; + break; + } else if (args[position].type === '=') { + next = true; + break; + } else { + position++; + } + } + if (next) { + continue; + } + if (args.every((arg, index) => index < position || (arg.type === '=' && arg.target && arg.target.type === 'id'))) { + const params = new Map(parameters.slice(index).map((a) => [a.name, a])); + while (position < args.length) { + const value = evalArgs[position]; + const arg = params.get(args[position].target.value); + position++; + if (!arg) { next = true; break; } - if (parameter.kwarg_only) { + if (arg.kwarg_only) { kwarg_only = true; } - let type = parameter.type; + let type = arg.real_type; let optional = false; - if (parameter.type.endsWith('?')) { - type = parameter.type.substring(0, parameter.type.length - 1); + if (type instanceof torch.OptionalType) { + type = type.getElementType(); optional = true; } - if (!this.isType(arg, type)) { + if (!this.isNativeType(value, type)) { if (optional) { continue; } @@ -3728,134 +3897,154 @@ pytorch.jit.Execution = class extends pytorch.Execution { break; } } - continue; } if (next) { - break; + continue; } - const parameter = parameters.shift(); - if (parameter.kwarg_only) { - kwarg_only = true; + if (position < evalArgs.length && !schema.is_vararg && !schema.name.startsWith('_caffe2::')) { + continue; } - const [argument] = copyEvalArgs; - /* if (type === 'Tensor' || (type === 'Scalar' && pytorch.Utility.isTensor(argument))) { - if (Array.isArray(argument) || (!pytorch.Utility.isTensor(argument) && argument !== null && argument !== undefined)) { - if (optional) { - continue; - } - next = true; - } else { - copyArgs.shift(); - copyEvalArgs.shift(); + if (!kwarg_only && parameters.slice(index).some((arg) => !arg.has_default_value())) { + continue; + } + matches.push(schema); + } + if (matches.length > 1) { + const keys = new Map([['IntType', 1], ['FloatType', 2], ['TensorType', 3], ['NumberType', 4]]); + matches.sort((a, b) => { + let keyA = keys.get(a.arguments[0].real_type.kind()) || 5; + let keyB = keys.get(b.arguments[0].real_type.kind()) || 5; + if (keyA === keyB && a.arguments.length > 1 && b.arguments.length > 1) { + keyA = keys.get(a.arguments[1].real_type.kind()) || 5; + keyB = keys.get(b.arguments[1].real_type.kind()) || 5; + } + return keyA - keyB; + }); + } + if (matches.length === 0) { + throw new pytorch.Error(`Unknown function '${op_name}'.`); + } + // return [matches[0], evalArgs]; + } + let overloads = this._types.get(op_name); + if (!overloads && type.startsWith('ops.') && !type.startsWith('ops.prim')) { + const module = this.import(moduleName); + if (!module || !module[name]) { + const metadata = {}; + metadata.name = type; + metadata.inputs = []; + metadata.outputs = []; + for (let i = 0; i < args.length; i++) { + const input = {}; + let argument = args[i]; + input.name = i.toString(); + if (argument.type === '=' && argument.target && argument.target.type === 'id') { + input.name = this.expression(argument.target, context); + argument = argument.expression; } - } else */ - let type = parameter.type; + const obj = this.expression(argument, context); + input.type = pytorch.Utility.getType(obj); + metadata.inputs.push(input); + } + const count = context.target.length > 0 ? context.target[context.target.length - 1].length : 0; + for (let i = 0; i < count; i++) { + metadata.outputs.push({ name: '', type: '' }); + } + this._metadata.add(type, metadata); + overloads = [metadata]; + } + } + if (!overloads) { + if (type.startsWith('aten::') || type.startsWith('prim::')) { + throw new pytorch.Error(`Unknown function '${type}'.`); + } + return null; + } + const evalArgs = args.map((argument) => { + if (argument.type === '=' && argument.target && argument.target.type === 'id') { + argument = argument.expression; + } + return this.expression(argument, context); + }); + const matches = []; + for (const schema of overloads) { + const parameters = schema.inputs || []; + let next = false; + let kwarg_only = false; + let position = 0; + let index = 0; + while (position < evalArgs.length) { + if (index >= parameters.length) { + next = !schema.name.startsWith('_caffe2::') && !schema.is_vararg; + break; + } + const arg = parameters[index]; + if (arg.kwarg_only) { + break; + } + index++; + const value = evalArgs[position]; + let type = arg.type; let optional = false; - if (parameter.type.endsWith('?')) { - type = parameter.type.substring(0, parameter.type.length - 1); + if (type.endsWith('?')) { + type = arg.type.substring(0, arg.type.length - 1); optional = true; } if (optional === true && (type === 'float32' || type === 'boolean' || type === 'int64' || type === 'complex' || type === 'ScalarType' || type === 'Device' || type === 'Layout') && - argument instanceof torch.Value && argument.type() instanceof torch.NoneType) { - copyArgs.shift(); - copyEvalArgs.shift(); - } else if (type === 'Tensor[]') { - const [argument] = copyEvalArgs; - if ((argument instanceof torch.Value && pytorch.Utility.toType(argument.type()) === 'Tensor[]') || - (Array.isArray(argument) && argument.every((item) => pytorch.Utility.isTensor(item) || item === null || (item instanceof torch.Value && item.type() instanceof torch.TensorType)))) { - copyArgs.shift(); - copyEvalArgs.shift(); - } else { - if (optional) { - continue; - } - next = true; + value instanceof torch.Value && value.type() instanceof torch.NoneType) { + position++; + } else if (!this.isType(value, type) && value !== null) { + if (optional) { + continue; } - /* } else if (type === 't[]') { - if (!Array.isArray(argument) && (argument instanceof torch.Value === false || argument.type() instanceof torch.ListType === false)) { - if (optional) { - continue; - } - next = true; - } else { - copyArgs.shift(); - copyEvalArgs.shift(); - }*/ + next = true; + break; + } else if (args[position].type === '=' && args[position].target.value !== arg.name) { + next = true; + break; } else { - const [arg] = copyArgs; - if (!this.isType(argument, type) && argument !== null) { + position++; + } + } + if (next) { + continue; + } + if (args.every((arg, index) => index < position || (arg.type === '=' && arg.target && arg.target.type === 'id'))) { + const params = new Map(parameters.slice(index).map((a) => [a.name, a])); + while (position < args.length) { + const value = evalArgs[position]; + const arg = params.get(args[position].target.value); + position++; + if (!arg) { + next = true; + break; + } + if (arg.kwarg_only) { + kwarg_only = true; + } + let type = arg.type; + let optional = false; + if (type.endsWith('?')) { + type = arg.type.substring(0, arg.type.length - 1); + optional = true; + } + if (!this.isType(value, type)) { if (optional) { continue; } next = true; - } else if (arg.type === '=') { - next = true; - // throw new pytorch.Error('Expected named argument.'); - } else { - copyArgs.shift(); - copyEvalArgs.shift(); + break; } } - if (next) { - break; - } } if (next) { continue; } - if (!kwarg_only && parameters.some((parameter) => parameter.default === undefined)) { + if (position < evalArgs.length && !schema.is_vararg && !schema.name.startsWith('_caffe2::')) { continue; } - for (let i = 0; i < schema.outputs.length; i++) { - const parameter = schema.outputs[i]; - switch (parameter.type) { - case 'Scalar': - case 'Tensor': - case 'Tensor[]': - case 'float32': - case 'float32[]': - case 'int64': - case 'int64[]': - case 'Device': - case 'boolean': - case 'boolean[]': - case 't': - case 't[]': - case 'complex': - case 'complex[]': - case 'string': - case 'string[]': - case 'Dict(string, Tensor)': - case 'Dict(Tensor, t)': - case 'Dict(boolean, t)': - case 'Dict(complex, t)': - case 'Dict(float32, t)': - case 'Dict(int64, t)': - case 'Dict(string, t)': - case 'Dict(Tensor, tVal)': - case 'Dict(boolean, tVal)': - case 'Dict(complex, tVal)': - case 'Dict(float32, tVal)': - case 'Dict(int64, tVal)': - case 'Dict(string, tVal)': - case '(string, t)[]': - case 'Any': - break; - case '__torch__.torch.classes.xnnpack.LinearOpContext': - case '__torch__.torch.classes.xnnpack.Conv2dOpContext': - case '__torch__.torch.classes.xnnpack.TransposeConv2dOpContext': - case '__torch__.torch.classes.rnn.CellParamsBase': - case '__torch__.torch.classes.quantized.LinearPackedParamsBase': - case '__torch__.torch.classes.quantized.Conv2dPackedParamsBase': - case '__torch__.torch.classes.quantized.Conv3dPackedParamsBase': - break; - default: { - throw new pytorch.Error(`Unknown return type '${parameter.type}'.`); - } - } - } - if (next) { + if (!kwarg_only && parameters.slice(index).some((parameter) => parameter.default === undefined)) { continue; } matches.push(schema); @@ -3872,10 +4061,10 @@ pytorch.jit.Execution = class extends pytorch.Execution { return keyA - keyB; }); } - if (matches.length > 0) { - return [matches[0], evalArgs]; + if (matches.length === 0) { + throw new pytorch.Error(`Unknown function '${type}'.`); } - throw new pytorch.Error(`Unknown function '${type}'.`); + return [matches[0], evalArgs]; } block(statements, context) { diff --git a/test/models.json b/test/models.json index 48de312d80..ce6c1da9eb 100644 --- a/test/models.json +++ b/test/models.json @@ -6281,7 +6281,7 @@ "target": "TestSerialization.test_lstm.traced.pt", "source": "https://github.com/user-attachments/files/16121906/TestSerialization.test_lstm.traced.pt.zip[TestSerialization.test_lstm.traced.pt]", "format": "TorchScript v1.6", - "assert": "model.graphs[0].nodes.length == 9", + "assert": "model.graphs[0].nodes.length == 10", "link": "https://github.com/lutzroeder/netron/issues/1067" }, { @@ -6289,7 +6289,7 @@ "target": "TFModel_traced_eager_quant.pt", "source": "https://github.com/lutzroeder/netron/files/10867120/TFModel_traced_eager_quant.pt.zip[TFModel_traced_eager_quant.pt]", "format": "TorchScript v1.6", - "assert": "model.graphs[0].nodes.length == 46", + "assert": "model.graphs[0].nodes.length == 51", "link": "https://github.com/lutzroeder/netron/issues/1067" }, { diff --git a/tools/pytorch_script.py b/tools/pytorch_script.py index 981acc889f..2e3647cbd7 100644 --- a/tools/pytorch_script.py +++ b/tools/pytorch_script.py @@ -430,6 +430,15 @@ def _write_metadata(value): 'aten::__and__.int(int a, int b) -> int', 'aten::__and__.Scalar(Tensor self, Scalar other) -> Tensor', 'aten::__and__.Tensor(Tensor self, Tensor other) -> Tensor', + 'aten::__contains__.Tensor(Dict(Tensor, t) dict, Tensor key) -> bool', + 'aten::__contains__.bool(Dict(bool, t) dict, bool key) -> bool', + 'aten::__contains__.complex(Dict(complex, t) dict, complex key) -> bool', + 'aten::__contains__.float(Dict(float, t) dict, float key) -> bool', + 'aten::__contains__.float_list(float[] l, float item) -> bool', + 'aten::__contains__.int(Dict(int, t) dict, int key) -> bool', + 'aten::__contains__.int_list(int[] l, int item) -> bool', + 'aten::__contains__.str(Dict(str, t) dict, str key) -> bool', + 'aten::__contains__.str_list(str[] l, str item) -> bool', 'aten::__getitem__.Dict_bool(Dict(bool, t) self, bool key) -> t(*)', 'aten::__getitem__.Dict_complex(Dict(complex, t) self, complex key) -> t(*)', 'aten::__getitem__.Dict_float(Dict(float, t) self, float key) -> t(*)', @@ -438,6 +447,7 @@ def _write_metadata(value): 'aten::__getitem__.Dict_Tensor(Dict(Tensor, t) self, Tensor key) -> t(*)', 'aten::__getitem__.str(str s, int index) -> str', 'aten::__getitem__.t(t[](a) list, int idx) -> t(*)', + 'aten::__is__(t1 self, t2 obj) -> bool', 'aten::_native_batch_norm_legit(Tensor input, Tensor? weight, Tensor? bias, Tensor(a!) running_mean, Tensor(b!) running_var, bool training, float momentum, float eps) -> (Tensor, Tensor, Tensor)', 'aten::_native_batch_norm_legit.no_stats(Tensor input, Tensor? weight, Tensor? bias, bool training, float momentum, float eps) -> (Tensor, Tensor, Tensor)', 'aten::_native_batch_norm_legit.no_stats_out(Tensor input, Tensor? weight, Tensor? bias, bool training, float momentum, float eps, *, Tensor(a!) out, Tensor(b!) save_mean, Tensor(c!) save_invstd) -> (Tensor(a!), Tensor(b!), Tensor(c!))', @@ -447,6 +457,13 @@ def _write_metadata(value): 'aten::_native_batch_norm_legit_no_training.out(Tensor input, Tensor? weight, Tensor? bias, Tensor running_mean, Tensor running_var, float momentum, float eps, *, Tensor(a!) out0, Tensor(b!) out1, Tensor(c!) out2) -> (Tensor(a!), Tensor(b!), Tensor(c!))', 'aten::_native_multi_head_attention(Tensor query, Tensor key, Tensor value, int embed_dim, int num_head, Tensor qkv_weight, Tensor qkv_bias, Tensor proj_weight, Tensor proj_bias, Tensor? mask=None, bool need_weights=True, bool average_attn_weights=True, int? mask_type=None) -> (Tensor, Tensor)', 'aten::_native_multi_head_attention.out(Tensor query, Tensor key, Tensor value, int embed_dim, int num_head, Tensor qkv_weight, Tensor qkv_bias, Tensor proj_weight, Tensor proj_bias, Tensor? mask=None, bool need_weights=True, bool average_attn_weights=True, int? mask_type=None, *, Tensor(a!) out0, Tensor(b!) out1) -> (Tensor(a!), Tensor(b!))', + 'aten::_set_item.Tensor(Dict(Tensor, t)(a!) l, Tensor(b -> *) idx, t(c -> *) v) -> ()', + 'aten::_set_item.bool(Dict(bool, t)(a!) l, bool(b -> *) idx, t(c -> *) v) -> ()', + 'aten::_set_item.complex(Dict(complex, t)(a!) l, complex(b -> *) idx, t(c -> *) v) -> ()', + 'aten::_set_item.float(Dict(float, t)(a!) l, float(b -> *) idx, t(c -> *) v) -> ()', + 'aten::_set_item.int(Dict(int, t)(a!) l, int(b -> *) idx, t(c -> *) v) -> ()', + 'aten::_set_item.str(Dict(str, t)(a!) l, str(b -> *) idx, t(c -> *) v) -> ()', + 'aten::_set_item.t(t[](a!) l, int idx, t(b -> *) el) -> t[](a!)', 'aten::add(Scalar a, Scalar b) -> Scalar', 'aten::add.Scalar(Tensor self, Scalar other, Scalar alpha=1) -> Tensor', 'aten::add.Scalar_out(Tensor self, Scalar other, Scalar alpha=1, *, Tensor(a!) out) -> Tensor(a!)', @@ -554,6 +571,8 @@ def _write_metadata(value): 'aten::Complex.Tensor_int(Tensor x, int y) -> complex', 'aten::Complex.Tensor_Tensor(Tensor a, Tensor b) -> complex', 'aten::ComplexImplicit(Tensor a) -> complex', + 'aten::device(str a) -> Device', + 'aten::device.with_index(str type, int index) -> Device', 'aten::dict.bool((bool, tVal)[] inputs) -> Dict(bool, tVal)', 'aten::dict.complex((complex, tVal)[] inputs) -> Dict(complex, tVal)', 'aten::dict.Dict_bool(Dict(bool, t)(a) self) -> Dict(bool, t)', @@ -893,6 +912,7 @@ def _write_metadata(value): 'aten::values.str(Dict(str, t) self) -> t[](*)', 'aten::values.Tensor(Dict(Tensor, t) self) -> t[](*)', 'aten::values(Tensor(a) self) -> Tensor(a)', + 'aten::warn(str message, int stacklevel=2) -> ()', 'prim::abs.complex(complex a) -> float', 'prim::abs.float(float a) -> float', 'prim::abs.int(int a) -> int',