Skip to content

Commit e02cf71

Browse files
committed
Update TorchScript test files (#1067)
1 parent b4e6833 commit e02cf71

File tree

4 files changed

+269
-36
lines changed

4 files changed

+269
-36
lines changed

source/pytorch-metadata.json

Lines changed: 228 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,14 @@
11
[
2+
{
3+
"name": "__torch__.torch.classes.rnn.CellParamsBase",
4+
"inputs": [
5+
{ "name": "type", "type": "string" },
6+
{ "name": "tensors", "type": "Tensor[]" },
7+
{ "name": "doubles", "type": "float64[]" },
8+
{ "name": "longs", "type": "int64[]" },
9+
{ "name": "packed_params", "type": "__torch__.torch.classes.quantized.LinearPackedParamsBase[]" }
10+
]
11+
},
212
{
313
"name": "__torch__.torch.classes.xnnpack.Conv2dOpContext",
414
"inputs": [
@@ -10400,17 +10410,94 @@
1040010410
]
1040110411
},
1040210412
{
10403-
"name": "aten::quantized_lstm",
10413+
"name": "aten::quantized_gru.data",
10414+
"category": "Layer",
10415+
"inputs": [
10416+
{ "name": "data", "type": "Tensor" },
10417+
{ "name": "batch_sizes", "type": "Tensor" },
10418+
{ "name": "hx", "type": "Tensor" },
10419+
{ "name": "params", "type": "__torch__.torch.classes.rnn.CellParamsBase[]" },
10420+
{ "name": "has_biases", "type": "boolean" },
10421+
{ "name": "num_layers", "type": "int64" },
10422+
{ "name": "dropout", "type": "float32" },
10423+
{ "name": "train", "type": "boolean" },
10424+
{ "name": "bidirectional", "type": "boolean" }
10425+
],
10426+
"outputs": [
10427+
{ "type": "Tensor" },
10428+
{ "type": "Tensor" }
10429+
]
10430+
},
10431+
{
10432+
"name": "aten::quantized_gru.data_legacy",
10433+
"category": "Layer",
10434+
"inputs": [
10435+
{ "name": "data", "type": "Tensor" },
10436+
{ "name": "batch_sizes", "type": "Tensor" },
10437+
{ "name": "hx", "type": "Tensor" },
10438+
{ "name": "params", "type": "Tensor[]" },
10439+
{ "name": "has_biases", "type": "boolean" },
10440+
{ "name": "num_layers", "type": "int64" },
10441+
{ "name": "dropout", "type": "float32" },
10442+
{ "name": "train", "type": "boolean" },
10443+
{ "name": "bidirectional", "type": "boolean" }
10444+
],
10445+
"outputs": [
10446+
{ "type": "Tensor" },
10447+
{ "type": "Tensor" }
10448+
]
10449+
},
10450+
{
10451+
"name": "aten::quantized_gru.input",
10452+
"category": "Layer",
1040410453
"inputs": [
1040510454
{ "name": "input", "type": "Tensor" },
10406-
{ "name": "hx", "type": "Tensor[]" },
10455+
{ "name": "hx", "type": "Tensor" },
10456+
{ "name": "params", "type": "__torch__.torch.classes.rnn.CellParamsBase[]" },
10457+
{ "name": "has_biases", "type": "boolean" },
10458+
{ "name": "num_layers", "type": "int64" },
10459+
{ "name": "dropout", "type": "float32" },
10460+
{ "name": "train", "type": "boolean" },
10461+
{ "name": "bidirectional", "type": "boolean" },
10462+
{ "name": "batch_first", "type": "boolean" }
10463+
],
10464+
"outputs": [
10465+
{ "type": "Tensor" },
10466+
{ "type": "Tensor" }
10467+
]
10468+
},
10469+
{
10470+
"name": "aten::quantized_gru.input_legacy",
10471+
"category": "Layer",
10472+
"inputs": [
10473+
{ "name": "input", "type": "Tensor" },
10474+
{ "name": "hx", "type": "Tensor" },
1040710475
{ "name": "params", "type": "Tensor[]" },
1040810476
{ "name": "has_biases", "type": "boolean" },
1040910477
{ "name": "num_layers", "type": "int64" },
1041010478
{ "name": "dropout", "type": "float32" },
1041110479
{ "name": "train", "type": "boolean" },
1041210480
{ "name": "bidirectional", "type": "boolean" },
10413-
{ "name": "batch_first", "type": "boolean" },
10481+
{ "name": "batch_first", "type": "boolean" }
10482+
],
10483+
"outputs": [
10484+
{ "type": "Tensor" },
10485+
{ "type": "Tensor" }
10486+
]
10487+
},
10488+
{
10489+
"name": "aten::quantized_lstm.data",
10490+
"category": "Layer",
10491+
"inputs": [
10492+
{ "name": "data", "type": "Tensor" },
10493+
{ "name": "batch_sizes", "type": "Tensor" },
10494+
{ "name": "hx", "type": "Tensor[]" },
10495+
{ "name": "params", "type": "__torch__.torch.classes.rnn.CellParamsBase[]" },
10496+
{ "name": "has_biases", "type": "boolean" },
10497+
{ "name": "num_layers", "type": "int64" },
10498+
{ "name": "dropout", "type": "float32" },
10499+
{ "name": "train", "type": "boolean" },
10500+
{ "name": "bidirectional", "type": "boolean" },
1041410501
{ "name": "dtype", "type": "ScalarType", "optional": true, "default": null },
1041510502
{ "name": "use_dynamic", "type": "boolean", "default": false }
1041610503
],
@@ -10421,7 +10508,8 @@
1042110508
]
1042210509
},
1042310510
{
10424-
"name": "aten::quantized_lstm.data",
10511+
"name": "aten::quantized_lstm.data_legacy",
10512+
"category": "Layer",
1042510513
"inputs": [
1042610514
{ "name": "data", "type": "Tensor" },
1042710515
{ "name": "batch_sizes", "type": "Tensor" },
@@ -10441,6 +10529,50 @@
1044110529
{ "type": "Tensor" }
1044210530
]
1044310531
},
10532+
{
10533+
"name": "aten::quantized_lstm.input",
10534+
"category": "Layer",
10535+
"inputs": [
10536+
{ "name": "input", "type": "Tensor" },
10537+
{ "name": "hx", "type": "Tensor[]" },
10538+
{ "name": "params", "type": "__torch__.torch.classes.rnn.CellParamsBase[]" },
10539+
{ "name": "has_biases", "type": "boolean" },
10540+
{ "name": "num_layers", "type": "int64" },
10541+
{ "name": "dropout", "type": "float32" },
10542+
{ "name": "train", "type": "boolean" },
10543+
{ "name": "bidirectional", "type": "boolean" },
10544+
{ "name": "batch_first", "type": "boolean" },
10545+
{ "name": "dtype", "type": "ScalarType", "optional": true, "default": null },
10546+
{ "name": "use_dynamic", "type": "boolean", "default": false }
10547+
],
10548+
"outputs": [
10549+
{ "type": "Tensor" },
10550+
{ "type": "Tensor" },
10551+
{ "type": "Tensor" }
10552+
]
10553+
},
10554+
{
10555+
"name": "aten::quantized_lstm.input_legacy",
10556+
"category": "Layer",
10557+
"inputs": [
10558+
{ "name": "input", "type": "Tensor" },
10559+
{ "name": "hx", "type": "Tensor[]" },
10560+
{ "name": "params", "type": "Tensor[]" },
10561+
{ "name": "has_biases", "type": "boolean" },
10562+
{ "name": "num_layers", "type": "int64" },
10563+
{ "name": "dropout", "type": "float32" },
10564+
{ "name": "train", "type": "boolean" },
10565+
{ "name": "bidirectional", "type": "boolean" },
10566+
{ "name": "batch_first", "type": "boolean" },
10567+
{ "name": "dtype", "type": "ScalarType", "optional": true, "default": null },
10568+
{ "name": "use_dynamic", "type": "boolean", "default": false }
10569+
],
10570+
"outputs": [
10571+
{ "type": "Tensor" },
10572+
{ "type": "Tensor" },
10573+
{ "type": "Tensor" }
10574+
]
10575+
},
1044410576
{
1044510577
"name": "aten::quantized_lstm_cell",
1044610578
"inputs": [
@@ -15742,6 +15874,41 @@
1574215874
{ "name": "Y", "type": "Tensor" }
1574315875
]
1574415876
},
15877+
{
15878+
"name": "quantized::make_quantized_cell_params",
15879+
"inputs": [
15880+
{ "name": "w_ih", "type": "Tensor" },
15881+
{ "name": "w_hh", "type": "Tensor" },
15882+
{ "name": "b_ih", "type": "Tensor" },
15883+
{ "name": "b_hh", "type": "Tensor" }
15884+
],
15885+
"outputs": [
15886+
{ "type": "__torch__.torch.classes.rnn.CellParamsBase" }
15887+
]
15888+
},
15889+
{
15890+
"name": "quantized::make_quantized_cell_params_dynamic",
15891+
"inputs": [
15892+
{ "name": "w_ih", "type": "__torch__.torch.classes.quantized.LinearPackedParamsBase" },
15893+
{ "name": "w_hh", "type": "__torch__.torch.classes.quantized.LinearPackedParamsBase" },
15894+
{ "name": "bias_ih", "type": "Tensor" },
15895+
{ "name": "bias_hh", "type": "Tensor" },
15896+
{ "name": "reduce_range", "type": "boolean", "default": false }
15897+
],
15898+
"outputs": [
15899+
{ "type": "__torch__.torch.classes.rnn.CellParamsBase" }
15900+
]
15901+
},
15902+
{
15903+
"name": "quantized::make_quantized_cell_params_fp16",
15904+
"inputs": [
15905+
{ "name": "w_ih", "type": "__torch__.torch.classes.quantized.LinearPackedParamsBase" },
15906+
{ "name": "w_hh", "type": "__torch__.torch.classes.quantized.LinearPackedParamsBase" }
15907+
],
15908+
"outputs": [
15909+
{ "type": "__torch__.torch.classes.rnn.CellParamsBase" }
15910+
]
15911+
},
1574515912
{
1574615913
"name": "quantized::mul",
1574715914
"inputs": [
@@ -15968,6 +16135,63 @@
1596816135
{ "type": "Tensor" }
1596916136
]
1597016137
},
16138+
{
16139+
"name": "quantized::quantized_gru_cell_dynamic",
16140+
"inputs": [
16141+
{ "name": "input", "type": "Tensor" },
16142+
{ "name": "hx", "type": "Tensor" },
16143+
{ "name": "w_ih", "type": "__torch__.torch.classes.quantized.LinearPackedParamsBase" },
16144+
{ "name": "w_hh", "type": "__torch__.torch.classes.quantized.LinearPackedParamsBase" },
16145+
{ "name": "b_ih", "type": "Tensor" },
16146+
{ "name": "b_hh", "type": "Tensor" }
16147+
],
16148+
"outputs": [
16149+
{ "type": "Tensor" }
16150+
]
16151+
},
16152+
{
16153+
"name": "quantized::quantized_lstm_cell_dynamic",
16154+
"inputs": [
16155+
{ "name": "input", "type": "Tensor" },
16156+
{ "name": "hx", "type": "Tensor[]" },
16157+
{ "name": "w_ih", "type": "__torch__.torch.classes.quantized.LinearPackedParamsBase" },
16158+
{ "name": "w_hh", "type": "__torch__.torch.classes.quantized.LinearPackedParamsBase" },
16159+
{ "name": "bias_ih", "type": "Tensor" },
16160+
{ "name": "bias_hh", "type": "Tensor" }
16161+
],
16162+
"outputs": [
16163+
{ "type": "Tensor" },
16164+
{ "type": "Tensor" }
16165+
]
16166+
},
16167+
{
16168+
"name": "quantized::quantized_rnn_relu_cell_dynamic",
16169+
"inputs": [
16170+
{ "name": "input", "type": "Tensor" },
16171+
{ "name": "hx", "type": "Tensor" },
16172+
{ "name": "w_ih", "type": "__torch__.torch.classes.quantized.LinearPackedParamsBase" },
16173+
{ "name": "w_hh", "type": "__torch__.torch.classes.quantized.LinearPackedParamsBase" },
16174+
{ "name": "b_ih", "type": "Tensor" },
16175+
{ "name": "b_hh", "type": "Tensor" }
16176+
],
16177+
"outputs": [
16178+
{ "type": "Tensor" }
16179+
]
16180+
},
16181+
{
16182+
"name": "quantized::quantized_rnn_tanh_cell_dynamic",
16183+
"inputs": [
16184+
{ "name": "input", "type": "Tensor" },
16185+
{ "name": "hx", "type": "Tensor" },
16186+
{ "name": "w_ih", "type": "__torch__.torch.classes.quantized.LinearPackedParamsBase" },
16187+
{ "name": "w_hh", "type": "__torch__.torch.classes.quantized.LinearPackedParamsBase" },
16188+
{ "name": "b_ih", "type": "Tensor" },
16189+
{ "name": "b_hh", "type": "Tensor" }
16190+
],
16191+
"outputs": [
16192+
{ "type": "Tensor" }
16193+
]
16194+
},
1597116195
{
1597216196
"name": "quantized::relu6",
1597316197
"category": "Activation",

source/pytorch.js

Lines changed: 23 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -145,7 +145,7 @@ pytorch.Graph = class {
145145
return submodules;
146146
};
147147
const loadScriptModule = (module, initializers) => {
148-
if (module) {
148+
if (module && !pytorch.Utility.isObject(module)) {
149149
if (pytorch.Graph._getParameters(module).size > 0 && !module.__hide__) {
150150
const item = { module };
151151
this.nodes.push(new pytorch.Node(metadata, '', item, initializers, values));
@@ -527,13 +527,21 @@ pytorch.Node = class {
527527
const input = inputs[i];
528528
const schema = this.type && this.type.inputs && i < this.type.inputs.length ? this.type.inputs[i] : null;
529529
const name = schema && schema.name ? schema.name : i.toString();
530-
const type = schema && schema.type ? schema.type : null;
530+
let type = schema && schema.type ? schema.type : null;
531+
let array = false;
532+
if (type && type.endsWith('[]')) {
533+
array = true;
534+
type = type.slice(0, -2);
535+
}
531536
let argument = null;
532537
if (pytorch.Utility.isObjectType(type)) {
533538
const obj = input.value;
534-
if (initializers.has(obj)) {
539+
if (!array && initializers.has(obj)) {
535540
const node = new pytorch.Node(metadata, group, { name, type, obj }, initializers, values);
536541
argument = new pytorch.Argument(name, node, 'object');
542+
} else if (array && Array.isArray(obj) && obj.every((obj) => initializers.has(obj))) {
543+
const node = obj.map((obj) => new pytorch.Node(metadata, group, { name, type, obj }, initializers, values));
544+
argument = new pytorch.Argument(name, node, 'object[]');
537545
} else {
538546
const identifier = input.unique().toString();
539547
const value = values.map(identifier);
@@ -1799,6 +1807,11 @@ pytorch.jit.Execution = class extends pytorch.Execution {
17991807
[this.weight, this.bias] = state;
18001808
}
18011809
});
1810+
this.registerType('__torch__.torch.classes.rnn.CellParamsBase', class {
1811+
__setstate__(state) {
1812+
[this.type, this.tensors, this.doubles, this.longs, this.packed_params] = state;
1813+
}
1814+
});
18021815
this.registerType('__torch__.torch.classes.xnnpack.Conv2dOpContext', class {
18031816
__setstate__(state) {
18041817
[this.weight, this.bias, this.stride, this.padding, this.dilation, this.groups, this.output_min, this.output_max] = state;
@@ -2137,33 +2150,9 @@ pytorch.jit.Execution = class extends pytorch.Execution {
21372150
} else {
21382151
copyArgs.shift();
21392152
copyEvalArgs.shift();
2140-
switch (parameter.type) {
2141-
case '__torch__.torch.classes.quantized.Conv2dPackedParamsBase':
2142-
case '__torch__.torch.classes.quantized.Conv3dPackedParamsBase':
2143-
case '__torch__.torch.classes.quantized.LinearPackedParamsBase':
2144-
case '__torch__.torch.classes.xnnpack.Conv2dOpContext':
2145-
case '__torch__.torch.classes.xnnpack.LinearOpContext':
2146-
case '__torch__.torch.classes.xnnpack.TransposeConv2dOpContext': {
2147-
const value = this.variable(argument);
2148-
value.value = argument;
2149-
node.addInput(value);
2150-
/*
2151-
for (const [, value] of Object.entries(argument)) {
2152-
if (pytorch.Utility.isTensor(value)) {
2153-
const tensor = value;
2154-
referencedParameters.push(tensor);
2155-
}
2156-
}
2157-
*/
2158-
break;
2159-
}
2160-
default: {
2161-
const value = this.variable(argument);
2162-
node.addInput(value);
2163-
value.value = argument;
2164-
break;
2165-
}
2166-
}
2153+
const value = this.variable(argument);
2154+
node.addInput(value);
2155+
value.value = argument;
21672156
}
21682157
}
21692158
}
@@ -2416,6 +2405,7 @@ pytorch.jit.Execution = class extends pytorch.Execution {
24162405
case '__torch__.torch.classes.quantized.Conv2dPackedParamsBase':
24172406
case '__torch__.torch.classes.quantized.Conv3dPackedParamsBase':
24182407
case '__torch__.torch.classes.quantized.LinearPackedParamsBase':
2408+
case '__torch__.torch.classes.rnn.CellParamsBase':
24192409
case '__torch__.torch.classes.xnnpack.Conv2dOpContext':
24202410
case '__torch__.torch.classes.xnnpack.LinearOpContext':
24212411
case '__torch__.torch.classes.xnnpack.TransposeConv2dOpContext': {
@@ -2607,6 +2597,7 @@ pytorch.jit.Execution = class extends pytorch.Execution {
26072597
case '__torch__.torch.classes.xnnpack.LinearOpContext':
26082598
case '__torch__.torch.classes.xnnpack.Conv2dOpContext':
26092599
case '__torch__.torch.classes.xnnpack.TransposeConv2dOpContext':
2600+
case '__torch__.torch.classes.rnn.CellParamsBase':
26102601
case '__torch__.torch.classes.quantized.LinearPackedParamsBase':
26112602
case '__torch__.torch.classes.quantized.Conv2dPackedParamsBase':
26122603
case '__torch__.torch.classes.quantized.Conv3dPackedParamsBase':
@@ -3390,6 +3381,8 @@ pytorch.Utility = class {
33903381
case '__torch__.torch.classes.xnnpack.LinearOpContext':
33913382
case '__torch__.torch.classes.xnnpack.Conv2dOpContext':
33923383
case '__torch__.torch.classes.xnnpack.TransposeConv2dOpContext':
3384+
case '__torch__.torch.classes.rnn.CellParamsBase':
3385+
case '__torch__.torch.classes.rnn.CellParamsBase[]':
33933386
case '__torch__.torch.classes.quantized.LinearPackedParamsBase':
33943387
case '__torch__.torch.classes.quantized.Conv2dPackedParamsBase':
33953388
case '__torch__.torch.classes.quantized.Conv3dPackedParamsBase':

0 commit comments

Comments
 (0)