Skip to content

Commit

Permalink
Update python.js (#1061)
Browse files Browse the repository at this point in the history
  • Loading branch information
lutzroeder committed Dec 19, 2024
1 parent 6a7bd77 commit 2ed1782
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 2 deletions.
5 changes: 5 additions & 0 deletions source/python.js
Original file line number Diff line number Diff line change
Expand Up @@ -9187,6 +9187,10 @@ python.Execution = class {
if (t.is_module()) {
module.__setattr__(k, convertModule(v));
} else {
if (t instanceof torch.TensorType && v && v.__class__ && v instanceof torch.Tensor === false && v.__class__.__module__ === '__torch__.torch.classes.quantized') {
const name = `${v.__class__.__module__}.${v.__class__.__name__}`;
type._attributes[i].type = this._source_importer.resolveType(name);
}
module.__setattr__(k, obj[k]);
}
}
Expand Down Expand Up @@ -12066,6 +12070,7 @@ python.Execution = class {
this.registerType('fastai.vision.models.unet.DynamicUnet', class {});
this.registerType('fastai.vision.models.unet.ResizeToOrig', class {});
this.registerType('fastai.vision.models.unet.UnetBlock', class {});
this.registerType('fastai.vision.models.xresnet.XResNet', class {});
this.registerFunction('fastai.vision.transform._crop_pad');
}

Expand Down
15 changes: 13 additions & 2 deletions source/pytorch.js
Original file line number Diff line number Diff line change
Expand Up @@ -1740,6 +1740,7 @@ pytorch.Execution = class extends python.Execution {

target(expr, context) {
const ast = this.ast;
const torch = this.torch;
if (expr instanceof ast.Name) {
switch (expr.id) {
case 'torch':
Expand Down Expand Up @@ -1783,7 +1784,13 @@ pytorch.Execution = class extends python.Execution {
for (let i = path.length - 1; i >= 0; i--) {
const name = path[i];
if (target) {
target = target.__getattr__ ? target.__getattr__(name) : target[name];
if (target instanceof torch.Value && target.type() instanceof torch.ClassType) {
const node = this._graph.createGetAttr(target, name);
this._graph.insertNode(node);
target = node.output();
} else {
target = target.__getattr__ ? target.__getattr__(name) : target[name];
}
} else {
target = context.get(name);
}
Expand All @@ -1800,6 +1807,9 @@ pytorch.Execution = class extends python.Execution {
}
return this.resolve(name);
}
if (target instanceof torch.Value) {
return target;
}
}
return super.target(expr, context);
}
Expand Down Expand Up @@ -3063,7 +3073,8 @@ pytorch.Execution = class extends python.Execution {
return (Array.isArray(obj) && obj.every((item) => item === null || typeof item === 'string')) ||
(obj instanceof torch.Value && obj.type() instanceof torch.ListType && obj.type().getElementType() instanceof torch.StringType);
case 'str[][]':
return Array.isArray(obj) && obj.every((item) => Array.isArray(item) && item.every((item) => typeof item === 'string'));
return Array.isArray(obj) && obj.every((item) => Array.isArray(item) && item.every((item) => typeof item === 'string')) ||
(obj instanceof torch.Value && obj.type() instanceof torch.ListType && obj.type().getElementType() instanceof torch.ListType && obj.type().getElementType().getElementType() instanceof torch.StringType);
case 'Layout':
case 'ScalarType':
case 'MemoryFormat':
Expand Down

0 comments on commit 2ed1782

Please sign in to comment.