Skip to content

Commit

Permalink
Update pytorch.js (#1061)
Browse files Browse the repository at this point in the history
  • Loading branch information
lutzroeder committed Dec 9, 2024
1 parent c8c9eb8 commit 84020e0
Show file tree
Hide file tree
Showing 2 changed files with 150 additions and 15 deletions.
115 changes: 104 additions & 11 deletions source/python.js
Original file line number Diff line number Diff line change
Expand Up @@ -6800,6 +6800,7 @@ python.Execution = class {
this._attributes = new Map();
this._methods = new Map();
this._staticmethods = new Map();
this._constants = new Map();
}
qualified_name() {
return this.annotation_str;
Expand All @@ -6822,7 +6823,7 @@ python.Execution = class {
findStaticMethod(name) {
return this._staticmethods.get(name);
}
addAttribute(name, type) {
addAttribute(name, type /*, is_parameter, is_buffer */) {
this._attributes.set(name, type);
}
findAttribute(name) {
Expand All @@ -6835,6 +6836,9 @@ python.Execution = class {
}
methods() {
}
addConstant(name, value) {
this._constants.set(name, value);
}
str() {
return this.qualified_name();
}
Expand Down Expand Up @@ -8601,20 +8605,105 @@ python.Execution = class {
throw new python.Error('TorchScript does not support class inheritance.');
}
}
importClass(qualified_name, class_def, is_module) {
if (qualified_name.prefix().startsWith('__torch__.torch.classes')) {
importClass(qualified_classname, class_def, is_module) {
if (qualified_classname.prefix().startsWith('__torch__.torch.classes')) {
return;
}
const class_type = new torch.ClassType(qualified_name.qualifiedName(), this._cu, is_module);
for (const entry of class_def.body) {
if (entry instanceof ast.AnnAssign) {
const target = this._cu.execution.identifier(entry.target);
const annotation = this._cu.execution.type(entry.annotation, null);
class_type.addAttribute(target, annotation);
const parameter_names = new Set();
const buffer_names = new Set();
const methods = [];
const attributes = [];
const constants = [];
const pre_hook_names = new Set();
const pre_hook_def_map = new Map();
const hook_names = new Set();
const hook_def_map = new Map();
const class_type = new torch.ClassType(qualified_classname.qualifiedName(), this._cu, is_module);
for (const stmt of class_def.body) {
if (stmt instanceof ast.Assign || stmt instanceof ast.AnnAssign) {
let target = null;
let annotation = null;
let value = null;
if (stmt instanceof ast.Assign) {
target = stmt.targets;
value = stmt.value;
} else {
target = stmt.target;
annotation = stmt.annotation;
value = stmt.value;
}
if (target instanceof ast.Name) {
const name = this._cu.execution.identifier(target);
switch (name) {
case '__annotations__': {
continue;
}
case '__parameters__': {
for (const elt of value.elts) {
parameter_names.add(elt.value);
}
break;
}
case '__buffers__': {
for (const elt of value.elts) {
buffer_names.add(elt.value);
}
break;
}
case '__forward_pre_hooks__': {
for (const elt of value.elts) {
pre_hook_names.add(elt.value);
}
break;
}
case '__forward_hooks__': {
for (const elt of value.elts) {
hook_names.add(elt.value);
}
break;
}
default: {
if (value) {
constants.push({ name, value, annotation });
} else {
attributes.push({ name, value, annotation });
}
break;
}
}
} else if (target instanceof ast.Subscript) {
// not implemented
continue;
} else {
throw new python.Error('Unexpected statement kind in module metadata.');
}
} else if (stmt instanceof ast.FunctionDef) {
const def = stmt;
const def_name = def.name;
if (pre_hook_names.has(def_name)) {
pre_hook_def_map.set(def_name, def);
} else if (hook_names.has(def_name)) {
hook_def_map.set(def_name, def);
} else {
methods.push(def);
}
} else {
throw new python.Error('Unexpected statement kind in class body.');
}
}
for (const assign of attributes) {
const name = assign.name;
const annotation = this._cu.execution.type(assign.annotation, null);
const is_parameter = parameter_names.has(name);
const is_buffer = buffer_names.has(name);
class_type.addAttribute(name, annotation, is_parameter, is_buffer);
}
for (const constant of constants) {
class_type.addConstant(constant.name, constant.value);
}
// debugger;
this._cu.register_type(class_type);
this._cu.define(qualified_classname, null, null, methods);
}
importNamedTuple(qualified_name, named_tuple_def) {
const field_names = [];
Expand Down Expand Up @@ -9111,10 +9200,14 @@ python.Execution = class {
if (!this.data.forward) {
throw new python.Error("Module 'forward' not implemented.");
}
const args = [this.data]; // self
this.traceAttr = false;
const args = [];
if (!this.traceAttr) {
args.push(this.data); // self
}
if (this.data.forward.__code__ && this.data.forward.__code__.args) {
for (const arg of this.data.forward.__code__.args) {
if (arg.name !== 'self') {
if (this.traceAttr || arg.name !== 'self') {
const value = execution.graph.addInput(arg.name);
value.setType(execution.type(arg.parameterType));
if (isTensor(value)) {
Expand Down
50 changes: 46 additions & 4 deletions source/pytorch.js
Original file line number Diff line number Diff line change
Expand Up @@ -1929,6 +1929,15 @@ pytorch.Execution = class extends python.Execution {
const target = this.target(func.value, context);
return this._graph.insertToList(target, typehint);
}
if (this.traceAttr) {
if (func instanceof ast.Name && func.id === 'getattr') {
const obj = this.expression(expr.args[0], context);
const field = this.expression(expr.args[1], context);
const n = this._graph.createGetAttr(obj, field);
this._graph.insertNode(n);
return n.output();
}
}
return super.expression(expr, context);
}
case 'Subscript': {
Expand Down Expand Up @@ -2258,6 +2267,20 @@ pytorch.Execution = class extends python.Execution {
}
}

emitSugaredExpr(tree, n_binders, type_hint) {
const ast = this.ast;
if (tree instanceof ast.Var) {
//
} else if (tree instanceof ast.Attribute) {
//
} else if (tree instanceof ast.Apply) {
//
} if (tree instanceof ast.Subscript) {
//
}
return this.emitSimpleExpr(tree, type_hint);
}

block(statements, context) {
const ast = this.ast;
const torch = this.torch;
Expand Down Expand Up @@ -2432,13 +2455,26 @@ pytorch.Execution = class extends python.Execution {
}
if (stmt instanceof ast.For) {
const range = stmt.location;
const node = this.create('prim::Loop', range, 0);
this._graph.insertNode(node);
const n = this._graph.insertNode(this.create('prim::Loop', range, 0));
const itrs = stmt.iter instanceof ast.Tuple ? stmt.iter.elts : [stmt.iter];
// const targets = stmt.target instanceof ast.Tuple ? stmt.target.elts : [stmt.target];
if (itrs.length !== 1) {
throw new pytorch.Error('List of iterables is not supported currently.');
}
/*
// const sv = this.expression(itrs[0], context);
const sv = this.emitSugaredExpr(itrs[0], 1);
const iterable = sv.iter(range, method);
if (iterable.shouldEmitUnrolled()) {
this.emitUnrolledLoop(loc, emit_body, iterable, targets);
} else {
this.emitLoopCommon(loc, emit_body, iterable, targets, {});
}
*/

/* const body_block = */ n.addBlock();
/* const condition_block = */ n.addBlock();

const loop = stmt;
if (loop.target instanceof ast.Name && loop.iter instanceof ast.Tuple === false) {
const range = this.expression(loop.iter, context);
Expand Down Expand Up @@ -2480,13 +2516,18 @@ pytorch.Execution = class extends python.Execution {
}

statement(stmt, context) {
const ast = this.ast;
const torch = this.torch;
if (stmt.__class__.__name__ === 'ClassDef') {
const name = `${context.get('__name__')}.${stmt.name}`;
this._resolver.resolveType(name);
}

if (!this.trace) {
return super.statement(stmt, context);
}

switch (stmt.__class__.__name__) {
case 'ClassDef': {
/*
super.statement(stmt, context);
const value = context.get(stmt.name);
const type = new torch.ClassType(`${value.__module__}.${value.__name__}`);
Expand All @@ -2498,6 +2539,7 @@ pytorch.Execution = class extends python.Execution {
}
}
value.__type__ = type;
*/
return undefined;
}
case 'If': {
Expand Down

0 comments on commit 84020e0

Please sign in to comment.