Skip to content

Commit

Permalink
Update pytorch.js (#1061)
Browse files Browse the repository at this point in the history
  • Loading branch information
lutzroeder committed Oct 31, 2024
1 parent a766f69 commit c2d1aa2
Show file tree
Hide file tree
Showing 3 changed files with 245 additions and 44 deletions.
44 changes: 34 additions & 10 deletions source/python.js
Original file line number Diff line number Diff line change
Expand Up @@ -6276,6 +6276,9 @@ python.Execution = class {
getElementType() {
return this._elem;
}
equals(rhs) {
return this.kind() === rhs.kind() && this.getElementType().equals(rhs.getElementType());
}
str() {
return `${this.getElementType().str()}[]`;
}
Expand Down Expand Up @@ -6393,6 +6396,9 @@ python.Execution = class {
torch.TensorType.value = torch.TensorType.value || new torch.TensorType();
return torch.TensorType.value;
}
equals(rhs) {
return this.kind() === rhs.kind();
}
str() {
return 'Tensor';
}
Expand Down Expand Up @@ -7317,7 +7323,7 @@ python.Execution = class {
});
this.registerType('torch.Graph', class {
constructor() {
this._unique = 1;
this._next_unique = 1;
this._all_nodes = [];
this._all_values = [];
this._all_blocks = [];
Expand Down Expand Up @@ -7357,9 +7363,12 @@ python.Execution = class {
}
this._insert_before = node;
}
get all_nodes() {
all_nodes() {
return this._all_nodes;
}
all_blocks() {
return this._all_blocks;
}
freeNode(n) {
const index = this._all_nodes.indexOf(n);
if (index !== -1) {
Expand All @@ -7382,14 +7391,18 @@ python.Execution = class {
});
this.registerType('torch.Block', class {
constructor(graph) {
this._unique = 1;
this._graph = graph;
this._input = graph.create('prim::Param');
this._output = graph.create('prim::Return');
this._input.next = this._output;
this._input.prev = this._output;
this._output.next = this._input;
this._output.prev = this._input;
this._graph.all_blocks().push(this);
this._output._owning_block = this;
// output_->topo_position_ = kUpperBound;
this._input._owning_block = this;
// input_->topo_position_ = kLowerBound;
}
param_node() {
return this._input;
Expand Down Expand Up @@ -7439,13 +7452,17 @@ python.Execution = class {
this._inputs = [];
this._outputs = [];
this._blocks = [];
this._graph.all_nodes.push(this);
this._graph.all_nodes().push(this);
this._prev = null;
this._next = null;
this._source_range = null;
}
owningGraph() {
return this._graph;
}
owningBlock() {
return this._owning_block;
}
kind() {
return this._kind;
}
Expand Down Expand Up @@ -7508,7 +7525,7 @@ python.Execution = class {
return this;
}
insertAfter(n) {
// this.owning_block_ = n->owningBlock();
this._owning_block = n.owningBlock();
const next = n.next;
n.next = this;
this.prev = n;
Expand Down Expand Up @@ -7607,11 +7624,17 @@ python.Execution = class {
kindOf(name) {
return this._values.get(name)[1];
}
setSourceRange(r) {
this._source_range = r;
}
sourceRange() {
return this._source_range;
}
});
this.registerType('torch.Value', class {
constructor(node) {
this._unique = node && node._unique ? node._unique++ : node._graph._unique++;
this._node = node && node._unique ? null : node;
this._unique = node && node._next_unique ? node._next_unique++ : node._graph._next_unique++; // remove always node
this._node = node && node._next_unique ? null : node;
this._uses = [];
}
unique() {
Expand Down Expand Up @@ -10714,7 +10737,7 @@ python.Execution = class {
}
case 'call': {
if (expression.target.type === '.') {
return this.call(expression.target.target, expression.target.member.value, expression.args, context);
return this.call(expression.target.target, expression.target.member.value, expression.args, context, expression.location);
}
return this.call(expression.target, null, expression.args, context);
}
Expand Down Expand Up @@ -10789,7 +10812,8 @@ python.Execution = class {
return undefined;
}

target(expression, context) {
target(expression, context, resolve) {
resolve = resolve === false ? false : true;
let current = expression;
let path = [];
for (;;) {
Expand All @@ -10812,7 +10836,7 @@ python.Execution = class {
break;
}
}
if (!target) {
if (!target && resolve) {
path.reverse();
const name = path.join('.');
const file = `${path.join('/')}.py`;
Expand Down
9 changes: 9 additions & 0 deletions source/pytorch-metadata.json
Original file line number Diff line number Diff line change
Expand Up @@ -220,6 +220,12 @@
{
"name": "aten::__interpolate.size_list_scale_list(Tensor input, int[]? size=None, float[]? scale_factor=None, str mode=\"nearest\", bool? align_corners=None, bool? recompute_scale_factor=None, bool antialias=False) -> Tensor"
},
{
"name": "aten::__ior__.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!)"
},
{
"name": "aten::__ior__.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!)"
},
{
"name": "aten::__irshift__.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!)"
},
Expand Down Expand Up @@ -532,6 +538,9 @@
{
"name": "aten::_unsafe_view.out(Tensor self, SymInt[] size, *, Tensor(a!) out) -> Tensor(a!)"
},
{
"name": "aten::_unwrap_optional(t(a)? optional) -> t(a)"
},
{
"name": "aten::_upsample_bilinear2d_aa(Tensor self, SymInt[2] output_size, bool align_corners, float? scales_h=None, float? scales_w=None) -> Tensor"
},
Expand Down
Loading

0 comments on commit c2d1aa2

Please sign in to comment.