From c2d1aa2461bb3e24c75107069b1adae7776f26ad Mon Sep 17 00:00:00 2001 From: Lutz Roeder Date: Wed, 30 Oct 2024 20:02:19 -0700 Subject: [PATCH] Update pytorch.js (#1061) --- source/python.js | 44 +++++-- source/pytorch-metadata.json | 9 ++ source/pytorch.js | 236 ++++++++++++++++++++++++++++++----- 3 files changed, 245 insertions(+), 44 deletions(-) diff --git a/source/python.js b/source/python.js index 67cc04bf80..959cf45b83 100644 --- a/source/python.js +++ b/source/python.js @@ -6276,6 +6276,9 @@ python.Execution = class { getElementType() { return this._elem; } + equals(rhs) { + return this.kind() === rhs.kind() && this.getElementType().equals(rhs.getElementType()); + } str() { return `${this.getElementType().str()}[]`; } @@ -6393,6 +6396,9 @@ python.Execution = class { torch.TensorType.value = torch.TensorType.value || new torch.TensorType(); return torch.TensorType.value; } + equals(rhs) { + return this.kind() === rhs.kind(); + } str() { return 'Tensor'; } @@ -7317,7 +7323,7 @@ python.Execution = class { }); this.registerType('torch.Graph', class { constructor() { - this._unique = 1; + this._next_unique = 1; this._all_nodes = []; this._all_values = []; this._all_blocks = []; @@ -7357,9 +7363,12 @@ python.Execution = class { } this._insert_before = node; } - get all_nodes() { + all_nodes() { return this._all_nodes; } + all_blocks() { + return this._all_blocks; + } freeNode(n) { const index = this._all_nodes.indexOf(n); if (index !== -1) { @@ -7382,7 +7391,6 @@ python.Execution = class { }); this.registerType('torch.Block', class { constructor(graph) { - this._unique = 1; this._graph = graph; this._input = graph.create('prim::Param'); this._output = graph.create('prim::Return'); @@ -7390,6 +7398,11 @@ python.Execution = class { this._input.prev = this._output; this._output.next = this._input; this._output.prev = this._input; + this._graph.all_blocks().push(this); + this._output._owning_block = this; + // output_->topo_position_ = kUpperBound; + this._input._owning_block = this; + // input_->topo_position_ = kLowerBound; } param_node() { return this._input; @@ -7439,13 +7452,17 @@ python.Execution = class { this._inputs = []; this._outputs = []; this._blocks = []; - this._graph.all_nodes.push(this); + this._graph.all_nodes().push(this); this._prev = null; this._next = null; + this._source_range = null; } owningGraph() { return this._graph; } + owningBlock() { + return this._owning_block; + } kind() { return this._kind; } @@ -7508,7 +7525,7 @@ python.Execution = class { return this; } insertAfter(n) { - // this.owning_block_ = n->owningBlock(); + this._owning_block = n.owningBlock(); const next = n.next; n.next = this; this.prev = n; @@ -7607,11 +7624,17 @@ python.Execution = class { kindOf(name) { return this._values.get(name)[1]; } + setSourceRange(r) { + this._source_range = r; + } + sourceRange() { + return this._source_range; + } }); this.registerType('torch.Value', class { constructor(node) { - this._unique = node && node._unique ? node._unique++ : node._graph._unique++; - this._node = node && node._unique ? null : node; + this._unique = node && node._next_unique ? node._next_unique++ : node._graph._next_unique++; // remove always node + this._node = node && node._next_unique ? null : node; this._uses = []; } unique() { @@ -10714,7 +10737,7 @@ python.Execution = class { } case 'call': { if (expression.target.type === '.') { - return this.call(expression.target.target, expression.target.member.value, expression.args, context); + return this.call(expression.target.target, expression.target.member.value, expression.args, context, expression.location); } return this.call(expression.target, null, expression.args, context); } @@ -10789,7 +10812,8 @@ python.Execution = class { return undefined; } - target(expression, context) { + target(expression, context, resolve) { + resolve = resolve === false ? false : true; let current = expression; let path = []; for (;;) { @@ -10812,7 +10836,7 @@ python.Execution = class { break; } } - if (!target) { + if (!target && resolve) { path.reverse(); const name = path.join('.'); const file = `${path.join('/')}.py`; diff --git a/source/pytorch-metadata.json b/source/pytorch-metadata.json index 840e3fbb32..5b657bcea8 100755 --- a/source/pytorch-metadata.json +++ b/source/pytorch-metadata.json @@ -220,6 +220,12 @@ { "name": "aten::__interpolate.size_list_scale_list(Tensor input, int[]? size=None, float[]? scale_factor=None, str mode=\"nearest\", bool? align_corners=None, bool? recompute_scale_factor=None, bool antialias=False) -> Tensor" }, + { + "name": "aten::__ior__.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!)" + }, + { + "name": "aten::__ior__.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!)" + }, { "name": "aten::__irshift__.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!)" }, @@ -532,6 +538,9 @@ { "name": "aten::_unsafe_view.out(Tensor self, SymInt[] size, *, Tensor(a!) out) -> Tensor(a!)" }, + { + "name": "aten::_unwrap_optional(t(a)? optional) -> t(a)" + }, { "name": "aten::_upsample_bilinear2d_aa(Tensor self, SymInt[2] output_size, bool align_corners, float? scales_h=None, float? scales_w=None) -> Tensor" }, diff --git a/source/pytorch.js b/source/pytorch.js index 459b6560a3..e6b4fbc107 100644 --- a/source/pytorch.js +++ b/source/pytorch.js @@ -362,6 +362,10 @@ pytorch.Node = class { let module = null; if (pytorch.Utility.isInstance(obj, 'torch.Node')) { const node = obj; + const sourceRange = node.sourceRange(); + if (sourceRange) { + this.metadata.push(new pytorch.Argument('source', sourceRange.replace(/^at\s/, '').replace(/\.$/, ''))); + } const kind = node.kind(); this.type = { identifier: kind, @@ -1732,7 +1736,8 @@ pytorch.Execution = class extends python.Execution { return type; } - target(expression, context) { + target(expression, context, resolve) { + resolve = resolve === false ? false : true; if (expression.type === 'id') { switch (expression.value) { case 'torch': @@ -1779,7 +1784,7 @@ pytorch.Execution = class extends python.Execution { break; } } - if (!target) { + if (!target && resolve) { path.reverse(); const name = path.join('.'); const file = `${path.join('/')}.py`; @@ -1789,7 +1794,7 @@ pytorch.Execution = class extends python.Execution { return this.resolve(name); } } - return super.target(expression, context); + return super.target(expression, context, resolve); } expression(expression, context) { @@ -2062,8 +2067,10 @@ pytorch.Execution = class extends python.Execution { types.push(value.type()); elements.push(item); } else if (Number.isInteger(item)) { - const value = new torch.Value(node); - value.value = item; + const value = this.constant(item); + node.addInput(value); + /// const value = new torch.Value(node); + // value.value = item; types.push(torch.IntType.get()); elements.push(item); } else if (typeof item === 'boolean') { @@ -2168,7 +2175,7 @@ pytorch.Execution = class extends python.Execution { break; } case 'list': { - return expression.value.map((expression) => this.static(expression, context)); + return expression.value.map((expression) => this.static(expression, context, state)); } case 'string': { return expression.value.substring(1, expression.value.length - 1); @@ -2184,11 +2191,18 @@ pytorch.Execution = class extends python.Execution { break; } case 'call': { + if (expression.target.type === 'id' && expression.target.value === 'annotate') { + return this.static(expression.args[1], context, state); + } const args = expression.args.map((expression) => this.static(expression, context, state)); if (args.every((arg) => arg !== undefined)) { - const target = this.target(expression.target, context); + const target = this.target(expression.target, context, false); if (typeof target === 'function') { - return target(...args); + if (target && target.__class__ === this._builtins.type) { + // debugger; + } else { + return target(...args); + } } } state.splice(0, state.length); @@ -2201,6 +2215,109 @@ pytorch.Execution = class extends python.Execution { return undefined; } + variables(value, scope) { + if (!scope.refs) { + scope.refs = new Set(); + } + switch (value.type) { + case '=': { + this.variables(value.target, scope); + this.variables(value.expression, scope); + break; + } + case '.': { + this.variables(value.target, scope); + this.variables(value.member, scope); + break; + } + case 'id': { + scope.refs.add(value.value); + break; + } + case 'import': + case 'string': + case 'number': { + break; + } + case 'list': { + for (const item of value.value) { + this.variables(item, scope.refs); + } + break; + } + case 'dict': { + for (const item of value.value) { + this.variables(item, scope.refs); + } + break; + } + case 'tuple': { + for (const item of value.value) { + this.variables(item, scope.refs); + } + break; + } + case 'pair': { + this.variables(value.key, scope.refs); + this.variables(value.value, scope.refs); + break; + } + case '[]': { + this.variables(value.target, scope.refs); + this.variables(value.arguments, scope.refs); + break; + } + case 'call': { + this.variables(value.target, scope.refs); + for (const arg of value.args) { + this.variables(arg, scope.refs); + } + break; + } + case 'if': { + this.variables(value.test, scope.refs); + this.variables(value.body, scope.refs); + this.variables(value.orelse, scope.refs); + break; + } + case 'for': { + for (const target of value.target) { + this.variables(target, scope.refs); + } + for (const iter of value.iter) { + this.variables(iter, scope.refs); + } + this.variables(value.body, scope.refs); + break; + } + case 'block': { + for (const statement of value.statements) { + this.variables(statement, scope.refs); + } + break; + } + case 'return': { + this.variables(value.expression, scope.refs); + break; + } + case 'while': { + this.variables(value.test, scope.refs); + this.variables(value.body, scope.refs); + break; + } + case 'var': { + this.variables(value.initializer, scope.refs); + break; + } + case 'pass': { + break; + } + default: { + throw new pytorch.Error(`Unsupported type '${value.type}'.`); + } + } + } + block(statements, context) { this.traceIf = false; if (!this.traceIf) { @@ -2209,30 +2326,14 @@ pytorch.Execution = class extends python.Execution { statements = Array.prototype.slice.call(statements); while (statements.length > 0) { if (statements.length > 1) { - const containsVariableReference = (queue, value) => { - while (queue.length > 0) { - const obj = queue.shift(); - if (obj && obj.type === 'id' && obj.value === value) { - return true; - } else if (Array.isArray(obj)) { - for (const item of obj) { - if (Array.isArray(item) || (Object(item) === item && item.type)) { - queue.push(item); - } + const containsVariableReference = (statements, value) => { + if (statements) { + for (const statement of statements) { + if (!statement.refs) { + this.variables(statement, statement); } - } else if (Object(obj) === obj) { - for (const [key, value] of Object.entries(obj)) { - if (key !== 'identifier') { - if (Array.isArray(value)) { - for (const item of value) { - if (Array.isArray(item) || (Object(item) === item && item.type)) { - queue.push(item); - } - } - } else if (Object(value) === value && value.type) { - queue.push(value); - } - } + if (statement.refs.has(value)) { + return true; } } } @@ -2245,13 +2346,16 @@ pytorch.Execution = class extends python.Execution { if (assign.type === '=' && condition.type === 'if' && assign.target.type === 'id' && condition.test.type === 'id' && assign.target.value === condition.test.value && - !containsVariableReference(statements.slice(2), condition.test.value)) { + !containsVariableReference(statements.slice(2), condition.test.value) && + (!statements[1].body || !containsVariableReference(statements[1].body.statements), condition.test.value) && + (!statements[1].orelse || !containsVariableReference(statements[1].orelse.statements, condition.test.value))) { statements.shift(); statements[0] = { type: 'if', test: assign.expression, body: condition.body, - orelse: condition.orelse + orelse: condition.orelse, + location: condition.location, }; } } @@ -2326,17 +2430,79 @@ pytorch.Execution = class extends python.Execution { if (this.traceIf) { const test = this.expression(statement.test, context); if (test instanceof torch.Value && test.type() instanceof torch.BoolType) { + const __variables = (statements) => { + const set = new Set(); + for (const statement of statements) { + if (statement.type === '=') { + if (statement.target.type === 'id') { + set.add(statement.target.value); + } else if (statement.target.type === 'tuple') { + for (const value of statement.target.value) { + if (value.type === 'id') { + set.add(value.value); + } else { + // debugger; + } + } + } else { + // debugger; + } + } + } + return set; + }; + const __type = (value) => { + if (!value) { + return null; + } + if (pytorch.Utility.isTensor(value)) { + return torch.TensorType.get(); + } + return value.type(); + }; + this.variables(statement, statement); const node = this._graph.create('prim::If'); + node.setSourceRange(statement.location); this.graph.insertNode(node); node.addInput(test); const prev = this._graph.insertPoint(); const true_block = node.addBlock(); this._graph.setInsertPoint(true_block); + let vars = __variables(statement.body.statements.concat(statement.orelse.statements)); + vars = new Map(Array.from(vars).map((name) => [name, {}])); this.block(statement.body.statements, context); + for (const [name, entry] of vars) { + entry.body = context.get(name); + } const false_block = node.addBlock(); this._graph.setInsertPoint(false_block); this.block(statement.orelse.statements, context); + for (const [name, entry] of vars) { + entry.orelse = context.get(name); + } this._graph.setInsertPoint(prev); + for (const [name, entry] of vars) { + const value = node.addOutput(); + context.set(name, value); + let type = null; + if (entry.body && !entry.orelse) { + type = __type(entry.body); + } else if (entry.orelse && !entry.body) { + type = __type(entry.orelse); + } else { + // compare + const type1 = __type(entry.body); + const type2 = __type(entry.orelse); + if (type1 === null && type2 === null) { + type = null; + } else if (type1.equals(type2)) { + type = type1; + } else { + // debugger; + } + } + value.setType(type); + } return undefined; } } else { @@ -2358,6 +2524,7 @@ pytorch.Execution = class extends python.Execution { } else if (test instanceof torch.Value && test.type() instanceof torch.BoolType) { const node = this._graph.create('prim::If'); this.graph.insertNode(node); + node.setSourceRange(statement.location); node.addInput(test); const prev = this._graph.insertPoint(); const true_block = node.addBlock(); @@ -2430,7 +2597,7 @@ pytorch.Execution = class extends python.Execution { throw new pytorch.Error(`Unsupported type expression '${expression.type}'.`); } - call(target, name, args, context) { + call(target, name, args, context, location) { if (!this.trace) { return super.call(target, name, args, context); } @@ -2484,6 +2651,7 @@ pytorch.Execution = class extends python.Execution { const [schema, evalArgs] = overload; const op = schema.overload_name ? `${schema.name}.${schema.overload_name}` : schema.name; const node = this._graph.create(op); + node.setSourceRange(location); this.graph.insertNode(node); const referencedParameters = []; const parameters = schema.arguments;