From a420f864e8e1ca47642cd5665fea6e232e349613 Mon Sep 17 00:00:00 2001 From: Lutz Roeder Date: Sun, 10 Nov 2024 10:19:38 -0800 Subject: [PATCH] Update pytorch.js (#1061) --- source/python.js | 19 ++++++-- source/pytorch-metadata.json | 27 ++++++++--- source/pytorch.js | 88 +++++++++++++----------------------- source/view.js | 16 ++++++- test/models.json | 1 + 5 files changed, 82 insertions(+), 69 deletions(-) diff --git a/source/python.js b/source/python.js index 1041b2d746..7ff0e1a9e0 100644 --- a/source/python.js +++ b/source/python.js @@ -7911,10 +7911,21 @@ python.Execution = class { execution.builtins.inf = torch.inf; execution.builtins.CONSTANTS = {}; execution._resolver = this._source_importer; - const known_types = ['__torch__.torch.classes._nnapi.Compilation']; - for (const name of known_types) { - const type = new torch.ClassType(name, this._compilation_unit, false); - type.addMethod(new torch.FunctionSchema('init(Tensor serialized_model_tensor, Tensor[] parameter_buffers) -> ()')); + const known_types = [ + { name: '__torch__.torch.classes._nnapi.Compilation', methods: ['init(Tensor serialized_model_tensor, Tensor[] parameter_buffers) -> ()'] }, + { name: '__torch__.torch.classes.quantized.Conv2dPackedParamsBase' }, + { name: '__torch__.torch.classes.quantized.Conv3dPackedParamsBase' }, + { name: '__torch__.torch.classes.quantized.LinearPackedParamsBase' }, + { name: '__torch__.torch.classes.rnn.CellParamsBase' }, + { name: '__torch__.torch.classes.xnnpack.Conv2dOpContext' }, + { name: '__torch__.torch.classes.xnnpack.LinearOpContext' }, + { name: '__torch__.torch.classes.xnnpack.TransposeConv2dOpContext' }, + ]; + for (const known_type of known_types) { + const type = new torch.ClassType(known_type.name, this._compilation_unit, false); + for (const known_method of known_type.methods || []) { + type.addMethod(new torch.FunctionSchema(known_method)); + } this._compilation_unit.register_type(type); } if (this._reader.has_record('model.json')) { diff --git a/source/pytorch-metadata.json b/source/pytorch-metadata.json index 86625d8e1a..5f868b2221 100755 --- a/source/pytorch-metadata.json +++ b/source/pytorch-metadata.json @@ -5388,25 +5388,32 @@ "name": "aten::tan_(Tensor(a!) self) -> Tensor(a!)" }, { - "name": "aten::tanh(Tensor self) -> Tensor" + "name": "aten::tanh(Tensor self) -> Tensor", + "category": "Activation" }, { - "name": "aten::tanh.Scalar(Scalar a) -> Scalar" + "name": "aten::tanh.Scalar(Scalar a) -> Scalar", + "category": "Activation" }, { - "name": "aten::tanh.complex(complex a) -> complex" + "name": "aten::tanh.complex(complex a) -> complex", + "category": "Activation" }, { - "name": "aten::tanh.float(float a) -> float" + "name": "aten::tanh.float(float a) -> float", + "category": "Activation" }, { - "name": "aten::tanh.int(int a) -> float" + "name": "aten::tanh.int(int a) -> float", + "category": "Activation" }, { - "name": "aten::tanh.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)" + "name": "aten::tanh.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)", + "category": "Activation" }, { - "name": "aten::tanh_(Tensor(a!) self) -> Tensor(a!)" + "name": "aten::tanh_(Tensor(a!) self) -> Tensor(a!)", + "category": "Activation" }, { "name": "aten::tensor(t[] data, *, ScalarType? dtype=None, Device? device=None, bool requires_grad=False) -> Tensor" @@ -5518,6 +5525,12 @@ { "name": "aten::topk.values(Tensor self, SymInt k, int dim=-1, bool largest=True, bool sorted=True, *, Tensor(a!) values, Tensor(b!) indices) -> (Tensor(a!) values, Tensor(b!) indices)" }, + { + "name": "aten::trace(Tensor self) -> Tensor" + }, + { + "name": "aten::trace.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)" + }, { "name": "aten::transpose.Dimname(Tensor(a) self, str dim0, str dim1) -> Tensor(a)", "category": "Transform" diff --git a/source/pytorch.js b/source/pytorch.js index 3a2695425b..6bbbe2a7c6 100644 --- a/source/pytorch.js +++ b/source/pytorch.js @@ -486,7 +486,7 @@ pytorch.Node = class { const node = obj.map((obj) => new pytorch.Node(metadata, name, type.qualified_name(), obj, initializers, values)); argument = new pytorch.Argument(name, node, 'object[]'); } else { - const identifier = input.unique().toString(); + const identifier = pytorch.Utility.unique(input); const value = values.map(identifier); argument = new pytorch.Argument(name, [value]); } @@ -2919,10 +2919,8 @@ pytorch.Execution = class extends python.Execution { node.addInput(value); } } - const result = []; - for (let i = 0; i < schema.returns.length; i++) { - const arg = schema.returns[i]; - const type = arg.real_type; + for (const arg of schema.returns) { + let type = arg.real_type; switch (type.str()) { case '__torch__.torch.classes.quantized.Conv2dPackedParamsBase': case '__torch__.torch.classes.quantized.Conv3dPackedParamsBase': @@ -2931,9 +2929,7 @@ pytorch.Execution = class extends python.Execution { case '__torch__.torch.classes.xnnpack.Conv2dOpContext': case '__torch__.torch.classes.xnnpack.LinearOpContext': case '__torch__.torch.classes.xnnpack.TransposeConv2dOpContext': { - const value = this.invoke(type.qualified_name(), []); - this.variable(value, node); - result.push(value); + type = this._resolver.resolveType(type.qualified_name()); break; } case 'Tensor': @@ -2950,69 +2946,49 @@ pytorch.Execution = class extends python.Execution { case 'bool': case 'bool[]': case 'Device': { - const output = node.addOutput(); - output.__origin__ = schema.name; - output.setType(type); - result.push(output); break; } case 't': { - const value = this.variable(null, node); - value.__origin__ = schema.name; - const t = varTypes.map(type); - if (!t) { + type = varTypes.map(type); + if (!type) { throw new pytorch.Error(`Unknown var type 't'.`); } - value.setType(t); - result.push(value); break; } case 't[]': { - const value = this.variable(null, node); - value.__origin__ = schema.name; - const t = varTypes.map(type.getElementType()); - if (!t) { + type = varTypes.map(type.getElementType()); + if (!type) { throw new pytorch.Error(); } - value.setType(torch.ListType.get(t)); - result.push(value); + type = torch.ListType.get(type); break; } default: { if (type instanceof torch.DictType) { - const value = node.addOutput(); - value.__origin__ = schema.name; const keyType = varTypes.map(type.getKeyType()); const valueType = varTypes.map(type.getValueType()); - value.setType(torch.DictType.get(keyType, valueType)); - result.push(value); - break; - } - if (type instanceof torch.TupleType && type.elements().length === 2) { - const value = node.addOutput(); - value.__origin__ = schema.name; - const keyType = varTypes.map(type.elements()[0]); - const valueType = varTypes.map(type.elements()[1]); - value.setType(torch.ListType.get(torch.TupleType.get([keyType, valueType]))); - result.push(value); - break; + type = torch.DictType.get(keyType, valueType); + } else if (type instanceof torch.TupleType && type.elements().length === 2) { + const elements = type.elements().map((type) => varTypes.map(type)); + type = torch.ListType.get(torch.TupleType.get(elements)); + } else if (type instanceof torch.ListType && type.getElementType() instanceof torch.TupleType) { + const elements = type.getElementType().elements().map((type) => varTypes.map(type)); + type = torch.ListType.get(torch.TupleType.get(elements)); + } else { + throw new pytorch.Error(`Unsupported return type '${type.str()}'.`); } - const output = this.invoke('torch.Tensor', []); - output.resize_([]); - output.__origin__ = schema.name; - this.variable(output, node); - result.push(output); break; } } + const output = node.addOutput(); + output.__origin__ = schema.name; + output.setType(type); } for (const referencedParameter of referencedParameters) { referencedParameter.__count__ = (referencedParameter.__count__ || 0) + 1; } - if (result.length > 1) { - return result; - } - return result[0]; + const outputs = node.outputs(); + return outputs.length > 1 ? outputs : outputs[0]; } isType(obj, type, N) { @@ -3134,9 +3110,13 @@ pytorch.Execution = class extends python.Execution { case 't2': return true; default: { - if (type instanceof torch.ClassType && - obj && obj.__class__ && obj.__class__.__module__ && obj.__class__.__name__) { - return type.qualified_name() === `${obj.__class__.__module__}.${obj.__class__.__name__}`; + if (type instanceof torch.ClassType) { + if (obj instanceof torch.Value && obj.type() instanceof torch.ClassType) { + return type.qualified_name() === obj.type().qualified_name(); + } + if (obj && obj.__class__ && obj.__class__.__module__ && obj.__class__.__name__) { + return type.qualified_name() === `${obj.__class__.__module__}.${obj.__class__.__name__}`; + } } if (type instanceof torch.TupleType) { throw new pytorch.Error('Not implemented.'); @@ -3512,7 +3492,8 @@ pytorch.Utility = class { return value.hasDebugName() ? `%${value.debugName().toString()}` : `%${value.unique().toString()}`; } - static isObjectType(type) { + static isObject(obj) { + const type = obj && obj.__class__ && obj.__class__.__module__ && obj.__class__.__name__ ? `${obj.__class__.__module__}.${obj.__class__.__name__}` : null; switch (type) { case '__torch__.torch.classes.xnnpack.LinearOpContext': case '__torch__.torch.classes.xnnpack.Conv2dOpContext': @@ -3528,11 +3509,6 @@ pytorch.Utility = class { } } - static isObject(obj) { - const type = obj && obj.__class__ && obj.__class__.__module__ && obj.__class__.__name__ ? `${obj.__class__.__module__}.${obj.__class__.__name__}` : null; - return pytorch.Utility.isObjectType(type); - } - static isSubclass(value, name) { if (value && value.__module__ && value.__name__) { return name === `${value.__module__}.${value.__name__}`; diff --git a/source/view.js b/source/view.js index 11cb733e74..ea04980bd1 100644 --- a/source/view.js +++ b/source/view.js @@ -2915,9 +2915,18 @@ view.ArgumentView = class extends view.Control { this._source = 'attribute'; } if (argument.type === 'tensor' || argument.type === 'tensor?') { - value = [value === null ? value : { type: value.type, initializer: value }]; + if (value === null || (value && value.constructor && value.constructor.name === 'Value')) { + value = [value]; + } else { + value = [{ type: value.type, initializer: value }]; + } } else if (argument.type === 'tensor[]' || argument.type === 'tensor?[]') { - value = value.map((value) => value === null ? value : { type: value.type, initializer: value }); + value = value.map((value) => { + if (value === null || (value && value.constructor && value.constructor.name === 'Value')) { + return value; + } + return { type: value.type, initializer: value }; + }); } this._source = typeof type === 'string' && !type.endsWith('*') ? 'attribute' : this._source; if (this._source === 'attribute' && type !== 'tensor' && type !== 'tensor?' && type !== 'tensor[]' && type !== 'tensor?[]') { @@ -3064,6 +3073,9 @@ view.ValueView = class extends view.Expander { super(context); this._value = value; try { + if (value && value.constructor && value.constructor.name === 'Value' && source === 'attribute') { + source = ''; + } const type = this._value.type; const initializer = this._value.initializer; const quantization = this._value.quantization; diff --git a/test/models.json b/test/models.json index fbf9a60e3d..9f74a448f3 100644 --- a/test/models.json +++ b/test/models.json @@ -5852,6 +5852,7 @@ "target": "netron_issue_677.pt", "source": "https://github.com/lutzroeder/netron/files/5923252/netron_issue_677.pt.zip[netron_issue_677.pt]", "format": "TorchScript v1.6", + "assert": "model.graphs[0].nodes.length == 5", "link": "https://github.com/lutzroeder/netron/issues/677" }, {