From 485efbdc0525e6f117b562acfac20313e494a217 Mon Sep 17 00:00:00 2001 From: Lutz Roeder Date: Wed, 13 Nov 2024 19:50:18 -0800 Subject: [PATCH] Update pytorch.js (#1061) --- source/python.js | 688 +++++++++++++++++++++++++++------------------- source/pytorch.js | 260 ++++++++---------- 2 files changed, 516 insertions(+), 432 deletions(-) diff --git a/source/python.js b/source/python.js index 2d01392503..03dc0a5cce 100644 --- a/source/python.js +++ b/source/python.js @@ -126,9 +126,15 @@ python.Execution = class { return this.__class__.__name__; } }); - this.registerType('ast.expr', class extends ast.AST { - }); - this.registerType('ast.stmt', class extends ast.AST { + this.registerType('ast.expr', class extends ast.AST {}); + this.registerType('ast.stmt', class extends ast.AST {}); + this.registerType('ast.excepthandler', class extends ast.AST {}); + this.registerType('ast.alias', class extends ast.AST { + constructor(name, asname) { + super(); + this.name = name; + this.asname = asname; + } }); this.registerType('ast.Name', class extends ast.expr { constructor(id, ctx) { @@ -139,6 +145,17 @@ python.Execution = class { } } }); + this.registerType('ast.Constant', class extends ast.expr { + constructor(value) { + super(); + this.value = value; + } + }); + this.registerType('ast.Ellipsis', class extends ast.Constant { + constructor() { + super(builtins.ellipsis); + } + }); this.registerType('ast.List', class extends ast.expr { constructor(elts, ctx) { super(); @@ -148,6 +165,12 @@ python.Execution = class { } } }); + this.registerType('ast.Set', class extends ast.expr { + constructor(elts) { + super(); + this.elts = elts; + } + }); this.registerType('ast.Tuple', class extends ast.expr { constructor(elts, ctx) { super(); @@ -157,6 +180,23 @@ python.Execution = class { } } }); + this.registerType('ast.Dict', class extends ast.expr { + constructor(keys, values) { + super(); + this.keys = keys; + this.values = values; + } + }); + this.registerType('ast.Subscript', class extends ast.expr { + constructor(value, slice, ctx) { + super(); + this.value = value; + this.slice = slice; + if (ctx) { + this.ctx = ctx; + } + } + }); this.registerType('ast.UnaryOp', class extends ast.expr { constructor(op, operand) { super(); @@ -182,6 +222,40 @@ python.Execution = class { } } }); + this.registerType('ast.Lambda', class extends ast.expr { + constructor(args, body) { + super(); + this.args = args; + this.body = body; + } + }); + this.registerType('ast.IfExp', class extends ast.expr { + constructor(test, body, orelse) { + super(); + this.test = test; + this.body = body; + this.orelse = orelse; + } + }); + this.registerType('ast.NamedExpr', class extends ast.expr { + constructor(target, value) { + super(); + this.target = target; + this.value = value; + } + }); + this.registerType('ast.Yield', class extends ast.expr { + constructor(value) { + super(); + this.value = value; + } + }); + this.registerType('ast.YieldFrom', class extends ast.expr { + constructor(value) { + super(); + this.value = value; + } + }); this.registerType('ast.Assign', class extends ast.stmt { constructor(targets, value, ctx) { super(); @@ -218,12 +292,43 @@ python.Execution = class { this.orelse = orelse; } }); + this.registerType('ast.While', class extends ast.stmt { + constructor(test, body, orelse /*, type_comment */) { + super(); + this.test = test; + this.body = body; + this.orelse = orelse; + } + }); + this.registerType('ast.Del', class extends ast.stmt { + constructor(targets) { + super(); + this.targets = targets; + } + }); this.registerType('ast.Return', class extends ast.stmt { constructor(value) { super(); this.value = value; } }); + this.registerType('ast.Try', class extends ast.stmt { + constructor(body, handlers, orelse, finalbody) { + super(); + this.body = body; + this.handlers = handlers; + this.orelse = orelse; + this.finalbody = finalbody; + } + }); + this.registerType('ast.ExceptHandler', class extends ast.excepthandler { + constructor(type, name, body) { + super(); + this.type_ = type; + this.name = name; + this.body = body; + } + }); this.registerType('ast.ClassDef', class extends ast.stmt { constructor(name, bases, keywords, body, decorator_list, type_params) { super(); @@ -275,6 +380,26 @@ python.Execution = class { this.cause = cause; } }); + this.registerType('ast.With', class extends ast.stmt { + constructor(items, body, type_comment) { + super(); + this.items = items; + this.body = body; + this.type_comment = type_comment; + } + }); + this.registerType('ast.Global', class extends ast.stmt { + constructor(names) { + super(); + this.names = names; + } + }); + this.registerType('ast.Nonlocal', class extends ast.stmt { + constructor(names) { + super(); + this.names = names; + } + }); this.registerType('ast.Continue', class extends ast.stmt {}); this.registerType('ast.Break', class extends ast.stmt {}); this.registerType('ast.Pass', class extends ast.stmt {}); @@ -364,23 +489,23 @@ python.Execution = class { _statement() { let node = null; let position = null; - position = this._eat_('id', 'break'); + position = this._eat('id', 'break'); if (position) { node = new ast.Break(); return this._complete(node, position); } - position = this._eat_('id', 'continue'); + position = this._eat('id', 'continue'); if (position) { node = new ast.Continue(); return this._complete(node, position); } - position = this._eat_('id', 'return'); + position = this._eat('id', 'return'); if (position) { const value = this._expression(-1, [], true); node = new ast.Return(value); return this._complete(node, position); } - position = this._eat_('id', 'raise'); + position = this._eat('id', 'raise'); if (position) { let exc = this._expression(-1, ['from']); let cause = null; @@ -406,59 +531,46 @@ python.Execution = class { node = new ast.Assert(test, msg); return this._complete(node, position); } - node = this._eat('id', 'exec'); - if (node) { - node.variable = this._expression(-1, ['in']); - if (this._tokenizer.eat('in')) { - do { - node.target = node.target || []; - node.target.push(this._expression(-1, ['in'], false)); - } - while (this._tokenizer.eat(',')); - } - return node; - } - node = this._eat('id', 'global'); - if (node) { - node.names = []; + position = this._eat('id', 'global'); + if (position) { + const names = []; do { const name = this._name(true); - node.names.push(name.id); + names.push(name.id); } while (this._tokenizer.eat(',')); - return node; + node = new ast.Global(names); + return this._complete(node, position); } - node = this._eat('id', 'nonlocal'); - if (node) { - node.names = []; + position = this._eat('id', 'nonlocal'); + if (position) { + const names = []; do { const name = this._name(true); - node.names.push(name.id); + names.push(name.id); } while (this._tokenizer.eat(',')); - return node; + node = new ast.Nonlocal(names); + return this._complete(node, position); } - node = this._eat('id', 'import'); - if (node) { - const location = node.location; + position = this._eat('id', 'import'); + if (position) { const names = []; do { - const alias = this._node('alias'); - alias.name = this._dottedName(); + const name = this._dottedName(); + let asname = null; if (this._tokenizer.eat('id', 'as')) { - const name = this._name(true); - alias.asname = name.id; + asname = this._name(true).id; } - names.push(alias); + const node = new ast.alias(name, asname); + names.push(node); } while (this._tokenizer.eat(',')); node = new ast.Import(names); - node.location = location; - return node; + return this._complete(node, position); } - node = this._eat('id', 'from'); - if (node) { - const location = node.location; + position = this._eat('id', 'from'); + if (position) { let level = 0; const dots = this._tokenizer.peek(); if (dots && Array.from(dots.type).every((c) => c === '.')) { @@ -470,53 +582,46 @@ python.Execution = class { const names = []; const close = this._tokenizer.eat('('); do { - const alias = this._node('alias'); - const name = this._name(true); - alias.name = name.id; + const name = this._name(true).id; + let asname = null; if (this._tokenizer.eat('id', 'as')) { - const name = this._name(true); - alias.asname = name.id; + asname = this._name(true).id; } - names.push(alias); + node = new ast.alias(name, asname); + names.push(node); } while (this._tokenizer.eat(',')); if (close) { this._tokenizer.expect(')'); } node = new ast.ImportFrom(module, names, level); - node.location = location; - return node; + return this._complete(node, position); } let decorator_list = this._decorator(); - node = this._eat('id', 'class'); - if (node) { - const location = node.location; + position = this._eat('id', 'class'); + if (position) { const name = this._name(true); if (decorator_list) { - node.decorator_list = Array.from(decorator_list); - decorator_list = null; + decorator_list = Array.from(decorator_list); } const bases = this._tokenizer.peek().type === '(' ? this._arguments() : []; this._tokenizer.expect(':'); const body = this._suite(); node = new ast.ClassDef(name.id, bases, null, body, decorator_list, null); - node.location = location; - return node; + return this._complete(node, position); } - const async = this._eat('id', 'async'); + const async = this._eat('id', 'async') !== null; if (async && !this._tokenizer.match('id', 'def') && !this._tokenizer.match('id', 'with') && !this._tokenizer.match('id', 'for')) { throw new python.Error(`Expected 'def', 'with' or 'for' ${this._tokenizer.location()}`); } - node = this._eat('id', 'def'); - if (node) { - const location = node.location; + position = this._eat('id', 'def'); + if (position) { const name = this._name(true); if (decorator_list) { - node.decorator_list = Array.from(decorator_list); - decorator_list = null; + decorator_list = Array.from(decorator_list); } this._tokenizer.expect('('); const args = this._parameters(')'); @@ -530,30 +635,23 @@ python.Execution = class { if (async) { node.async = async; } - node.location = location; - return node; + return this._complete(node, position); } if (decorator_list && decorator_list.length > 0) { throw new python.Error('Unexpected decorator.'); } - node = this._eat('id', 'del'); - if (node) { - node.expression = this._expression(-1, [], true); - return node; + position = this._eat('id', 'del'); + if (position) { + const targets = this._expression(-1, [], true); + node = new ast.Del(targets); + return this._complete(node, position); } - // node = this._eat('id', 'print'); - // if (node) { - // node.expression = this._expression(-1, [], true); - // return node; - // } - node = this._eat('id', 'if'); - if (node) { - const location = node.location; + position = this._eat('id', 'if'); + if (position) { const test = this._expression(); this._tokenizer.expect(':'); const body = this._suite(); node = new ast.If(test, body); - node.location = location; let current = node; this._tokenizer.eat('\n'); while (this._tokenizer.eat('id', 'elif')) { @@ -568,29 +666,28 @@ python.Execution = class { this._tokenizer.expect(':'); current.orelse = this._suite(); } - return node; + return this._complete(node, position); } - node = this._eat('id', 'while'); - if (node) { - node.test = this._expression(); + position = this._eat('id', 'while'); + if (position) { + const test = this._expression(); this._tokenizer.expect(':'); - node.body = this._suite(); + const body = this._suite(); + let orelse = null; if (this._tokenizer.eat('id', 'else')) { this._tokenizer.expect(':'); - node.orelse = this._suite(); + orelse = this._suite(); } - return node; + node = new ast.While(test, body, orelse); + return this._complete(node, position); } - node = this._eat('id', 'pass'); - if (node) { - const location = node.location; + position = this._eat('id', 'pass'); + if (position) { node = new ast.Pass(); - node.location = location; - return node; + return this._complete(node, position); } - node = this._eat('id', 'for'); - if (node) { - const location = node.location; + position = this._eat('id', 'for'); + if (position) { let target = this._expression(-1, ['in']); while (this._tokenizer.eat(',')) { if (target instanceof ast.Tuple === false) { @@ -622,15 +719,11 @@ python.Execution = class { orelse = this._suite(); } node = new ast.For(target, iter, body, orelse); - node.location = location; - return node; + return this._complete(node, position); } - node = this._eat('id', 'with'); - if (node) { - if (async) { - node.async = async; - } - node.item = []; + position = this._eat('id', 'with'); + if (position) { + const items = []; do { const item = this._node(); item.type = 'with_item'; @@ -638,22 +731,28 @@ python.Execution = class { if (this._tokenizer.eat('id', 'as')) { item.variable = this._expression(); } - node.item.push(item); + items.push(item); } while (this._tokenizer.eat(',')); this._tokenizer.expect(':'); - node.body = this._suite(); - return node; + const body = this._suite(); + node = new ast.With(items, body, null); + if (async) { + node.async = async; + } + return this._complete(node, position); } - node = this._eat('id', 'try'); - if (node) { + position = this._eat('id', 'try'); + if (position) { this._tokenizer.expect(':'); - node.body = this._suite(); - node.except = []; + const body = this._suite(); + const handlers = []; + let orelse = null; + let finalbody = null; while (this._tokenizer.match('id', 'except')) { - const except = this._node('except'); this._tokenizer.expect('id', 'except'); - except.clause = []; + const type = this._expression(); + /* except.clause.push(this._expression()); while (this._tokenizer.eat(',')) { if (this._tokenizer.match(':') || this._tokenizer.match('as')) { @@ -662,42 +761,43 @@ python.Execution = class { } except.clause.push(this._expression()); } + */ + let name = null; if (this._tokenizer.eat('id', 'as')) { - except.variable = this._expression(); + name = this._expression(); } this._tokenizer.expect(':'); - except.body = this._suite(); - node.except.push(except); + const body = this._suite(); + const except = new ast.ExceptHandler(type, name, body); + handlers.push(except); } if (this._tokenizer.match('id', 'else')) { - node.orelse = this._node('else'); this._tokenizer.expect('id', 'else'); this._tokenizer.expect(':'); - node.orelse.body = this._suite(); + orelse = this._suite(); } if (this._tokenizer.match('id', 'finally')) { - node.finally = this._node('finally'); this._tokenizer.expect('id', 'finally'); this._tokenizer.expect(':'); - node.finally.body = this._suite(); + finalbody = this._suite(); } + node = new ast.Try(body, handlers, orelse, finalbody); + node = this._complete(node, position); return node; } const expr = this._expression(-1, [], true); if (expr) { if (expr instanceof ast.Name && this._tokenizer.eat(':')) { - const location = this._tokenizer.location(); + const position = this._position(); const annotation = this._expression(-1, ['=']); let value = null; if (this._tokenizer.eat('=')) { value = this._expression(); } node = new ast.AnnAssign(expr, annotation, value); - node.location = location; - return node; + return this._complete(node, position); } switch (expr.type) { - case ':=': case '==': case '!=': case '+=': @@ -720,16 +820,18 @@ python.Execution = class { case '>': case '%': case '^=': - case '...': + case 'Ellipsis': + case 'NamedExpr': case 'Call': case 'Assert': case 'Raise': - case 'string': case 'Assign': case 'AnnAssign': case 'Attribute': - case '[]': - case 'yield': + case 'Yield': + case 'Subscript': + case 'Name': + case 'Constant': case '+': case '-': case '*': case '/': case '**': case '@': @@ -737,15 +839,13 @@ python.Execution = class { case '~': case '&': case '^': case '|': case 'not': - case 'Name': - case 'number': case 'in': case 'and': case 'or': case 'If': case 'For': case 'List': case 'Tuple': - case 'lambda': + case 'Lambda': case 'Await': return expr; default: @@ -808,9 +908,10 @@ python.Execution = class { } } if (this._tokenizer.eat(':=')) { - node.type = ':='; - node.target = stack.pop(); - node.expression = this._expression(-1, terminal, tuple === false ? false : true); + const target = stack.pop(); + const value = this._expression(-1, terminal, tuple === false ? false : true); + node = new ast.NamedExpr(target, value); + node = this._complete(node, position); stack.push(node); continue; } @@ -846,22 +947,25 @@ python.Execution = class { default: break; } - node = this._eat('id', 'if'); - if (node) { - node.body = stack.pop(); - node.test = this._expression(); + position = this._eat('id', 'if'); + if (position) { + const body = stack.pop(); + const test = this._expression(); this._tokenizer.expect('id', 'else'); - node.orelse = this._expression(); + const orelse = this._expression(); + node = new ast.IfExp(test, body, orelse); + node = this._complete(node, position); stack.push(node); continue; } while (this._tokenizer.match('id', 'for') || this._tokenizer.match('id', 'async')) { - const async = this._eat('id', 'async'); + const async = this._eat('id', 'async') !== null; if (async && !this._tokenizer.match('id', 'for')) { throw new python.Error(`Expected 'for' ${this._tokenizer.location()}`); } - node = this._eat('id', 'for'); - if (node) { + position = this._eat('id', 'for'); + if (position) { + node = this._node('for'); if (async) { node.async = async; } @@ -876,28 +980,33 @@ python.Execution = class { stack.push(node); } } - node = this._eat('id', 'lambda'); - if (node) { - node.args = this._parameters(':'); - node.body = this._expression(-1, terminal, false); + position = this._eat('id', 'lambda'); + if (position) { + const args = this._parameters(':'); + const body = this._expression(-1, terminal, false); + node = new ast.Lambda(args, body); + node = this._complete(node, position); stack.push(node); continue; } - node = this._eat('id', 'yield'); - if (node) { + position = this._eat('id', 'yield'); + if (position) { if (this._tokenizer.eat('id', 'from')) { - node.from = this._expression(-1, [], true); + const value = this._expression(-1, [], true); + node = new ast.YieldFrom(value); + stack.push(node); } else { - node.expression = []; + const value = []; do { - node.expression.push(this._expression(-1, [], false)); + value.push(this._expression(-1, [], false)); } while (this._tokenizer.eat(',')); + node = new ast.Yield(value); + stack.push(node); } - stack.push(node); continue; } - position = this._eat_('id', 'await'); + position = this._eat('id', 'await'); if (position) { const value = this._expression(minPrecedence, terminal, tuple); node = new ast.Await(value); @@ -905,13 +1014,12 @@ python.Execution = class { stack.push(node); continue; } - node = this._eat('.'); - if (node) { - const location = node.location; + position = this._eat('.'); + if (position) { const value = stack.pop(); const attr = this._name().id; node = new ast.Attribute(value, attr); - node.location = location; + node = this._complete(node, position); stack.push(node); continue; } @@ -940,46 +1048,112 @@ python.Execution = class { if (stack.length === 0) { stack.push(this._expressions()); } else { - node = this._node('[]'); - node.target = stack.pop(); - node.arguments = this._slice(); + const value = stack.pop(); + const elts = this._slice(); + node = new ast.Subscript(value, elts); stack.push(node); } continue; } if (this._tokenizer.peek().type === '{') { - stack.push(this._dictOrSetMaker()); + const elts = []; + const keys = []; + const values = []; + this._tokenizer.expect('{'); + let dict = true; + while (!this._tokenizer.eat('}')) { + const item = this._expression(-1, [], false); + if (item === null) { + throw new python.Error(`Expected expression ${this._tokenizer.location()}`); + } + if (!this._tokenizer.eat(':')) { + dict = false; + } + if (dict) { + const value = this._expression(-1, [], false); + if (value === null) { + throw new python.Error(`Expected expression ${this._tokenizer.location()}`); + } + keys.push(item); + values.push(value); + } else { + elts.push(item); + } + this._tokenizer.eat(','); + this._tokenizer.eat('\n'); + if (this._tokenizer.eat('}')) { + break; + } + } + if (keys.length !== values.length || (keys.length > 0 && elts.length > 0)) { + throw new python.Error(`Invalid set expression ${this._tokenizer.location()}`); + } + let node = null; + if (elts.length > 0) { + node = new ast.Set(elts); + } else { + node = new ast.Dict(keys, values); + } + stack.push(node); continue; } node = this._node(); const literal = this._literal(); if (literal) { - if (stack.length > 0 && literal.type === 'number' && - (literal.value.startsWith('-') || literal.value.startsWith('+'))) { + if (stack.length > 0 && literal.type === 'number' && (literal.value.startsWith('-') || literal.value.startsWith('+'))) { node.type = literal.value.substring(0, 1); literal.value = literal.value.substring(1); node.left = stack.pop(); node.right = literal; stack.push(node); - } else if (stack.length === 1 && literal.type === 'string' && stack[0].type === 'string') { - stack[0].value += literal.value; + } else if (stack.length === 1 && literal.type === 'string' && stack[0] instanceof ast.Constant && typeof stack[0].value === 'string') { + stack[0].value += literal.value.substring(1, literal.value.length - 1); } else { + let value = literal.value; if (literal.type === 'number') { - switch (literal.value) { - case 'inf': literal.value = Infinity; break; - case '-inf': literal.value = -Infinity; break; - default: break; + switch (value) { + case 'inf': value = Infinity; break; + case '-inf': value = -Infinity; break; + default: value = Number(value); break; } + } else if (literal.type === 'string') { + value = literal.value.substring(1, literal.value.length - 1); + } else { + throw new python.Error(`Invalid literal ${this._tokenizer.location()}`); } - stack.push(literal); + const node = new ast.Constant(value); + stack.push(node); } continue; } + position = this._eat('id', 'False'); + if (position) { + node = new ast.Constant(false); + node = this._complete(node, position); + stack.push(node); + continue; + } + position = this._eat('id', 'True'); + if (position) { + node = new ast.Constant(true); + node = this._complete(node, position); + stack.push(node); + continue; + } + position = this._eat('id', 'None'); + if (position) { + node = new ast.Constant(null); + node = this._complete(node, position); + stack.push(node); + continue; + } if (this._tokenizer.peek().keyword) { break; } - node = this._eat('...'); - if (node) { + position = this._eat('...'); + if (position) { + node = new ast.Ellipsis(); + node = this._complete(node, position); stack.push(node); continue; } @@ -1036,38 +1210,6 @@ python.Execution = class { } return list; } - _dictOrSetMaker() { - const list = []; - this._tokenizer.expect('{'); - let dict = true; - while (!this._tokenizer.eat('}')) { - const item = this._expression(-1, [], false); - if (item === null) { - throw new python.Error(`Expected expression ${this._tokenizer.location()}`); - } - if (!this._tokenizer.eat(':')) { - dict = false; - } - if (dict) { - const value = this._expression(-1, [], false); - if (value === null) { - throw new python.Error(`Expected expression ${this._tokenizer.location()}`); - } - list.push({ type: 'pair', key: item, value }); - } else { - list.push(item); - } - this._tokenizer.eat(','); - this._tokenizer.eat('\n'); - if (this._tokenizer.eat('}')) { - break; - } - } - if (dict) { - return { type: 'dict', value: list }; - } - return { type: 'set', value: list }; - } _expressions() { const elts = []; this._tokenizer.expect('['); @@ -1169,12 +1311,8 @@ python.Execution = class { const target = this._expression(-1, ['[', '=']); if (target) { if (this._tokenizer.peek().value === '[') { - const type = this._node(); - type.type = '[]'; - type.target = target; - type.arguments = this._expressions(); - // type.arguments = this._typeArguments(); - return type; + const elts = this._expressions(); + return new ast.Subscript(target, elts); } return target; } @@ -1252,9 +1390,9 @@ python.Execution = class { } _eat(type, value) { if (this._tokenizer.match(type, value)) { - const node = this._node(type === 'id' ? value : type); + const position = this._position(); this._tokenizer.expect(type, value); - return node; + return position; } return null; } @@ -1266,14 +1404,6 @@ python.Execution = class { node.end_col_offset = this._tokenizer.col_offset; return node; } - _eat_(type, value) { - if (this._tokenizer.match(type, value)) { - const position = this._position(); - this._tokenizer.expect(type, value); - return position; - } - return null; - } _position() { return { location: this._tokenizer.location(), @@ -10957,7 +11087,7 @@ python.Execution = class { } throw new python.Error("Unsupported 'for' statement."); } - case 'while': { + case 'While': { const test = this.expression(stmt.test, context); if (test) { const value = this.block(stmt.body.statements, context); @@ -10967,9 +11097,9 @@ python.Execution = class { } break; } - case 'with': { + case 'With': { const items = []; - for (const item of stmt.item) { + for (const item of stmt.items) { items.push(this.expression(item.expression, context)); } for (const item of items) { @@ -11021,7 +11151,7 @@ python.Execution = class { } break; } - case 'string': { + case 'Constant': { break; } default: { @@ -11043,15 +11173,16 @@ python.Execution = class { const value = this.expression(expr.value, context); context.set(target.id, value); return undefined; - } else if (target.type === '[]') { - if (target.target instanceof ast.Name && - target.arguments instanceof ast.List && - target.arguments.elts.length === 1) { - const index = this.expression(target.arguments.elts[0], context); - if (target.target.id === '__annotations__') { - context.set(target.target.id, context.get(target.target.id) || {}); - } - const obj = context.get(target.target.id); + } else if (target instanceof ast.Subscript) { + if (target.value instanceof ast.Name && + target.slice instanceof ast.List && + target.slice.elts.length === 1) { + const index = this.expression(target.slice.elts[0], context); + const id = target.value.id; + if (id === '__annotations__') { + context.set(id, context.get(id) || {}); + } + const obj = context.get(id); const value = this.expression(expr.value, context); if (obj instanceof Map) { obj.set(index, value); @@ -11087,19 +11218,16 @@ python.Execution = class { case 'List': { return expr.elts.map((expr) => this.expression(expr, context)); } - case 'string': { - return expr.value.substring(1, expr.value.length - 1); - } - case 'number': { - return Number(expr.value); + case 'Constant': { + return expr.value; } - case '[]': { - if (expr.target instanceof ast.Name && - expr.arguments instanceof ast.List && - expr.arguments.elts.length === 1) { - const id = expr.target.id; + case 'Subscript': { + if (expr.value instanceof ast.Name && + expr.slice instanceof ast.List && + expr.slice.elts.length === 1) { + const id = expr.value.id; if (context.get(id)) { - const index = this.expression(expr.arguments.elts[0], context); + const index = this.expression(expr.slice.elts[0], context); const target = context.get(id); if (target instanceof Map) { return target.get(index); @@ -11107,21 +11235,21 @@ python.Execution = class { return target[index < 0 ? target.length + index : index]; } } - const target = this.expression(expr.target, context); - if (target && expr.arguments instanceof ast.List && - (target.__class__ === typing._TupleType || - target.__class__ === typing._SpecialGenericAlias || - target.__class__ === typing._SpecialForm)) { - const type = { ...target }; - type.__args__ = expr.arguments.elts.map((arg) => this.expression(arg, context)); + const value = this.expression(expr.value, context); + if (value && expr.slice instanceof ast.List && + (value.__class__ === typing._TupleType || + value.__class__ === typing._SpecialGenericAlias || + value.__class__ === typing._SpecialForm)) { + const type = { ...value }; + type.__args__ = expr.slice.elts.map((arg) => this.expression(arg, context)); return type; } - if (expr.arguments instanceof ast.List && expr.arguments.elts.length === 1) { - const index = this.expression(expr.arguments.elts[0], context); - if (target instanceof Map) { - return target.get(index); + if (expr.slice instanceof ast.List && expr.slice.elts.length === 1) { + const index = this.expression(expr.slice.elts[0], context); + if (value instanceof Map) { + return value.get(index); } - return target[index < 0 ? target.length + index : index]; + return value[index < 0 ? value.length + index : index]; } break; } @@ -11138,45 +11266,37 @@ python.Execution = class { } case 'Name': { const id = expr.id; - switch (id) { - case 'self': return self; - case 'None': return null; - case 'True': return true; - case 'False': return false; - default: { - const type = (value) => { - return value && - (value.__class__ === builtins.type || - value.__class__ === typing._TupleType || - value.__class__ === typing._SpecialGenericAlias || - value.__class__ === typing._SpecialForm); - }; - const builtin = builtins[id]; - if (type(builtin)) { - return builtin; - } - const value = context.get(id); - if (value === undefined) { - const value = typing[id]; - if (type(value)) { - return value; - } - } + if (id === 'self') { + return self; + } + const type = (value) => { + return value && + (value.__class__ === builtins.type || + value.__class__ === typing._TupleType || + value.__class__ === typing._SpecialGenericAlias || + value.__class__ === typing._SpecialForm); + }; + const builtin = builtins[id]; + if (type(builtin)) { + return builtin; + } + const value = context.get(id); + if (value === undefined) { + const value = typing[id]; + if (type(value)) { return value; } } + return value; } case 'Tuple': { return expr.elts.map((expr) => this.expression(expr, context)); } - case 'dict': { + case 'Dict': { const dict = {}; - for (const pair of expr.value) { - if (pair.type !== 'pair') { - throw new python.Error(`Unsupported dict item type '${pair.type}'.`); - } - const key = this.expression(pair.key, context); - const value = this.expression(pair.value, context); + for (let i = 0; i < expr.keys.length; i++) { + const key = this.expression(expr.keys[i], context); + const value = this.expression(expr.values[i], context); dict[key] = value; } return dict; diff --git a/source/pytorch.js b/source/pytorch.js index 2a3bd6083c..cddf863b99 100644 --- a/source/pytorch.js +++ b/source/pytorch.js @@ -1820,13 +1820,11 @@ pytorch.Execution = class extends python.Execution { const ast = this.ast; const torch = this.torch; switch (expr.type) { - case 'Name': { - switch (expr.id) { - case 'True': return this.constant(true); - case 'False': return this.constant(false); - default: break; + case 'Constant': { + if (expr.value === true || expr.value === false) { + return this.constant(expr.value); } - return super.expression(expr, context); + break; } case 'Assign': { const target = expr.targets; @@ -1952,31 +1950,14 @@ pytorch.Execution = class extends python.Execution { value.setType(torch.BoolType.get()); return value; } - /* - if (expression.target.type === '.') { - const target = this.target(expression.target.target, context); // this.expression(expression.target.target, context); - if (target instanceof torch.Value && target.type() instanceof torch.ClassType) { - const node = this._graph.create('prim::CallMethod'); - this.graph.insertNode(node); - const name = this.variable(expression.target.member.value, node); - node.addInput(name); - const args = expression.args.map((expression) => this.expression(expression, context)); - for (const arg of args) { - const value = this.variable(arg, node); - node.addInput(value); - } - return node.addOutput(); - } - } - */ return super.expression(expr, context); } - case '[]': { - if (expr.arguments instanceof ast.List && expr.arguments.elts.length === 1) { - const target = this.expression(expr.target, context); - const [elt] = expr.arguments.elts; - if (target instanceof torch.Value) { - let type = target.type(); + case 'Subscript': { + if (expr.slice instanceof ast.List && expr.slice.elts.length === 1) { + const value = this.expression(expr.value, context); + const [elt] = expr.slice.elts; + if (value instanceof torch.Value) { + let type = value.type(); if (type instanceof torch.OptionalType) { type = type.getElementType(); } @@ -1984,20 +1965,20 @@ pytorch.Execution = class extends python.Execution { let index = this.expression(elt, context); const node = this._graph.create('aten::__getitem__.t'); this.graph.insertNode(node); - node.addInput(target); + node.addInput(value); if (Number.isInteger(index)) { index = this.constant(index); } node.addInput(index); - const value = node.addOutput(); - value.setType(type.getElementType()); - return value; + const output = node.addOutput(); + output.setType(type.getElementType()); + return output; } if (type instanceof torch.DictType) { let key = this.expression(elt, context); const node = this._graph.create('aten::__getitem__.t'); this.graph.insertNode(node); - node.addInput(target); + node.addInput(value); if (type.getKeyType() instanceof torch.StringType && typeof key === 'string') { const value = new torch.Value(node); value.value = key; @@ -2008,22 +1989,26 @@ pytorch.Execution = class extends python.Execution { throw new pytorch.Error(`Unsupported dictionary key type.`); } node.addInput(key); - const value = node.addOutput(); - value.setType(type.getValueType()); - return value; + const output = node.addOutput(); + output.setType(type.getValueType()); + return output; } if (type instanceof torch.TupleType) { - let index = this.expression(elt, context); + const index = this.expression(elt, context); const node = this._graph.create('prim::TupleIndex'); this.graph.insertNode(node); - const value = node.addOutput(); - value.setType(type.elements()[index]); - node.addInput(target); - if (Number.isInteger(index)) { - index = this.constant(index); + node.addInput(value); + if (index instanceof torch.Value) { + node.addInput(index); + } else if (Number.isInteger(index)) { + const value = this.constant(index); + node.addInput(value); + } else { + throw new pytorch.Error(`Unsupported tuple index type.`); } - node.addInput(index); - return value; + const output = node.addOutput(); + output.setType(type.elements()[index]); + return output; } } } @@ -2087,55 +2072,39 @@ pytorch.Execution = class extends python.Execution { if (item instanceof torch.Value) { node.addInput(item); types.push(item.type()); - elements.push(item); } else if (pytorch.Utility.isTensor(item)) { const value = this.variable(item, node); node.addInput(value); types.push(value.type()); - elements.push(item); - } else if (Number.isInteger(item) || - typeof item === 'number' || - typeof item === 'boolean' || - typeof item === 'string') { + } else if (item === null || Number.isInteger(item) || typeof item === 'number' || typeof item === 'boolean' || typeof item === 'string') { const value = this.constant(item); node.addInput(value); types.push(value.type()); - elements.push(item); - } else if (item === null) { - const value = new torch.Value(node); - value.value = item; - node.addInput(value); - types.push(torch.NoneType.get()); - elements.push(item); } else { const value = new torch.Value(node); value.value = item; node.addInput(value); types.push(torch.Type.get()); - elements.push(item); } + elements.push(item); } const value = node.addOutput(); - value.value = elements; value.setType(torch.TupleType.get(types)); return value; } - case 'dict': { + case 'Dict': { const node = this._graph.create('prim::DictConstruct'); this.graph.insertNode(node); let keyType = null; let valueType = null; - for (const pair of expr.value) { - if (pair.type !== 'pair') { - throw new pytorch.Error(`Unsupported dict item type '${pair.type}'.`); - } - const key = this.expression(pair.key, context); + for (let i = 0; i < expr.keys.length; i++) { + const key = this.expression(expr.keys[i], context); const keyValue = this.variable(key, null); keyType = keyValue.type(); - const value = this.expression(pair.value, context); + node.addInput(keyValue); + const value = this.expression(expr.values[i], context); const valueValue = this.variable(value, null); valueType = valueValue.type(); - node.addInput(keyValue); node.addInput(valueValue); } const output = node.addOutput(); @@ -2157,60 +2126,49 @@ pytorch.Execution = class extends python.Execution { const torch = this.torch; switch (expr.type) { case 'Name': { - switch (expr.id) { - case 'None': return null; - case 'True': return true; - case 'False': return false; - default: { - const value = context.get(expr.id); - if (typeof value === 'number' || typeof value === 'boolean' || typeof value === 'string') { - return value; - } - if (value instanceof torch.Tensor && value.storage() && value.storage().size() !== undefined) { - return value; - } - if (value instanceof Map) { - return value; + const value = context.get(expr.id); + if (typeof value === 'number' || typeof value === 'boolean' || typeof value === 'string') { + return value; + } + if (value instanceof torch.Tensor && value.storage() && value.storage().size() !== undefined) { + return value; + } + if (value instanceof Map) { + return value; + } + if (value instanceof torch.Value) { + const node = value.node(); + if (node.kind() === 'prim::Constant') { + state.push(node); + return pytorch.Utility.constant(node, 'value'); + } else if (node.kind() === 'prim::ListConstruct' && node.inputs().every((value) => value instanceof torch.Value && value.node().kind() === 'prim::Constant')) { + state.push(node); + for (const value of node.inputs()) { + state.push(value.node()); } - if (value instanceof torch.Value) { - const node = value.node(); - if (node.kind() === 'prim::Constant') { - state.push(node); - return pytorch.Utility.constant(node, 'value'); - } else if (node.kind() === 'prim::ListConstruct' && node.inputs().every((value) => value instanceof torch.Value && value.node().kind() === 'prim::Constant')) { + return node.inputs().map((value) => pytorch.Utility.constant(value.node(), 'value')); + } else if (node.kind() === 'prim::TupleUnpack') { + const index = node.outputs().indexOf(value); + const input = node.inputs()[0].node(); + if (input.kind() === 'prim::TupleConstruct') { + const value = input.inputs()[index]; + const constant = value.node(); + if (constant.kind() === 'prim::Constant') { state.push(node); - for (const value of node.inputs()) { - state.push(value.node()); - } - return node.inputs().map((value) => pytorch.Utility.constant(value.node(), 'value')); - } else if (node.kind() === 'prim::TupleUnpack') { - const index = node.outputs().indexOf(value); - const input = node.inputs()[0].node(); - if (input.kind() === 'prim::TupleConstruct') { - const value = input.inputs()[index]; - const constant = value.node(); - if (constant.kind() === 'prim::Constant') { - state.push(node); - state.push(constant); - return pytorch.Utility.constant(constant, 'value'); - } - } + state.push(constant); + return pytorch.Utility.constant(constant, 'value'); } - state.splice(0, state.length); } - break; } + state.splice(0, state.length); } break; } case 'List': { return expr.elts.map((expr) => this.static(expr, context, state)); } - case 'string': { - return expr.value.substring(1, expr.value.length - 1); - } - case 'number': { - return Number(expr.value); + case 'Constant': { + return expr.value; } case 'Attribute': { const target = this.target(expr.value, context); @@ -2267,8 +2225,7 @@ pytorch.Execution = class extends python.Execution { break; } case 'Import': - case 'string': - case 'number': { + case 'Constant': { break; } case 'List': { @@ -2277,9 +2234,10 @@ pytorch.Execution = class extends python.Execution { } break; } - case 'dict': { - for (const item of value.value) { - this.variables(item, scope); + case 'Dict': { + for (let i = 0; i < value.keys.length; i++) { + this.variables(value.keys[i], scope); + this.variables(value.values[i], scope); } break; } @@ -2294,9 +2252,9 @@ pytorch.Execution = class extends python.Execution { this.variables(value.value, scope); break; } - case '[]': { - this.variables(value.target, scope); - this.variables(value.arguments, scope); + case 'Subscript': { + this.variables(value.value, scope); + this.variables(value.slice, scope); break; } case 'Call': { @@ -2328,7 +2286,7 @@ pytorch.Execution = class extends python.Execution { this.variables(value.value, scope); break; } - case 'while': { + case 'While': { this.variables(value.test, scope); this.variables(value.body, scope); break; @@ -2399,7 +2357,6 @@ pytorch.Execution = class extends python.Execution { } else if (test === false) { statements.splice(i, 1, ...condition.orelse.statements); } - const count = new Map(); for (const node of state) { if (count.has(node)) { @@ -2413,7 +2370,6 @@ pytorch.Execution = class extends python.Execution { node.destroy(); } } - if (test === true || test === false) { continue; } @@ -2550,7 +2506,7 @@ pytorch.Execution = class extends python.Execution { continue; } } - if (stmt.type === 'while') { + if (stmt instanceof ast.While) { const node = this._graph.create('prim::Loop'); node.setSourceRange(stmt.location); this.graph.insertNode(node); @@ -2608,30 +2564,31 @@ pytorch.Execution = class extends python.Execution { type(expr) { const ast = this.ast; const torch = this.torch; - if (expr.type === '[]' && expr.target instanceof ast.Name) { - switch (expr.target.id) { + if (expr instanceof ast.Subscript && expr.value instanceof ast.Name) { + const elts = expr.slice.elts; + switch (expr.value.id) { case 'List': { - const elementType = this.type(expr.arguments.elts[0]); - return torch.ListType.get(elementType); + const type = this.type(elts[0]); + return torch.ListType.get(type); } case 'Optional': { - const elementType = this.type(expr.arguments.elts[0]); - return torch.OptionalType.get(elementType); + const type = this.type(elts[0]); + return torch.OptionalType.get(type); } case 'Tuple': { - const elements = expr.arguments.elts.map((expr) => this.type(expr)); - return torch.TupleType.get(elements); + const types = elts.map((expr) => this.type(expr)); + return torch.TupleType.get(types); } case 'Dict': { - const key = this.type(expr.arguments.elts[0]); - const value = this.type(expr.arguments.elts[1]); + const key = this.type(elts[0]); + const value = this.type(elts[1]); return torch.DictType.get(key, value); } case 'Final': { - return this.type(expr.arguments.elts[0]); + return this.type(elts[0]); } default: { - throw new pytorch.Error(`Unsupported type element expression '${expr.target.id}'.`); + throw new pytorch.Error(`Unsupported type element expression '${expr.value.id}'.`); } } } @@ -2648,6 +2605,12 @@ pytorch.Execution = class extends python.Execution { default: throw new pytorch.Error(`Unsupported type expression '${expr.value}'.`); } } + if (expr instanceof ast.Constant) { + if (expr.value === null) { + return torch.NoneType.get(); + } + throw new pytorch.Error(`Unsupported type expression '${expr.value}'.`); + } if (expr instanceof ast.Attribute) { const identifier = this.identifier(expr); const type = this._resolver.resolveType(identifier); @@ -3051,7 +3014,7 @@ pytorch.Execution = class extends python.Execution { case 'SymInt[1]': return this.isType(obj, torch.IntType.get()) || this.isType(obj, torch.ListType.get(torch.IntType.get())); case 'float': { - return obj !== null && (typeof obj === 'number' || obj instanceof Number) || (obj instanceof torch.Value && obj.type() instanceof torch.FloatType); + return obj !== null && (typeof obj === 'number' || obj instanceof Number) || (obj instanceof torch.Value && (obj.type() instanceof torch.FloatType || obj.type() instanceof torch.IntType)); } case 'float[]': { if (Array.isArray(obj) && obj.every((item) => (typeof item === 'number' || item instanceof Number) && !isNaN(item))) { @@ -3447,21 +3410,22 @@ pytorch.Utility = class { switch (type.kind()) { case 'OptionalType': return `${pytorch.Utility.toType(type.getElementType())}?`; case 'ListType': return `${pytorch.Utility.toType(type.getElementType())}[]`; - case 'BoolType': return `boolean`; - case 'IntType': return `int64`; - case 'FloatType': return `float32`; - case 'StringType': return `string`; - case 'ComplexType': return `complex`; - case 'NumberType': return `scalar`; - case 'TensorType': return `tensor`; + case 'BoolType': return 'boolean'; + case 'IntType': return 'int64'; + case 'FloatType': return 'float32'; + case 'StringType': return 'string'; + case 'ComplexType': return 'complex'; + case 'NumberType': return 'scalar'; + case 'TensorType': return 'tensor'; case 'TupleType': return `tuple<${type.elements().map((type) => pytorch.Utility.toType(type)).join(', ')}>`; case 'DictType': return `map<${pytorch.Utility.toType(type.getKeyType())}, ${pytorch.Utility.toType(type.getValueType())}>`; - case 'DeviceObjType': return `device`; - case 'SymIntType': return `SymInt`; - case 'ScalarTypeType': return `ScalarType`; - case 'MemoryFormat': return `MemoryFormat`; - case 'Layout': return `Layout`; + case 'DeviceObjType': return 'device'; + case 'SymIntType': return 'SymInt'; + case 'ScalarTypeType': return 'ScalarType'; + case 'MemoryFormat': return 'MemoryFormat'; + case 'Layout': return 'Layout'; case 'VarType': return type.annotation_str; + case 'NoneType': return 'None'; default: throw new pytorch.Error(`Unsupported type '${type.kind()}'.`); } }