Skip to content

Commit

Permalink
Update backend test (#990)
Browse files Browse the repository at this point in the history
  • Loading branch information
lutzroeder committed Oct 20, 2024
1 parent eb4ed0e commit e97c8a1
Show file tree
Hide file tree
Showing 5 changed files with 466 additions and 390 deletions.
4 changes: 2 additions & 2 deletions source/python.js
Original file line number Diff line number Diff line change
Expand Up @@ -4364,8 +4364,8 @@ python.Execution = class {
}
});
this.registerFunction('torch._C._get_registry', () => {
torch._C._registry = torch._C._registry || new torch._C.OperatorRegistry();
return torch._C._registry;
this._operators = this._operators || new torch._C.OperatorRegistry();
return this._operators;
});
this.registerFunction('torch._C._get_schema', (op_name, overload_name) => {
const registry = torch._C._get_registry();
Expand Down
2 changes: 1 addition & 1 deletion source/pytorch-metadata.json
Original file line number Diff line number Diff line change
Expand Up @@ -20382,7 +20382,7 @@
"category": "Data"
},
{
"name": "torchaudio::sox_effects_apply_effects_tensor(Tensor tensor, int sample_rate, str[][] effects, bool channels_first=True) -> (Tensor, int64)",
"name": "torchaudio::sox_effects_apply_effects_tensor(Tensor tensor, int sample_rate, str[][] effects, bool channels_first=True) -> (Tensor, int)",
"inputs": [
{ "name": "tensor", "type": "Tensor" },
{ "name": "sample_rate", "type": "int64" },
Expand Down
60 changes: 37 additions & 23 deletions source/pytorch.js
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,9 @@ pytorch.Graph = class {
node.outputs().every((output) => pytorch.Utility.isTensor(output.value))) {
continue;
}
if (node.kind() === 'prim::Constant' && node.outputs().length === 1 && node.outputs()[0].uses().length === 1) {
continue;
}
this.nodes.push(new pytorch.Node(metadata, null, null, node, initializers, values));
}
if (module) {
Expand Down Expand Up @@ -398,9 +401,18 @@ pytorch.Node = class {
let module = null;
if (pytorch.Utility.isInstance(obj, 'torch.Node')) {
const node = obj;
// const schema = node.schema();
this.type = createType(metadata, node.kind());
for (const name of node.attributeNames()) {
const kind = node.kind();
this.type = {
identifier: kind,
name: kind.indexOf('::') === -1 ? kind : kind.split('::').pop().split('.')[0]
};
const schema = node.schema();
if (schema && schema.category) {
this.type.category = schema.category;
}
const inputs = node.inputs();
const outputs = node.outputs();
const getAttribute = (node, name) => {
const kind = node.kindOf(name);
let value = null;
let type = null;
Expand All @@ -412,12 +424,16 @@ pytorch.Node = class {
case 'ival': value = node.ival(name); break;
default: throw new pytorch.Error(`Unsupported attribute kind '${kind}'.`);
}
return [type, value];
};
for (const name of node.attributeNames()) {
const [type, value] = getAttribute(node, name);
const attribute = new pytorch.Argument(name, value, type);
this.attributes.push(attribute);
}
let match = true;
let count = 0;
for (const input of node.inputs()) {
for (const input of inputs) {
const value = input.value;
let values = [];
if (pytorch.Utility.isObject(value)) {
Expand Down Expand Up @@ -458,22 +474,21 @@ pytorch.Node = class {
module = null;
}
}
const inputs = node.inputs();
for (let i = 0; i < inputs.length; i++) {
const input = inputs[i];
const schema = this.type && this.type.inputs && i < this.type.inputs.length ? this.type.inputs[i] : null;
const name = schema && schema.name ? schema.name : i.toString();
let type = schema && schema.type ? schema.type : null;
const arg = schema && schema.arguments && i < schema.arguments.length ? schema.arguments[i] : null;
const name = arg && arg.name ? arg.name : i.toString();
let type = arg ? arg.real_type : null;
let array = false;
if (type && type.endsWith('[]')) {
if (pytorch.Utility.isInstance(type, 'torch.ListType')) {
array = true;
type = type.slice(0, -2);
type = type.getElementType();
}
let argument = null;
if (pytorch.Utility.isObjectType(type)) {
if (arg && pytorch.Utility.isInstance(arg.real_type, 'torch.ClassType')) {
const obj = input.value;
if (!array && initializers.has(obj)) {
const node = new pytorch.Node(metadata, name, type, obj, initializers, values);
const node = new pytorch.Node(metadata, name, type.qualified_name(), obj, initializers, values);
argument = new pytorch.Argument(name, node, 'object');
} else if (array && Array.isArray(obj) && obj.every((obj) => initializers.has(obj))) {
const node = obj.map((obj) => new pytorch.Node(metadata, name, type, obj, initializers, values));
Expand Down Expand Up @@ -507,6 +522,9 @@ pytorch.Node = class {
}
} else if (pytorch.Utility.isInstance(input.type(), 'torch.StringType') && typeof input.value === 'string') {
argument = new pytorch.Argument(name, input.value, 'string');
} else if (input.node() && input.uses().length === 1 && input.node().kind() === 'prim::Constant') {
const [type, value] = getAttribute(input.node(), 'value');
argument = new pytorch.Argument(name, value, type || 'attribute');
} else {
const identifier = input.unique().toString();
const value = values.map(identifier);
Expand Down Expand Up @@ -543,21 +561,19 @@ pytorch.Node = class {
}
return value;
});
argument = new pytorch.Argument(name, args, schema.type);
argument = new pytorch.Argument(name, args, pytorch.Utility.toType(type));
} else {
argument = createAttribute(schema, schema.name, input.value);
throw new pytorch.Error('Unsupported input value');
}
this.inputs.push(argument);
}
const outputs = node.outputs();
for (let i = 0; i < outputs.length; i++) {
const output = outputs[i];
const metadata = this.type && this.type.outputs && i < this.type.outputs.length ? this.type.outputs[i] : null;
let name = '';
if (metadata && metadata.name) {
name = metadata.name;
const ret = schema && schema.returns && i < schema.returns.length ? schema.returns[i] : null;
if (ret && ret.name) {
name = ret.name;
} else {
name = i === 0 ? 'output' : `output${i}`;
name = i === 0 && outputs.length === 1 ? 'output' : `${i}`;
}
let list = [output];
if (output.uses().length === 1 &&
Expand Down Expand Up @@ -2573,9 +2589,7 @@ pytorch.jit.Execution = class extends pytorch.Execution {
const node = this._graph.create('aten::__getitem__.t');
node.addInput(target);
if (Number.isInteger(index)) {
const value = this.invoke('torch.Value', [node]);
value.value = index;
index = value;
index = this.constant(index);
}
node.addInput(index);
const value = node.addOutput();
Expand Down
Loading

0 comments on commit e97c8a1

Please sign in to comment.