Skip to content

Commit d66c5bf

Browse files
committed
Update python.js (#1061)
1 parent d774780 commit d66c5bf

File tree

2 files changed

+26
-8
lines changed

2 files changed

+26
-8
lines changed

source/python.js

Lines changed: 19 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -8428,7 +8428,7 @@ python.Execution = class {
84288428
insertNode(node) {
84298429
return node.insertBefore(this._insert_before);
84308430
}
8431-
insertConstant(val) {
8431+
insertConstant(val, loc) {
84328432
const n = this.create('prim::Constant');
84338433
this.insertNode(n);
84348434
let type = null;
@@ -8442,22 +8442,29 @@ python.Execution = class {
84428442
n.ss_('value', val);
84438443
type = torch.ListType.create(torch.StringType.get());
84448444
} else if (typeof val === 'boolean') {
8445-
// return value;
84468445
n.i_('value', val === true ? 1 : 0);
84478446
type = torch.BoolType.get();
84488447
} else if (Number.isInteger(val)) {
84498448
n.i_('value', val);
84508449
type = torch.IntType.get();
84518450
} else if (typeof val === 'number') {
8452-
// return value;
84538451
n.f_('value', val);
84548452
type = torch.FloatType.get();
8453+
} else if (val instanceof torch.Tensor) {
8454+
n.t_('value', val);
8455+
type = torch.TensorType.get();
8456+
} else if (val instanceof torch.ScriptObject) {
8457+
n.ival_('value', val);
8458+
type = val.type();
84558459
} else {
84568460
throw new python.Error(`Unsupported value type '${typeof value}'.`);
84578461
}
84588462
if (type) {
84598463
n.output().setType(type);
84608464
}
8465+
if (loc) {
8466+
n.setSourceRange(loc);
8467+
}
84618468
return n.output();
84628469
}
84638470
insertMethodCall(method_name, matched) {
@@ -8768,6 +8775,12 @@ python.Execution = class {
87688775
f(name) {
87698776
return this._values.get(name)[0];
87708777
}
8778+
t_(name, value) {
8779+
this._values.set(name, [value, 't']);
8780+
}
8781+
t(name) {
8782+
return this._values.get(name)[0];
8783+
}
87718784
tys_(name, value) {
87728785
this._values.set(name, [value, 'tys']);
87738786
}
@@ -8860,9 +8873,10 @@ python.Execution = class {
88608873
this._type = type;
88618874
}
88628875
set value(value) { // remove
8863-
if (value instanceof torch.Value === false) {
8864-
this._value = value;
8876+
if (value instanceof torch.Value) {
8877+
throw new python.Error('Value cannot be a value.');
88658878
}
8879+
this._value = value;
88668880
}
88678881
get value() { // remove
88688882
return this._value;

source/pytorch.js

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2914,9 +2914,13 @@ pytorch.Execution = class extends python.Execution {
29142914
}
29152915
throw new pytorch.Error();
29162916
}
2917-
const value = this.variable(v);
2918-
value.value = v;
2919-
node.addInput(value);
2917+
if (v instanceof torch.Value) {
2918+
node.addInput(v);
2919+
} else {
2920+
const value = this.variable(v);
2921+
value.value = v;
2922+
node.addInput(value);
2923+
}
29202924
}
29212925
}
29222926
for (const arg of schema.returns) {

0 commit comments

Comments
 (0)