Skip to content

Commit

Permalink
Update pytorch.js (#1061)
Browse files Browse the repository at this point in the history
  • Loading branch information
lutzroeder committed Nov 10, 2024
1 parent f49812a commit a420f86
Show file tree
Hide file tree
Showing 5 changed files with 82 additions and 69 deletions.
19 changes: 15 additions & 4 deletions source/python.js
Original file line number Diff line number Diff line change
Expand Up @@ -7911,10 +7911,21 @@ python.Execution = class {
execution.builtins.inf = torch.inf;
execution.builtins.CONSTANTS = {};
execution._resolver = this._source_importer;
const known_types = ['__torch__.torch.classes._nnapi.Compilation'];
for (const name of known_types) {
const type = new torch.ClassType(name, this._compilation_unit, false);
type.addMethod(new torch.FunctionSchema('init(Tensor serialized_model_tensor, Tensor[] parameter_buffers) -> ()'));
const known_types = [
{ name: '__torch__.torch.classes._nnapi.Compilation', methods: ['init(Tensor serialized_model_tensor, Tensor[] parameter_buffers) -> ()'] },
{ name: '__torch__.torch.classes.quantized.Conv2dPackedParamsBase' },
{ name: '__torch__.torch.classes.quantized.Conv3dPackedParamsBase' },
{ name: '__torch__.torch.classes.quantized.LinearPackedParamsBase' },
{ name: '__torch__.torch.classes.rnn.CellParamsBase' },
{ name: '__torch__.torch.classes.xnnpack.Conv2dOpContext' },
{ name: '__torch__.torch.classes.xnnpack.LinearOpContext' },
{ name: '__torch__.torch.classes.xnnpack.TransposeConv2dOpContext' },
];
for (const known_type of known_types) {
const type = new torch.ClassType(known_type.name, this._compilation_unit, false);
for (const known_method of known_type.methods || []) {
type.addMethod(new torch.FunctionSchema(known_method));
}
this._compilation_unit.register_type(type);
}
if (this._reader.has_record('model.json')) {
Expand Down
27 changes: 20 additions & 7 deletions source/pytorch-metadata.json
Original file line number Diff line number Diff line change
Expand Up @@ -5388,25 +5388,32 @@
"name": "aten::tan_(Tensor(a!) self) -> Tensor(a!)"
},
{
"name": "aten::tanh(Tensor self) -> Tensor"
"name": "aten::tanh(Tensor self) -> Tensor",
"category": "Activation"
},
{
"name": "aten::tanh.Scalar(Scalar a) -> Scalar"
"name": "aten::tanh.Scalar(Scalar a) -> Scalar",
"category": "Activation"
},
{
"name": "aten::tanh.complex(complex a) -> complex"
"name": "aten::tanh.complex(complex a) -> complex",
"category": "Activation"
},
{
"name": "aten::tanh.float(float a) -> float"
"name": "aten::tanh.float(float a) -> float",
"category": "Activation"
},
{
"name": "aten::tanh.int(int a) -> float"
"name": "aten::tanh.int(int a) -> float",
"category": "Activation"
},
{
"name": "aten::tanh.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)"
"name": "aten::tanh.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)",
"category": "Activation"
},
{
"name": "aten::tanh_(Tensor(a!) self) -> Tensor(a!)"
"name": "aten::tanh_(Tensor(a!) self) -> Tensor(a!)",
"category": "Activation"
},
{
"name": "aten::tensor(t[] data, *, ScalarType? dtype=None, Device? device=None, bool requires_grad=False) -> Tensor"
Expand Down Expand Up @@ -5518,6 +5525,12 @@
{
"name": "aten::topk.values(Tensor self, SymInt k, int dim=-1, bool largest=True, bool sorted=True, *, Tensor(a!) values, Tensor(b!) indices) -> (Tensor(a!) values, Tensor(b!) indices)"
},
{
"name": "aten::trace(Tensor self) -> Tensor"
},
{
"name": "aten::trace.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)"
},
{
"name": "aten::transpose.Dimname(Tensor(a) self, str dim0, str dim1) -> Tensor(a)",
"category": "Transform"
Expand Down
88 changes: 32 additions & 56 deletions source/pytorch.js
Original file line number Diff line number Diff line change
Expand Up @@ -486,7 +486,7 @@ pytorch.Node = class {
const node = obj.map((obj) => new pytorch.Node(metadata, name, type.qualified_name(), obj, initializers, values));
argument = new pytorch.Argument(name, node, 'object[]');
} else {
const identifier = input.unique().toString();
const identifier = pytorch.Utility.unique(input);
const value = values.map(identifier);
argument = new pytorch.Argument(name, [value]);
}
Expand Down Expand Up @@ -2919,10 +2919,8 @@ pytorch.Execution = class extends python.Execution {
node.addInput(value);
}
}
const result = [];
for (let i = 0; i < schema.returns.length; i++) {
const arg = schema.returns[i];
const type = arg.real_type;
for (const arg of schema.returns) {
let type = arg.real_type;
switch (type.str()) {
case '__torch__.torch.classes.quantized.Conv2dPackedParamsBase':
case '__torch__.torch.classes.quantized.Conv3dPackedParamsBase':
Expand All @@ -2931,9 +2929,7 @@ pytorch.Execution = class extends python.Execution {
case '__torch__.torch.classes.xnnpack.Conv2dOpContext':
case '__torch__.torch.classes.xnnpack.LinearOpContext':
case '__torch__.torch.classes.xnnpack.TransposeConv2dOpContext': {
const value = this.invoke(type.qualified_name(), []);
this.variable(value, node);
result.push(value);
type = this._resolver.resolveType(type.qualified_name());
break;
}
case 'Tensor':
Expand All @@ -2950,69 +2946,49 @@ pytorch.Execution = class extends python.Execution {
case 'bool':
case 'bool[]':
case 'Device': {
const output = node.addOutput();
output.__origin__ = schema.name;
output.setType(type);
result.push(output);
break;
}
case 't': {
const value = this.variable(null, node);
value.__origin__ = schema.name;
const t = varTypes.map(type);
if (!t) {
type = varTypes.map(type);
if (!type) {
throw new pytorch.Error(`Unknown var type 't'.`);
}
value.setType(t);
result.push(value);
break;
}
case 't[]': {
const value = this.variable(null, node);
value.__origin__ = schema.name;
const t = varTypes.map(type.getElementType());
if (!t) {
type = varTypes.map(type.getElementType());
if (!type) {
throw new pytorch.Error();
}
value.setType(torch.ListType.get(t));
result.push(value);
type = torch.ListType.get(type);
break;
}
default: {
if (type instanceof torch.DictType) {
const value = node.addOutput();
value.__origin__ = schema.name;
const keyType = varTypes.map(type.getKeyType());
const valueType = varTypes.map(type.getValueType());
value.setType(torch.DictType.get(keyType, valueType));
result.push(value);
break;
}
if (type instanceof torch.TupleType && type.elements().length === 2) {
const value = node.addOutput();
value.__origin__ = schema.name;
const keyType = varTypes.map(type.elements()[0]);
const valueType = varTypes.map(type.elements()[1]);
value.setType(torch.ListType.get(torch.TupleType.get([keyType, valueType])));
result.push(value);
break;
type = torch.DictType.get(keyType, valueType);
} else if (type instanceof torch.TupleType && type.elements().length === 2) {
const elements = type.elements().map((type) => varTypes.map(type));
type = torch.ListType.get(torch.TupleType.get(elements));
} else if (type instanceof torch.ListType && type.getElementType() instanceof torch.TupleType) {
const elements = type.getElementType().elements().map((type) => varTypes.map(type));
type = torch.ListType.get(torch.TupleType.get(elements));
} else {
throw new pytorch.Error(`Unsupported return type '${type.str()}'.`);
}
const output = this.invoke('torch.Tensor', []);
output.resize_([]);
output.__origin__ = schema.name;
this.variable(output, node);
result.push(output);
break;
}
}
const output = node.addOutput();
output.__origin__ = schema.name;
output.setType(type);
}
for (const referencedParameter of referencedParameters) {
referencedParameter.__count__ = (referencedParameter.__count__ || 0) + 1;
}
if (result.length > 1) {
return result;
}
return result[0];
const outputs = node.outputs();
return outputs.length > 1 ? outputs : outputs[0];
}

isType(obj, type, N) {
Expand Down Expand Up @@ -3134,9 +3110,13 @@ pytorch.Execution = class extends python.Execution {
case 't2':
return true;
default: {
if (type instanceof torch.ClassType &&
obj && obj.__class__ && obj.__class__.__module__ && obj.__class__.__name__) {
return type.qualified_name() === `${obj.__class__.__module__}.${obj.__class__.__name__}`;
if (type instanceof torch.ClassType) {
if (obj instanceof torch.Value && obj.type() instanceof torch.ClassType) {
return type.qualified_name() === obj.type().qualified_name();
}
if (obj && obj.__class__ && obj.__class__.__module__ && obj.__class__.__name__) {
return type.qualified_name() === `${obj.__class__.__module__}.${obj.__class__.__name__}`;
}
}
if (type instanceof torch.TupleType) {
throw new pytorch.Error('Not implemented.');
Expand Down Expand Up @@ -3512,7 +3492,8 @@ pytorch.Utility = class {
return value.hasDebugName() ? `%${value.debugName().toString()}` : `%${value.unique().toString()}`;
}

static isObjectType(type) {
static isObject(obj) {
const type = obj && obj.__class__ && obj.__class__.__module__ && obj.__class__.__name__ ? `${obj.__class__.__module__}.${obj.__class__.__name__}` : null;
switch (type) {
case '__torch__.torch.classes.xnnpack.LinearOpContext':
case '__torch__.torch.classes.xnnpack.Conv2dOpContext':
Expand All @@ -3528,11 +3509,6 @@ pytorch.Utility = class {
}
}

static isObject(obj) {
const type = obj && obj.__class__ && obj.__class__.__module__ && obj.__class__.__name__ ? `${obj.__class__.__module__}.${obj.__class__.__name__}` : null;
return pytorch.Utility.isObjectType(type);
}

static isSubclass(value, name) {
if (value && value.__module__ && value.__name__) {
return name === `${value.__module__}.${value.__name__}`;
Expand Down
16 changes: 14 additions & 2 deletions source/view.js
Original file line number Diff line number Diff line change
Expand Up @@ -2915,9 +2915,18 @@ view.ArgumentView = class extends view.Control {
this._source = 'attribute';
}
if (argument.type === 'tensor' || argument.type === 'tensor?') {
value = [value === null ? value : { type: value.type, initializer: value }];
if (value === null || (value && value.constructor && value.constructor.name === 'Value')) {
value = [value];
} else {
value = [{ type: value.type, initializer: value }];
}
} else if (argument.type === 'tensor[]' || argument.type === 'tensor?[]') {
value = value.map((value) => value === null ? value : { type: value.type, initializer: value });
value = value.map((value) => {
if (value === null || (value && value.constructor && value.constructor.name === 'Value')) {
return value;
}
return { type: value.type, initializer: value };
});
}
this._source = typeof type === 'string' && !type.endsWith('*') ? 'attribute' : this._source;
if (this._source === 'attribute' && type !== 'tensor' && type !== 'tensor?' && type !== 'tensor[]' && type !== 'tensor?[]') {
Expand Down Expand Up @@ -3064,6 +3073,9 @@ view.ValueView = class extends view.Expander {
super(context);
this._value = value;
try {
if (value && value.constructor && value.constructor.name === 'Value' && source === 'attribute') {
source = '';
}
const type = this._value.type;
const initializer = this._value.initializer;
const quantization = this._value.quantization;
Expand Down
1 change: 1 addition & 0 deletions test/models.json
Original file line number Diff line number Diff line change
Expand Up @@ -5852,6 +5852,7 @@
"target": "netron_issue_677.pt",
"source": "https://github.com/lutzroeder/netron/files/5923252/netron_issue_677.pt.zip[netron_issue_677.pt]",
"format": "TorchScript v1.6",
"assert": "model.graphs[0].nodes.length == 5",
"link": "https://github.com/lutzroeder/netron/issues/677"
},
{
Expand Down

0 comments on commit a420f86

Please sign in to comment.