Skip to content

Commit

Permalink
Update pytorch.js (#1061)
Browse files Browse the repository at this point in the history
  • Loading branch information
lutzroeder committed Dec 8, 2024
1 parent ec42697 commit ff05bb3
Show file tree
Hide file tree
Showing 2 changed files with 51 additions and 44 deletions.
1 change: 1 addition & 0 deletions source/python.js
Original file line number Diff line number Diff line change
Expand Up @@ -8432,6 +8432,7 @@ python.Execution = class {
}
setSourceRange(r) {
this._source_range = r;
return this;
}
sourceRange() {
return this._source_range;
Expand Down
94 changes: 50 additions & 44 deletions source/pytorch.js
Original file line number Diff line number Diff line change
Expand Up @@ -1686,7 +1686,7 @@ pytorch.Execution = class extends python.Execution {
} else if (obj instanceof torch.Value) {
value = obj;
} else {
value = new torch.Value(node ? node : this._graph);
value = new torch.Value(node ? node : this.graph);
}
if (pytorch.Utility.isTensor(obj)) {
value.value = obj;
Expand Down Expand Up @@ -1781,6 +1781,10 @@ pytorch.Execution = class extends python.Execution {
return super.target(expr, context);
}

create(kind, loc, n_outputs) {
return this._graph.create(kind, n_outputs).setSourceRange(loc);
}

expression(expr, context, typehint) {
if (!this.trace) {
return super.expression(expr, context);
Expand Down Expand Up @@ -1818,13 +1822,13 @@ pytorch.Execution = class extends python.Execution {
if (value.type() instanceof torch.TupleType) {
const node = this._graph.createTupleUnpack(value);
node.setSourceRange(expr.location);
this.graph.insertNode(node);
this._graph.insertNode(node);
outputs = node.outputs();
} else if (value.type() instanceof torch.ListType) {
const size = target.elts.length;
const node = this._graph.createListUnpack(value, size);
node.setSourceRange(expr.location);
this.graph.insertNode(node);
this._graph.insertNode(node);
outputs = node.outputs();
}
if (outputs === null) {
Expand Down Expand Up @@ -1897,7 +1901,7 @@ pytorch.Execution = class extends python.Execution {
const type = this.type(expr.args[0]);
const node = this._graph.createUninitialized(type);
node.setSourceRange(expr.location);
this.graph.insertNode(node);
this._graph.insertNode(node);
return node.output();
}
if (func instanceof ast.Name && func.id === 'unchecked_cast') {
Expand All @@ -1906,7 +1910,7 @@ pytorch.Execution = class extends python.Execution {
value = this.variable(value);
}
const type = this.type(expr.args[0]);
return this.graph.insertUncheckedCast(value, type);
return this._graph.insertUncheckedCast(value, type);
}
if (func instanceof ast.Name && func.id === 'isinstance') {
const value = this.expression(expr.args[0], context);
Expand All @@ -1918,12 +1922,12 @@ pytorch.Execution = class extends python.Execution {
}
const v = this.variable(value); // remove
const node = this._graph.createIsInstance(v, types);
this.graph.insertNode(node);
this._graph.insertNode(node);
return node.output();
}
if (func.attr === 'tolist' && expr.args.length === 0) {
const target = this.target(func.value, context);
return this.graph.insertToList(target, typehint);
return this._graph.insertToList(target, typehint);
}
return super.expression(expr, context);
}
Expand All @@ -1942,14 +1946,14 @@ pytorch.Execution = class extends python.Execution {
index = this._graph.insertConstant(index);
}
const node = this._graph.create('aten::__getitem__.t', [value, index]);
this.graph.insertNode(node);
this._graph.insertNode(node);
node.output().setType(type.getElementType());
return node.output();
}
if (type instanceof torch.DictType) {
let key = this.expression(elt, context);
const node = this._graph.create('aten::__getitem__.t', [value]);
this.graph.insertNode(node);
this._graph.insertNode(node);
if (type.getKeyType() instanceof torch.StringType && typeof key === 'string') {
const value = new torch.Value(node);
value.value = key;
Expand All @@ -1971,7 +1975,7 @@ pytorch.Execution = class extends python.Execution {
const output_type = type.elements()[index];
index = this._graph.insertConstant(index);
const node = this._graph.createTupleIndex(value, index, output_type);
this.graph.insertNode(node);
this._graph.insertNode(node);
return node.output();
}
}
Expand All @@ -1983,7 +1987,7 @@ pytorch.Execution = class extends python.Execution {
const attr = expr.attr;
if (target instanceof torch.Value && target.type() instanceof torch.ClassType) {
const node = this._graph.createGetAttr(target, attr);
this.graph.insertNode(node);
this._graph.insertNode(node);
return node.output();
}
return target[attr];
Expand Down Expand Up @@ -2012,7 +2016,7 @@ pytorch.Execution = class extends python.Execution {
}
const contained_type = typehint ? typehint.getElementType() : item_type;
const node = this._graph.createList(contained_type, values);
this.graph.insertNode(node);
this._graph.insertNode(node);
return node.output();
}
break;
Expand All @@ -2033,7 +2037,7 @@ pytorch.Execution = class extends python.Execution {
}
const node = this._graph.createTuple(values);
node.setSourceRange(expr.location);
this.graph.insertNode(node);
this._graph.insertNode(node);
return node.output();
}
case 'Dict': {
Expand All @@ -2058,7 +2062,7 @@ pytorch.Execution = class extends python.Execution {
const key_type = typehint ? typehint.getKeyType() : keyType;
const value_type = typehint ? typehint.getValueType() : valueType;
const node = this._graph.createDict(key_type, value_type, keys, values);
this.graph.insertNode(node);
this._graph.insertNode(node);
return node.output();
}
default: {
Expand Down Expand Up @@ -2367,9 +2371,8 @@ pytorch.Execution = class extends python.Execution {
return value.type();
};
this.variables(condition, condition);
const node = this._graph.create('prim::If', 0);
node.setSourceRange(stmt.location);
this.graph.insertNode(node);
const node = this.create('prim::If', stmt.location, 0);
this._graph.insertNode(node);
node.addInput(test);
const prev = this._graph.insertPoint();
const true_block = node.addBlock();
Expand Down Expand Up @@ -2428,9 +2431,14 @@ pytorch.Execution = class extends python.Execution {
throw new pytorch.Error("Unsupported condition.");
}
if (stmt instanceof ast.For) {
const node = this._graph.create('prim::Loop', 0);
node.setSourceRange(stmt.location);
this.graph.insertNode(node);
const range = stmt.location;
const node = this.create('prim::Loop', range, 0);
this._graph.insertNode(node);
const itrs = stmt.iter instanceof ast.Tuple ? stmt.iter.elts : [stmt.iter];
// const targets = stmt.target instanceof ast.Tuple ? stmt.target.elts : [stmt.target];
if (itrs.length !== 1) {
throw new pytorch.Error('List of iterables is not supported currently.');
}
const loop = stmt;
if (loop.target instanceof ast.Name && loop.iter instanceof ast.Tuple === false) {
const range = this.expression(loop.iter, context);
Expand All @@ -2449,9 +2457,8 @@ pytorch.Execution = class extends python.Execution {
}
}
if (stmt instanceof ast.While) {
const node = this._graph.create('prim::Loop', 0);
node.setSourceRange(stmt.location);
this.graph.insertNode(node);
const node = this._graph.create('prim::Loop', stmt.location, 0);
this._graph.insertNode(node);
const test = this.expression(stmt.test, context);
if (test) {
const value = this.block(stmt.body, context);
Expand Down Expand Up @@ -2577,9 +2584,9 @@ pytorch.Execution = class extends python.Execution {
if (identifier) {
const type = this._resolver.resolveType(identifier);
if (type) {
const node = this.graph.createObject(type);
const node = this._graph.createObject(type);
node.setSourceRange(location);
this.graph.insertNode(node);
this._graph.insertNode(node);
return node.output();
}
}
Expand All @@ -2590,9 +2597,9 @@ pytorch.Execution = class extends python.Execution {
if (args.length === 0) {
return obj;
}
const node = this.graph.create('prim::CallMethod', 0);
const node = this._graph.create('prim::CallMethod', 0);
node.setSourceRange(location);
this.graph.insertNode(node);
this._graph.insertNode(node);
node.s_('name', name);
node.addInput(obj);
const evalArgs = args.map((arg) => this.expression(arg, context));
Expand All @@ -2608,8 +2615,8 @@ pytorch.Execution = class extends python.Execution {
if (!overload) {
const moduleTarget = this.target(target, context);
if (moduleTarget instanceof torch.Value && moduleTarget.type() instanceof torch.ClassType) {
const node = this.graph.create('prim::CallMethod');
this.graph.insertNode(node);
const node = this._graph.create('prim::CallMethod');
this._graph.insertNode(node);
node.s_('name', name);
const evalArgs = args.map((expression) => this.expression(expression, context));
for (const arg of evalArgs) {
Expand All @@ -2627,12 +2634,12 @@ pytorch.Execution = class extends python.Execution {
const values = evalArgs.map((arg) => this.variable(arg));
const node = this._graph.createTuple(values, type);
node.setSourceRange(location);
this.graph.insertNode(node);
this._graph.insertNode(node);
return node.output();
}
if (type instanceof torch.ClassType) {
const node = this.graph.create('prim::CallMethod');
this.graph.insertNode(node);
const node = this._graph.create('prim::CallMethod');
this._graph.insertNode(node);
node.s_('name', name);
const evalArgs = args.map((expression) => this.expression(expression, context));
for (const arg of evalArgs) {
Expand All @@ -2646,9 +2653,8 @@ pytorch.Execution = class extends python.Execution {
}
const [schema, evalArgs] = overload;
const op = schema.overload_name ? `${schema.name}.${schema.overload_name}` : schema.name;
const node = this._graph.create(op, 0);
node.setSourceRange(location);
this.graph.insertNode(node);
const node = this.create(op, location, 0);
this._graph.insertNode(node);
const referencedParameters = [];
const parameters = schema.arguments;
const varTypes = new Map();
Expand All @@ -2666,7 +2672,7 @@ pytorch.Execution = class extends python.Execution {
let index = 0;
while (position < evalArgs.length) {
if (index >= parameters.length) {
if (schema.name.startsWith('_caffe2::') || schema.is_vararg) {
if (schema.is_vararg) {
break;
}
throw new pytorch.Error('Invalid parameter length.');
Expand Down Expand Up @@ -2694,25 +2700,25 @@ pytorch.Execution = class extends python.Execution {
} else if (type instanceof torch.ListType && type.getElementType() instanceof torch.TensorType) {
const v = evalArgs[position];
if ((v instanceof torch.Value && v.type() instanceof torch.ListType && v.type().getElementType() instanceof torch.TensorType) ||
(Array.isArray(v) && v.every((item) => pytorch.Utility.isTensor(item) || item === null || (item instanceof torch.Value && item.type() instanceof torch.TensorType)))) {
(v === null || Array.isArray(v) && v.every((item) => pytorch.Utility.isTensor(item) || item === null || (item instanceof torch.Value && item.type() instanceof torch.TensorType)))) {
position++;
if (v instanceof torch.Value) {
input = v;
match = true;
} else {
const list = this._graph.create('prim::ListConstruct');
this.graph.insertNode(node);
for (const arg of v) {
const values = [];
for (const arg of v || []) {
const tensor = arg;
if (tensor) {
tensor.__count__ = (tensor.__count__ || 0) + 1;
}
const value = this.variable(tensor);
value.setType(torch.TensorType.get());
list.addInput(value);
values.push(value);
}
list.output().setType(torch.ListType.create(torch.TensorType.get()));
input = list.output();
const node = this._graph.createList(torch.TensorType.get(), values);
this._graph.insertNode(node);
input = node.output();
match = true;
}
} else {
Expand Down Expand Up @@ -3146,7 +3152,7 @@ pytorch.Execution = class extends python.Execution {
let index = 0;
while (position < evalArgs.length) {
if (index >= parameters.length) {
next = !schema.name.startsWith('_caffe2::') && !schema.is_vararg;
next = !schema.is_vararg;
break;
}
const arg = parameters[index];
Expand Down

0 comments on commit ff05bb3

Please sign in to comment.