From 4073049abd39d5722a71de6f81473477351bea7b Mon Sep 17 00:00:00 2001 From: Lutz Roeder Date: Sun, 8 Dec 2024 23:31:08 -0500 Subject: [PATCH] Update pytorch.js (#1061) --- source/python.js | 520 ++++++++++++++++++++++++++++++++++++++++------ source/pytorch.js | 63 +++++- 2 files changed, 518 insertions(+), 65 deletions(-) diff --git a/source/python.js b/source/python.js index 9c5493b7d8..fb6bc2aeb8 100644 --- a/source/python.js +++ b/source/python.js @@ -485,6 +485,26 @@ python.Execution = class { this.type_params = type_params; } }); + this.registerType('ast.arguments', class extends ast.AST { + constructor(posonlyargs, args, vararg, kwonlyargs, kw_defaults, kwarg, defaults) { + super(); + this.posonlyargs = posonlyargs; + this.args = args; + this.vararg = vararg; + this.kwonlyargs = kwonlyargs; + this.kw_defaults = kw_defaults; + this.kwarg = kwarg; + this.defaults = defaults; + } + }); + this.registerType('ast.arg', class extends ast.AST { + constructor(arg, annotation, type_comment) { + super(); + this.arg = arg; + this.annotation = annotation; + this.type_comment = type_comment; + } + }); this.registerType('ast.Import', class extends ast.stmt { constructor(names) { super(); @@ -1463,31 +1483,6 @@ python.Execution = class { } return null; } - _parameter(terminal) { - const node = this._node('parameter'); - if (this._tokenizer.eat('/')) { - node.name = '/'; - return node; - } - if (this._tokenizer.eat('**')) { - node.parameterType = '**'; - } - if (this._tokenizer.eat('*')) { - node.parameterType = '*'; - } - const name = this._name(); - if (name !== null) { - node.name = name.id; - if (terminal !== ':' && this._tokenizer.eat(':')) { - node.parameterType = this._type(); - } - if (this._tokenizer.eat('=')) { - node.initializer = this._expression(); - } - return node; - } - return null; - } _parameters(terminal) { const list = []; while (!this._tokenizer.eat(terminal)) { @@ -1495,7 +1490,29 @@ python.Execution = class { if (this._tokenizer.eat('(')) { list.push(this._parameters(')')); } else { - list.push(this._parameter(terminal)); + const node = this._node('parameter'); + if (this._tokenizer.eat('/')) { + node.name = '/'; + } else { + if (this._tokenizer.eat('**')) { + node.annotation = '**'; + } + if (this._tokenizer.eat('*')) { + node.annotation = '*'; + } + const name = this._name(); + if (name === null) { + throw new python.Error(`Expected parameter ${this._tokenizer.location()}`); + } + node.name = name.id; + if (terminal !== ':' && this._tokenizer.eat(':')) { + node.annotation = this._type(); + } + if (this._tokenizer.eat('=')) { + node.initializer = this._expression(); + } + } + list.push(node); } this._tokenizer.eat('\n'); if (!this._tokenizer.eat(',')) { @@ -4970,6 +4987,122 @@ python.Execution = class { } return null; }); + this.registerType('torch.jit.MatchedSchema', class { + constructor(inputs, return_types, return_field_names, schema_name) { + this.inputs = inputs; + this.return_types = return_types; + this.register_field_names = return_field_names; + this.schema_name = schema_name; + } + }); + this.registerType('torch.jit.Self', class { + }); + this.registerType('torch.jit.SimpleSelf', class extends torch.jit.Self { + constructor(classType) { + super(); + this._classType = classType; + } + getClassType() { + return this._classType; + } + }); + this.registerType('torch.jit.Function', class { + isGraphFunction() { + return false; + } + name() { + return this.qualname().name(); + } + }); + this.registerType('torch.jit.BuiltinOpFunction', class extends torch.jit.Function { + constructor(qualname, schema) { + super(); + this._name = qualname; + this._schema = schema; + } + qualname() { + return this._name; + } + getSchema() { + return this._schema; + } + }); + this.registerType('torch.jit.GraphFunction', class extends torch.jit.Function { + constructor(name, graph, function_creator, executor_execution_mode) { + super(); + this._name = name; + this._graph = graph; + this._executor_execution_mode = executor_execution_mode; + this._function_creator = function_creator; + } + isGraphFunction() { + return true; + } + qualname() { + return this._name; + } + graph() { + return this._graph; + } + placeholderCreator() { + throw new python.Error('Recursive method call.'); + } + ensure_defined() { + if (this._function_creator) { + const creator = this._function_creator; + this._function_creator = this.placeholderCreator; + creator(this); + this._function_creator = null; + } + this.check_single_output(); + } + check_single_output() { + // if (this.graph().outputs().length !== 1) { + // throw new python.Error('Graph must have a single output.'); + // } + } + getSchema() { + this._schema = this._schema || this.defaultSchemaFor(this); + return this._schema; + } + setSchema(schema) { + this._schema = schema; + } + num_inputs() { + return this.graph().inputs().length; + } + unshapedType(type) { + if (type.isSubtypeOf(torch.TensorType.get())) { + return torch.TensorType.get(); + } + throw new python.Error('Not implemented.'); + /* + at::ArrayRef contained = type->containedTypes(); + if (contained.empty()) { + return type; + } + return type->withContained(fmap(type->containedTypes(), unshapedType)); + */ + } + defaultSchemaFor(fn) { + const args = []; + const returns = []; + const g = fn.graph(); + const num_inputs = fn.num_inputs(); + for (let i = 0; i < num_inputs; i++) { + const v = g.inputs()[i]; + const name = v.hasDebugName() ? v.debugNameBase() : `argument_${i}`; + const argument = new torch.Argument(name, this.unshapedType(g.inputs()[i].type())); + args.push(argument); + } + const num_outputs = g.outputs().length; + for (let i = 0; i < num_outputs; i++) { + const argument = new torch.Argument('', this.unshapedType(g.outputs()[i].type())); + returns.push(argument); + } + return new torch.FunctionSchema(fn.name(), '', args, returns); + } + }); this.registerType('torch.ao.quantization.fake_quantize.FakeQuantize', class {}); this.registerType('torch.ao.quantization.fake_quantize.FusedMovingAvgObsFakeQuantize', class {}); this.registerType('torch.ao.quantization.observer._PartialWrapper', class {}); @@ -6800,6 +6933,7 @@ python.Execution = class { this._attributes = new Map(); this._methods = new Map(); this._staticmethods = new Map(); + this._constants = new Map(); } qualified_name() { return this.annotation_str; @@ -6811,18 +6945,25 @@ python.Execution = class { return this._is_module; } addMethod(func) { - this._methods.set(func.name, func); + this._methods.set(func.name(), func); } findMethod(name) { return this._methods.get(name); } + getMethod(name) { + const method = this.findMethod(name); + if (!method) { + throw new python.Error(`Method '${name}' not found on class '${this.str()}.`); + } + return method; + } addStaticMethod(func) { this._staticmethods.set(func.name, func); } findStaticMethod(name) { return this._staticmethods.get(name); } - addAttribute(name, type) { + addAttribute(name, type /*, is_parameter, is_buffer */) { this._attributes.set(name, type); } findAttribute(name) { @@ -6834,6 +6975,10 @@ python.Execution = class { hasConstant(/* name */) { } methods() { + throw new python.Error('Not implemented.'); + } + addConstant(name, value) { + this._constants.set(name, value); } str() { return this.qualified_name(); @@ -7336,7 +7481,7 @@ python.Execution = class { } } }); - this.registerType('torch._C.SchemaTypeParser', class { + this.registerType('torch.jit.SchemaTypeParser', class { constructor(L) { this.L = L; } @@ -7539,7 +7684,7 @@ python.Execution = class { return this.default_value !== undefined; } static parse(L, is_return, kwarg_only) { - const type_parser = new torch._C.SchemaTypeParser(L); + const type_parser = new torch.jit.SchemaTypeParser(L); let [fake_type, real_type, alias_info] = type_parser.parseFakeAndRealType(); L.whitespace(0); let N = null; @@ -7847,6 +7992,47 @@ python.Execution = class { return list.join(''); } }); + this.registerType('torch.jit.ScriptTypeParser', class { + constructor(resolver) { + this._resolver = resolver; + } + parseSchemaFromDef(def, skip_self) { + const name = def.name; + const args = this.parseArgsFromDecl(def, skip_self); + const returns = this.parseReturnFromDecl(def); + return new torch.FunctionSchema(name, '', args, returns, false, false); + } + parseArgsFromDecl(decl, skip_self) { + const retval = []; + const params = skip_self ? decl.args.slice(1) : decl.args.slice(); + for (const decl_arg of params) { + const N = null; + const default_value = null; + const type = decl_arg.annotation ? this.parseTypeFromExpr(decl_arg.annotation) : null; + const arg = new torch.Argument(decl_arg.name, type, type, N, default_value, /* decl_arg.kwarg_only() */ false, null); + retval.push(arg); + } + } + parseReturnFromDecl(decl) { + /* + if (!decl.return_type().present()) + return {}; + */ + // if (parseBroadcastList(decl.return_type().get())) + // throw ErrorReport(decl.return_type().range()) << "Broadcastable lists cannot appear as a return type"; + const parsed_type = this.parseTypeFromExpr(decl.returns); + return [new torch.Argument('', parsed_type, parsed_type, null, null, false)]; + } + parseTypeFromExpr(expr) { + if (expr instanceof ast.Name) { + const type = this._resolver.resolveType(expr.id); + if (type) { + return type; + } + } + return this._resolver._cu.execution.type(expr); + } + }); this.registerType('torch._ops.OpOverload', class extends torch._ops.OperatorBase { constructor(overloadpacket, op, op_dk, schema, tags) { super(); @@ -8080,6 +8266,9 @@ python.Execution = class { return_node() { return this._block.return_node(); } + block() { + return this._block; + } addInput(name) { return this._block.addInput(name); } @@ -8118,8 +8307,16 @@ python.Execution = class { } return n.output(); } + insertMethodCall(method_name, matched) { + const n = this.create('prim::CallMethod', matched.inputs); + this.insertNode(n); + n.s_('name', method_name); + n.output().setType(matched.return_types[0]); + return n; + } insertUncheckedCast(v, type) { - const n = this.insertNode(this.create('prim::unchecked_cast', [v])); + const n = this.create('prim::unchecked_cast', [v]); + this.insertNode(n); n.output().setType(type); return n.output(); } @@ -8182,6 +8379,12 @@ python.Execution = class { this._all_blocks.splice(index, 1); } } + set_op_version(version) { + this._op_version = version; + } + get_op_version() { + return this._op_version; + } }); this.registerType('torch.Block', class { constructor(graph) { @@ -8535,7 +8738,13 @@ python.Execution = class { } }); this.registerType('torch.jit.QualifiedName', class { - constructor(name) { + constructor(...args) { + let name = null; + if (args.length === 1 && typeof args[0] === 'string') { + [name] = args; + } else { + name = `${args[0].qualifiedName()}.${args[1]}`; + } const index = name.lastIndexOf('.'); this._qualifiedName = name; this._prefix = index === -1 ? '' : name.substring(0, index); @@ -8601,20 +8810,108 @@ python.Execution = class { throw new python.Error('TorchScript does not support class inheritance.'); } } - importClass(qualified_name, class_def, is_module) { - if (qualified_name.prefix().startsWith('__torch__.torch.classes')) { + importClass(qualified_classname, class_def, is_module) { + if (qualified_classname.prefix().startsWith('__torch__.torch.classes')) { return; } - const class_type = new torch.ClassType(qualified_name.qualifiedName(), this._cu, is_module); - for (const entry of class_def.body) { - if (entry instanceof ast.AnnAssign) { - const target = this._cu.execution.identifier(entry.target); - const annotation = this._cu.execution.type(entry.annotation, null); - class_type.addAttribute(target, annotation); + const parameter_names = new Set(); + const buffer_names = new Set(); + const methods = []; + const method_resolvers = []; + const attributes = []; + const constants = []; + const pre_hook_names = new Set(); + const pre_hook_def_map = new Map(); + const hook_names = new Set(); + const hook_def_map = new Map(); + const class_type = new torch.ClassType(qualified_classname.qualifiedName(), this._cu, is_module); + for (const stmt of class_def.body) { + if (stmt instanceof ast.Assign || stmt instanceof ast.AnnAssign) { + let target = null; + let annotation = null; + let value = null; + if (stmt instanceof ast.Assign) { + target = stmt.targets; + value = stmt.value; + } else { + target = stmt.target; + annotation = stmt.annotation; + value = stmt.value; + } + if (target instanceof ast.Name) { + const name = this._cu.execution.identifier(target); + switch (name) { + case '__annotations__': { + continue; + } + case '__parameters__': { + for (const elt of value.elts) { + parameter_names.add(elt.value); + } + break; + } + case '__buffers__': { + for (const elt of value.elts) { + buffer_names.add(elt.value); + } + break; + } + case '__forward_pre_hooks__': { + for (const elt of value.elts) { + pre_hook_names.add(elt.value); + } + break; + } + case '__forward_hooks__': { + for (const elt of value.elts) { + hook_names.add(elt.value); + } + break; + } + default: { + if (value) { + constants.push({ name, value, annotation }); + } else { + attributes.push({ name, value, annotation }); + } + break; + } + } + } else if (target instanceof ast.Subscript) { + // not implemented + continue; + } else { + throw new python.Error('Unexpected statement kind in module metadata.'); + } + } else if (stmt instanceof ast.FunctionDef) { + const def = stmt; + const def_name = def.name; + if (pre_hook_names.has(def_name)) { + pre_hook_def_map.set(def_name, def); + } else if (hook_names.has(def_name)) { + hook_def_map.set(def_name, def); + } else { + methods.push(def); + method_resolvers.push(this); + } + } else { + throw new python.Error('Unexpected statement kind in class body.'); } } + for (const assign of attributes) { + const name = assign.name; + const annotation = this._cu.execution.type(assign.annotation, null); + const is_parameter = parameter_names.has(name); + const is_buffer = buffer_names.has(name); + class_type.addAttribute(name, annotation, is_parameter, is_buffer); + } + for (const constant of constants) { + class_type.addConstant(constant.name, constant.value); + } // debugger; this._cu.register_type(class_type); + const self = new torch.jit.SimpleSelf(class_type); + this._cu.define(qualified_classname, [], [], methods, method_resolvers, self, false, this._version); } importNamedTuple(qualified_name, named_tuple_def) { const field_names = []; @@ -8685,7 +8982,12 @@ python.Execution = class { execution.builtins.CONSTANTS = {}; execution._resolver = this._source_importer; const known_types = [ - { name: '__torch__.torch.classes._nnapi.Compilation', methods: ['init(Tensor serialized_model_tensor, Tensor[] parameter_buffers) -> ()'] }, + { name: '__torch__.torch.classes._nnapi.Compilation', methods: [ + '__init__(__torch__.torch.classes._nnapi.Compilation self) -> NoneType', + 'init(__torch__.torch.classes._nnapi.Compilation self, Tensor serialized_model_tensor, Tensor[] parameter_buffers) -> NoneType', + 'init2(__torch__.torch.classes._nnapi.Compilation self, Tensor serialized_model_tensor, Tensor[] parameter_buffers, int compilation_preference, bool relax_f32_to_f16) -> NoneType', + 'run(__torch__.torch.classes._nnapi.Compilation self, Tensor[] inputs, Tensor[] outputs) -> NoneType' + ] }, { name: '__torch__.torch.classes.quantized.Conv2dPackedParamsBase' }, { name: '__torch__.torch.classes.quantized.Conv3dPackedParamsBase' }, { name: '__torch__.torch.classes.quantized.LinearPackedParamsBase' }, @@ -8695,9 +8997,13 @@ python.Execution = class { { name: '__torch__.torch.classes.xnnpack.TransposeConv2dOpContext' }, ]; for (const known_type of known_types) { + const prefix = new torch.jit.QualifiedName(known_type.name); const type = new torch.ClassType(known_type.name, this._compilation_unit, false); for (const known_method of known_type.methods || []) { - type.addMethod(new torch.FunctionSchema(known_method)); + const schema = new torch.FunctionSchema(known_method); + const name = new torch.jit.QualifiedName(prefix, schema.name); + const fn = new torch.jit.BuiltinOpFunction(name, schema); + type.addMethod(fn); } this._compilation_unit.register_type(type); } @@ -8962,13 +9268,38 @@ python.Execution = class { } return null; }; - const inlineCallTo = (/* to_replace, callee, use_graph */) => { + const inlineCallTo = (to_replace, callee, inline_optimized_graph) => { + const graph = inline_optimized_graph ? callee.optimized_graph() : callee.graph(); + return inlineCallTo(to_replace, callee, graph.get()); }; + const Inline = () => {}; + const GRAPH_UPDATE = () => {}; const inlineCalls = (block) => { for (const cur of block.nodes()) { switch (cur.kind()) { case 'prim::CallFunction': { - throw new python.Error(); + const graphFunction = tryToGraphFunction(cur); + if (graphFunction) { + const function_constant = cur.input(0).node(); + const fun_type = function_constant.output().type().expect(torch.FunctionType); + cur.removeInput(0); + GRAPH_UPDATE("Inlining function '", fun_type.function().name(), "' to ", cur); + let g = null; + const fallback = function_constant.hasAttribute('fallback'); + if (fallback && graphFunction.get_executor().isOptimized()) { + const exec_plans = graphFunction.get_executor().getDebugState().execution_plans; + if (!exec_plans.empty()) { + g = exec_plans.begin().second.graph; + Inline(g); + } + } + if (g === null) { + g = graphFunction.optimized_graph(); + } + // GRAPH_UPDATE("Function body: ", g); + inlineCallTo(cur, graphFunction, g.get()); + } + break; } case 'prim::CallMethod': { const graphFunction = tryToGraphFunction(cur); @@ -9055,12 +9386,12 @@ python.Execution = class { return new torch.ScriptObject(type); } _type() { - return this._type; + return this._type; // torch.ClassType } _get_method(name) { - for (const method of this._type.methods()) { - if (name === method.name) { - return method; + for (const fn of this._type.methods()) { + if (name === fn.name) { + return new torch.ScriptMethod(this /* _value() */, fn); } } return null; @@ -9111,12 +9442,16 @@ python.Execution = class { if (!this.data.forward) { throw new python.Error("Module 'forward' not implemented."); } - const args = [this.data]; // self + execution.traceAttr = false; + const args = []; + if (!execution.traceAttr) { + args.push(this.data); // self + } if (this.data.forward.__code__ && this.data.forward.__code__.args) { for (const arg of this.data.forward.__code__.args) { - if (arg.name !== 'self') { + if (execution.traceAttr || arg.name !== 'self') { const value = execution.graph.addInput(arg.name); - value.setType(execution.type(arg.parameterType)); + value.setType(execution.type(arg.annotation)); if (isTensor(value)) { value.__variable__ = arg.name; value.__origin__ = 'graph-input'; @@ -9183,6 +9518,21 @@ python.Execution = class { return this._items; } }); + this.registerType('torch.jit.to_ir', class { + constructor(def, _resolver, self, method) { + this.method = method; + this.graph = method.graph(); + this.resolver = _resolver; + this._typeParser = new torch.jit.ScriptTypeParser(this.resolver); + const schema = this.emitDef(def, self, this.graph.block()); + method.setSchema(schema); + } + emitDef(def, self /*, block */) { + const schema = this._typeParser.parseSchemaFromDef(def, self !== null); + // this.resolver._cu.execution + return schema; + } + }); this.registerType('torch.jit.CompilationUnit', class { constructor() { this._functions = new Map(); @@ -9192,16 +9542,66 @@ python.Execution = class { this._classes.set(namedType.annotation_str, namedType); } register_function(fn) { - this._functions.set(fn.name, fn); - } - define(prefix, properties, propResolvers, definitions /*, defResolvers, self, shouldMangle, operator_set_version */) { - for (const def of definitions) { - const name = def.name; - const qualified_name = prefix ? `${prefix}.${name}` : name; - const graph = new torch.Graph(); - const fn = new torch.ScriptFunction(qualified_name, graph, null); - this.register_function(fn); + this._functions.set(fn.name(), fn); + return fn; + } + define(prefix, ...args) { + if (Array.isArray(args[0])) { + const [/* properties */, /* propResolvers */, definitions, defResolvers, self, shouldMangle, operator_set_version] = args; + const function_table = new Map(); + const functions = []; + const record_function = (fn) => { + function_table.set(fn.name(), fn); + functions.push(fn); + this.register_function(fn); + }; + // properties + for (let i = 0; i < definitions.length; i++) { + const fn = this.define(prefix, definitions[i], defResolvers[i], self, function_table, shouldMangle, 'method', operator_set_version); + record_function(fn); + } + for (const [name, fn] of function_table) { + if (name === '__init__') { + fn.ensure_defined(); + } + } + for (const fn of functions) { + fn.ensure_defined(); + } + return functions; + } + const [def, resolver, self, function_table, shouldMangle, type, operator_set_version] = args; + let _resolver = resolver; + if (!self) { + _resolver = new torch._C.FunctionResolver(resolver.get(), function_table); + } + const creator = (method) => { + // let call_name = method.qualname().name(); + // if (self) { + // const atoms = method.qualname().atoms(); + // // TORCH_INTERNAL_ASSERT(atoms.size() >= 2); + // call_name = `${atoms.at(atoms.size() - 2)}.${atoms.at(atoms.size() - 1)}`; + // } + // this.call(call_name, def.range()); + return new torch.jit.to_ir(def, _resolver, self, method); + }; + const name = prefix ? new torch.jit.QualifiedName(prefix, def.name) : new torch.jit.QualifiedName(def.name); + const graph = new torch.Graph(); + graph.set_op_version(operator_set_version); + const fn = new torch.jit.GraphFunction(name, graph, creator); + if (shouldMangle && this.find_function(name)) { + // name = mangle(name); + } + if (self) { + if (type === 'hook') { + self.getClassType().addForwardHook(fn); + } else if (type === 'prehook') { + self.getClassType().addPreHook(fn); + } else { + self.getClassType().addMethod(fn); + } } + return fn; } get_type(name) { return this._classes.get(name.qualifiedName()); diff --git a/source/pytorch.js b/source/pytorch.js index 50e73c1ec2..d828eb9228 100644 --- a/source/pytorch.js +++ b/source/pytorch.js @@ -1929,6 +1929,15 @@ pytorch.Execution = class extends python.Execution { const target = this.target(func.value, context); return this._graph.insertToList(target, typehint); } + if (this.traceAttr) { + if (func instanceof ast.Name && func.id === 'getattr') { + const obj = this.expression(expr.args[0], context); + const field = this.expression(expr.args[1], context); + const n = this._graph.createGetAttr(obj, field); + this._graph.insertNode(n); + return n.output(); + } + } return super.expression(expr, context); } case 'Subscript': { @@ -2258,6 +2267,20 @@ pytorch.Execution = class extends python.Execution { } } + emitSugaredExpr(tree, n_binders, type_hint) { + const ast = this.ast; + if (tree instanceof ast.Var) { + // + } else if (tree instanceof ast.Attribute) { + // + } else if (tree instanceof ast.Apply) { + // + } if (tree instanceof ast.Subscript) { + // + } + return this.emitSimpleExpr(tree, type_hint); + } + block(statements, context) { const ast = this.ast; const torch = this.torch; @@ -2432,13 +2455,26 @@ pytorch.Execution = class extends python.Execution { } if (stmt instanceof ast.For) { const range = stmt.location; - const node = this.create('prim::Loop', range, 0); - this._graph.insertNode(node); + const n = this._graph.insertNode(this.create('prim::Loop', range, 0)); const itrs = stmt.iter instanceof ast.Tuple ? stmt.iter.elts : [stmt.iter]; // const targets = stmt.target instanceof ast.Tuple ? stmt.target.elts : [stmt.target]; if (itrs.length !== 1) { throw new pytorch.Error('List of iterables is not supported currently.'); } + /* + // const sv = this.expression(itrs[0], context); + const sv = this.emitSugaredExpr(itrs[0], 1); + const iterable = sv.iter(range, method); + if (iterable.shouldEmitUnrolled()) { + this.emitUnrolledLoop(loc, emit_body, iterable, targets); + } else { + this.emitLoopCommon(loc, emit_body, iterable, targets, {}); + } + */ + + /* const body_block = */ n.addBlock(); + /* const condition_block = */ n.addBlock(); + const loop = stmt; if (loop.target instanceof ast.Name && loop.iter instanceof ast.Tuple === false) { const range = this.expression(loop.iter, context); @@ -2480,14 +2516,19 @@ pytorch.Execution = class extends python.Execution { } statement(stmt, context) { - const ast = this.ast; - const torch = this.torch; + if (stmt.__class__.__name__ === 'ClassDef') { + const name = `${context.get('__name__')}.${stmt.name}`; + this._resolver.resolveType(name); + } + if (!this.trace) { return super.statement(stmt, context); } + switch (stmt.__class__.__name__) { case 'ClassDef': { super.statement(stmt, context); + /* const value = context.get(stmt.name); const type = new torch.ClassType(`${value.__module__}.${value.__name__}`); for (const entry of stmt.body) { @@ -2498,6 +2539,7 @@ pytorch.Execution = class extends python.Execution { } } value.__type__ = type; + */ return undefined; } case 'If': { @@ -2551,9 +2593,11 @@ pytorch.Execution = class extends python.Execution { case 'bool': return torch.BoolType.get(); case 'list': return torch.Type.get('AnyListType'); case 'tuple': return torch.Type.get('AnyTupleType'); + case 'Device': return torch.DeviceObjType.get(); case 'None': return torch.NoneType.get(); case 'NoneType': return torch.NoneType.get(); - default: throw new pytorch.Error(`Unsupported type expression '${expr.value}'.`); + case 'Any': return torch.AnyType.get(); + default: throw new pytorch.Error(`Unsupported type expression '${expr.id}'.`); } } if (expr instanceof ast.Constant) { @@ -2615,15 +2659,24 @@ pytorch.Execution = class extends python.Execution { if (!overload) { const moduleTarget = this.target(target, context); if (moduleTarget instanceof torch.Value && moduleTarget.type() instanceof torch.ClassType) { + const class_type = moduleTarget.type().expect(torch.ClassType); + const method_name = name; + const method = class_type.getMethod(method_name); + const return_type = method.getSchema().returns[0].real_type; const node = this._graph.create('prim::CallMethod'); this._graph.insertNode(node); node.s_('name', name); + const inputs = []; const evalArgs = args.map((expression) => this.expression(expression, context)); for (const arg of evalArgs) { const value = this.variable(arg); + inputs.push(value); node.addInput(value); } + node.output().setType(return_type); return node.output(); + // const matchedSchema = new torch.jit.MatchedSchema(inputs, return_types, return_field_names, schema_name) + // return this._graph.insertMethodCall(name, matchedSchema); } const prefix = this.identifier(target); if (prefix && prefix !== 'self' && !prefix.startsWith('self.') && prefix.indexOf('.') !== -1) {