diff --git a/source/python.js b/source/python.js index d496b2de3c..3485cb71cb 100644 --- a/source/python.js +++ b/source/python.js @@ -4377,6 +4377,10 @@ python.Execution = class { } throw new python.Error(`Schema '${op_name}.${overload_name}' not found.`); }); + this.registerFunction('torch._C._jit_get_schemas_for_operator', (op_name) => { + const registry = torch._C._get_registry(); + return registry.getAllOperatorsFor(op_name).map((op) => op.schema()); + }); this.registerFunction('torch._C._jit_get_operation', (op_name) => { const registry = torch._C._get_registry(); const sortedOps = registry.getAllOperatorsFor(op_name); @@ -7036,7 +7040,7 @@ python.Execution = class { }); this.registerType('torch.FunctionSchema', class { constructor(name, overload_name, args, returns, is_vararg, is_varret) { - let index = name.indexOf('('); + const index = name.indexOf('('); if (index === -1) { this._name = name; this._overload_name = overload_name; @@ -7046,15 +7050,15 @@ python.Execution = class { this._is_varret = is_varret; } else { const value = name.substring(0, index).trim(); - this._buffer = name.substring(index, name.length); - index = value.indexOf('.'); - if (index === -1) { + const dot = value.indexOf('.'); + if (dot === -1) { this._name = value; this._overload_name = ''; } else { - this._name = value.substring(0, index); - this._overload_name = value.substring(index + 1, value.length); + this._name = value.substring(0, dot); + this._overload_name = value.substring(dot + 1, value.length); } + this._buffer = name.substring(index, name.length); } } static parse(schema) { diff --git a/source/pytorch.js b/source/pytorch.js index 7f707c057c..e730f3ca26 100644 --- a/source/pytorch.js +++ b/source/pytorch.js @@ -3491,8 +3491,8 @@ pytorch.jit.Execution = class extends pytorch.Execution { (obj instanceof torch.Value && obj.type() instanceof torch.TensorType) || (obj instanceof torch.Value && obj.type() instanceof torch.OptionalType && obj.type().getElementType() instanceof torch.TensorType)); case 'Tensor[]': - return Array.isArray(obj) && obj.length > 0 && - obj.every((tensor) => pytorch.Utility.isTensor(tensor) || tensor === null || (tensor instanceof torch.Value && tensor.type() instanceof torch.TensorType)); + return (Array.isArray(obj) && obj.length > 0 && obj.every((tensor) => pytorch.Utility.isTensor(tensor) || tensor === null || (tensor instanceof torch.Value && tensor.type() instanceof torch.TensorType))) || + (obj instanceof torch.Value && obj.type() instanceof torch.ListType && obj.type().getElementType() instanceof torch.TensorType); case 'Scalar': return (obj !== null && (obj !== Object(obj) || obj instanceof Number)) || (pytorch.Utility.isTensor(obj) && Array.isArray(obj.size()) && obj.size().length === 0) || @@ -3627,52 +3627,52 @@ pytorch.jit.Execution = class extends pytorch.Execution { const torch = this.torch; const type = name ? `${moduleName}.${name}` : moduleName; // https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/native_functions.yaml - let overloads = null; + let op_name = null; if (type.startsWith('torch.')) { - overloads = this._types.get(`aten::${type.substring(6)}`); - } else if (type.startsWith('ops.prim.')) { - overloads = this._types.get(`prim::${type.substring(9)}`); + op_name = `aten::${type.substring(6)}`; + } else if (type.startsWith('ops.')) { + op_name = type.substring(4).replace('.', '::'); } else if (type === 'int') { - overloads = this._types.get(`aten::Int`); + op_name = 'aten::Int'; } else if (type === 'str') { - overloads = this._types.get(`aten::str`); + op_name = 'aten::str'; } else if (type === 'bool') { - overloads = this._types.get(`aten::Bool`); + op_name = 'aten::Bool'; } else if (type === 'float') { - overloads = this._types.get(`aten::Float`); + op_name = 'aten::Float'; } else if (type === 'complex') { - overloads = this._types.get(`aten::Complex`); - } else if (type.startsWith('ops.') && !type.startsWith('ops.prim.')) { - const path = type.split('.'); - if (path.length === 3) { - overloads = this._types.get(`${path[1]}::${path[2]}`); - } - if (!overloads) { - const module = this.import(moduleName); - if (!module || !module[name]) { - const metadata = {}; - metadata.name = type; - metadata.inputs = []; - metadata.outputs = []; - for (let i = 0; i < args.length; i++) { - const input = {}; - let argument = args[i]; - input.name = i.toString(); - if (argument.type === '=' && argument.target && argument.target.type === 'id') { - input.name = this.expression(argument.target, context); - argument = argument.expression; - } - const obj = this.expression(argument, context); - input.type = pytorch.Utility.getType(obj); - metadata.inputs.push(input); - } - const count = context.target.length > 0 ? context.target[context.target.length - 1].length : 0; - for (let i = 0; i < count; i++) { - metadata.outputs.push({ name: '', type: '' }); + op_name = 'aten::Complex'; + } + this.native = true; + if (this.native && op_name) { + // const overloads = torch._C._jit_get_schemas_for_operator(op_name); + } + let overloads = this._types.get(op_name); + if (!overloads && type.startsWith('ops.') && !type.startsWith('ops.prim')) { + const module = this.import(moduleName); + if (!module || !module[name]) { + const metadata = {}; + metadata.name = type; + metadata.inputs = []; + metadata.outputs = []; + for (let i = 0; i < args.length; i++) { + const input = {}; + let argument = args[i]; + input.name = i.toString(); + if (argument.type === '=' && argument.target && argument.target.type === 'id') { + input.name = this.expression(argument.target, context); + argument = argument.expression; } - this._metadata.add(type, metadata); - overloads = [metadata]; + const obj = this.expression(argument, context); + input.type = pytorch.Utility.getType(obj); + metadata.inputs.push(input); } + const count = context.target.length > 0 ? context.target[context.target.length - 1].length : 0; + for (let i = 0; i < count; i++) { + metadata.outputs.push({ name: '', type: '' }); + } + this._metadata.add(type, metadata); + overloads = [metadata]; } } if (!overloads) { @@ -3681,7 +3681,6 @@ pytorch.jit.Execution = class extends pytorch.Execution { } return null; } - overloads = Array.isArray(overloads) ? overloads : [overloads]; const evalArgs = args.map((argument) => { if (argument.type === '=' && argument.target && argument.target.type === 'id') { argument = argument.expression; @@ -3690,172 +3689,81 @@ pytorch.jit.Execution = class extends pytorch.Execution { }); const matches = []; for (const schema of overloads) { - const copyArgs = Array.prototype.slice.call(args); - const copyEvalArgs = Array.prototype.slice.call(evalArgs); - const parameters = Array.prototype.slice.call(schema.inputs || []).concat(Array.prototype.slice.call(schema.attributes || [])); + const parameters = Array.prototype.slice.call(schema.inputs || []); let next = false; let kwarg_only = false; - while (copyEvalArgs.length > 0) { + let position = 0; + while (position < evalArgs.length) { if (parameters.length <= 0) { next = !schema.name.startsWith('_caffe2::') && !schema.is_vararg; break; } - if (copyArgs.every((arg) => arg.type === '=' && arg.target && arg.target.type === 'id') && - parameters.every((parameter) => parameter.type !== 'Tensor' && parameter.type !== 'Tensor[]')) { - const map = new Map(parameters.map((parameter) => [parameter.name, parameter])); - while (copyArgs.length > 0) { - const argument = copyArgs.shift(); - const arg = copyEvalArgs.shift(); - const parameter = map.get(argument.target.value); - if (!parameter) { - next = true; - break; - } - if (parameter.kwarg_only) { - kwarg_only = true; - } - let type = parameter.type; - let optional = false; - if (parameter.type.endsWith('?')) { - type = parameter.type.substring(0, parameter.type.length - 1); - optional = true; - } - if (!this.isType(arg, type)) { - if (optional) { - continue; - } - next = true; - break; - } - } - continue; - } - if (next) { + if (parameters[0].kwarg_only) { break; } - const parameter = parameters.shift(); - if (parameter.kwarg_only) { - kwarg_only = true; - } - const [argument] = copyEvalArgs; - /* if (type === 'Tensor' || (type === 'Scalar' && pytorch.Utility.isTensor(argument))) { - if (Array.isArray(argument) || (!pytorch.Utility.isTensor(argument) && argument !== null && argument !== undefined)) { - if (optional) { - continue; - } - next = true; - } else { - copyArgs.shift(); - copyEvalArgs.shift(); - } - } else */ - let type = parameter.type; + const arg = parameters.shift(); + const value = evalArgs[position]; + let type = arg.type; let optional = false; - if (parameter.type.endsWith('?')) { - type = parameter.type.substring(0, parameter.type.length - 1); + if (arg.type.endsWith('?')) { + type = arg.type.substring(0, arg.type.length - 1); optional = true; } if (optional === true && (type === 'float32' || type === 'boolean' || type === 'int64' || type === 'complex' || type === 'ScalarType' || type === 'Device' || type === 'Layout') && - argument instanceof torch.Value && argument.type() instanceof torch.NoneType) { - copyArgs.shift(); - copyEvalArgs.shift(); - } else if (type === 'Tensor[]') { - const [argument] = copyEvalArgs; - if ((argument instanceof torch.Value && pytorch.Utility.toType(argument.type()) === 'Tensor[]') || - (Array.isArray(argument) && argument.every((item) => pytorch.Utility.isTensor(item) || item === null || (item instanceof torch.Value && item.type() instanceof torch.TensorType)))) { - copyArgs.shift(); - copyEvalArgs.shift(); - } else { - if (optional) { - continue; - } - next = true; + value instanceof torch.Value && value.type() instanceof torch.NoneType) { + position++; + } else if (!this.isType(value, type) && value !== null) { + if (optional) { + continue; } - /* } else if (type === 't[]') { - if (!Array.isArray(argument) && (argument instanceof torch.Value === false || argument.type() instanceof torch.ListType === false)) { - if (optional) { - continue; - } - next = true; - } else { - copyArgs.shift(); - copyEvalArgs.shift(); - }*/ + next = true; + break; + } else if (args[position].type === '=') { + next = true; + break; } else { - const [arg] = copyArgs; - if (!this.isType(argument, type) && argument !== null) { + position++; + } + } + if (next) { + continue; + } + if (args.every((arg, index) => index < position || (arg.type === '=' && arg.target && arg.target.type === 'id'))) { + const map = new Map(parameters.map((parameter) => [parameter.name, parameter])); + while (position < args.length) { + const value = evalArgs[position]; + const arg = map.get(args[position].target.value); + position++; + if (!arg) { + next = true; + break; + } + if (arg.kwarg_only) { + kwarg_only = true; + } + let type = arg.type; + let optional = false; + if (arg.type.endsWith('?')) { + type = arg.type.substring(0, arg.type.length - 1); + optional = true; + } + if (!this.isType(value, type)) { if (optional) { continue; } next = true; - } else if (arg.type === '=') { - next = true; - // throw new pytorch.Error('Expected named argument.'); - } else { - copyArgs.shift(); - copyEvalArgs.shift(); + break; } } - if (next) { - break; - } } if (next) { continue; } - if (!kwarg_only && parameters.some((parameter) => parameter.default === undefined)) { + if (position < evalArgs.length && !schema.is_vararg && !schema.name.startsWith('_caffe2::')) { continue; } - for (let i = 0; i < schema.outputs.length; i++) { - const parameter = schema.outputs[i]; - switch (parameter.type) { - case 'Scalar': - case 'Tensor': - case 'Tensor[]': - case 'float32': - case 'float32[]': - case 'int64': - case 'int64[]': - case 'Device': - case 'boolean': - case 'boolean[]': - case 't': - case 't[]': - case 'complex': - case 'complex[]': - case 'string': - case 'string[]': - case 'Dict(string, Tensor)': - case 'Dict(Tensor, t)': - case 'Dict(boolean, t)': - case 'Dict(complex, t)': - case 'Dict(float32, t)': - case 'Dict(int64, t)': - case 'Dict(string, t)': - case 'Dict(Tensor, tVal)': - case 'Dict(boolean, tVal)': - case 'Dict(complex, tVal)': - case 'Dict(float32, tVal)': - case 'Dict(int64, tVal)': - case 'Dict(string, tVal)': - case '(string, t)[]': - case 'Any': - break; - case '__torch__.torch.classes.xnnpack.LinearOpContext': - case '__torch__.torch.classes.xnnpack.Conv2dOpContext': - case '__torch__.torch.classes.xnnpack.TransposeConv2dOpContext': - case '__torch__.torch.classes.rnn.CellParamsBase': - case '__torch__.torch.classes.quantized.LinearPackedParamsBase': - case '__torch__.torch.classes.quantized.Conv2dPackedParamsBase': - case '__torch__.torch.classes.quantized.Conv3dPackedParamsBase': - break; - default: { - throw new pytorch.Error(`Unknown return type '${parameter.type}'.`); - } - } - } - if (next) { + if (!kwarg_only && parameters.some((parameter) => parameter.default === undefined)) { continue; } matches.push(schema);