Skip to content

Commit

Permalink
Update pytorch.js (#1061)
Browse files Browse the repository at this point in the history
  • Loading branch information
lutzroeder committed Oct 21, 2024
1 parent e97c8a1 commit 4878492
Show file tree
Hide file tree
Showing 2 changed files with 98 additions and 186 deletions.
16 changes: 10 additions & 6 deletions source/python.js
Original file line number Diff line number Diff line change
Expand Up @@ -4377,6 +4377,10 @@ python.Execution = class {
}
throw new python.Error(`Schema '${op_name}.${overload_name}' not found.`);
});
this.registerFunction('torch._C._jit_get_schemas_for_operator', (op_name) => {
const registry = torch._C._get_registry();
return registry.getAllOperatorsFor(op_name).map((op) => op.schema());
});
this.registerFunction('torch._C._jit_get_operation', (op_name) => {
const registry = torch._C._get_registry();
const sortedOps = registry.getAllOperatorsFor(op_name);
Expand Down Expand Up @@ -7036,7 +7040,7 @@ python.Execution = class {
});
this.registerType('torch.FunctionSchema', class {
constructor(name, overload_name, args, returns, is_vararg, is_varret) {
let index = name.indexOf('(');
const index = name.indexOf('(');
if (index === -1) {
this._name = name;
this._overload_name = overload_name;
Expand All @@ -7046,15 +7050,15 @@ python.Execution = class {
this._is_varret = is_varret;
} else {
const value = name.substring(0, index).trim();
this._buffer = name.substring(index, name.length);
index = value.indexOf('.');
if (index === -1) {
const dot = value.indexOf('.');
if (dot === -1) {
this._name = value;
this._overload_name = '';
} else {
this._name = value.substring(0, index);
this._overload_name = value.substring(index + 1, value.length);
this._name = value.substring(0, dot);
this._overload_name = value.substring(dot + 1, value.length);
}
this._buffer = name.substring(index, name.length);
}
}
static parse(schema) {
Expand Down
268 changes: 88 additions & 180 deletions source/pytorch.js
Original file line number Diff line number Diff line change
Expand Up @@ -3491,8 +3491,8 @@ pytorch.jit.Execution = class extends pytorch.Execution {
(obj instanceof torch.Value && obj.type() instanceof torch.TensorType) ||
(obj instanceof torch.Value && obj.type() instanceof torch.OptionalType && obj.type().getElementType() instanceof torch.TensorType));
case 'Tensor[]':
return Array.isArray(obj) && obj.length > 0 &&
obj.every((tensor) => pytorch.Utility.isTensor(tensor) || tensor === null || (tensor instanceof torch.Value && tensor.type() instanceof torch.TensorType));
return (Array.isArray(obj) && obj.length > 0 && obj.every((tensor) => pytorch.Utility.isTensor(tensor) || tensor === null || (tensor instanceof torch.Value && tensor.type() instanceof torch.TensorType))) ||
(obj instanceof torch.Value && obj.type() instanceof torch.ListType && obj.type().getElementType() instanceof torch.TensorType);
case 'Scalar':
return (obj !== null && (obj !== Object(obj) || obj instanceof Number)) ||
(pytorch.Utility.isTensor(obj) && Array.isArray(obj.size()) && obj.size().length === 0) ||
Expand Down Expand Up @@ -3627,52 +3627,52 @@ pytorch.jit.Execution = class extends pytorch.Execution {
const torch = this.torch;
const type = name ? `${moduleName}.${name}` : moduleName;
// https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/native_functions.yaml
let overloads = null;
let op_name = null;
if (type.startsWith('torch.')) {
overloads = this._types.get(`aten::${type.substring(6)}`);
} else if (type.startsWith('ops.prim.')) {
overloads = this._types.get(`prim::${type.substring(9)}`);
op_name = `aten::${type.substring(6)}`;
} else if (type.startsWith('ops.')) {
op_name = type.substring(4).replace('.', '::');
} else if (type === 'int') {
overloads = this._types.get(`aten::Int`);
op_name = 'aten::Int';
} else if (type === 'str') {
overloads = this._types.get(`aten::str`);
op_name = 'aten::str';
} else if (type === 'bool') {
overloads = this._types.get(`aten::Bool`);
op_name = 'aten::Bool';
} else if (type === 'float') {
overloads = this._types.get(`aten::Float`);
op_name = 'aten::Float';
} else if (type === 'complex') {
overloads = this._types.get(`aten::Complex`);
} else if (type.startsWith('ops.') && !type.startsWith('ops.prim.')) {
const path = type.split('.');
if (path.length === 3) {
overloads = this._types.get(`${path[1]}::${path[2]}`);
}
if (!overloads) {
const module = this.import(moduleName);
if (!module || !module[name]) {
const metadata = {};
metadata.name = type;
metadata.inputs = [];
metadata.outputs = [];
for (let i = 0; i < args.length; i++) {
const input = {};
let argument = args[i];
input.name = i.toString();
if (argument.type === '=' && argument.target && argument.target.type === 'id') {
input.name = this.expression(argument.target, context);
argument = argument.expression;
}
const obj = this.expression(argument, context);
input.type = pytorch.Utility.getType(obj);
metadata.inputs.push(input);
}
const count = context.target.length > 0 ? context.target[context.target.length - 1].length : 0;
for (let i = 0; i < count; i++) {
metadata.outputs.push({ name: '', type: '' });
op_name = 'aten::Complex';
}
this.native = true;
if (this.native && op_name) {
// const overloads = torch._C._jit_get_schemas_for_operator(op_name);
}
let overloads = this._types.get(op_name);
if (!overloads && type.startsWith('ops.') && !type.startsWith('ops.prim')) {
const module = this.import(moduleName);
if (!module || !module[name]) {
const metadata = {};
metadata.name = type;
metadata.inputs = [];
metadata.outputs = [];
for (let i = 0; i < args.length; i++) {
const input = {};
let argument = args[i];
input.name = i.toString();
if (argument.type === '=' && argument.target && argument.target.type === 'id') {
input.name = this.expression(argument.target, context);
argument = argument.expression;
}
this._metadata.add(type, metadata);
overloads = [metadata];
const obj = this.expression(argument, context);
input.type = pytorch.Utility.getType(obj);
metadata.inputs.push(input);
}
const count = context.target.length > 0 ? context.target[context.target.length - 1].length : 0;
for (let i = 0; i < count; i++) {
metadata.outputs.push({ name: '', type: '' });
}
this._metadata.add(type, metadata);
overloads = [metadata];
}
}
if (!overloads) {
Expand All @@ -3681,7 +3681,6 @@ pytorch.jit.Execution = class extends pytorch.Execution {
}
return null;
}
overloads = Array.isArray(overloads) ? overloads : [overloads];
const evalArgs = args.map((argument) => {
if (argument.type === '=' && argument.target && argument.target.type === 'id') {
argument = argument.expression;
Expand All @@ -3690,172 +3689,81 @@ pytorch.jit.Execution = class extends pytorch.Execution {
});
const matches = [];
for (const schema of overloads) {
const copyArgs = Array.prototype.slice.call(args);
const copyEvalArgs = Array.prototype.slice.call(evalArgs);
const parameters = Array.prototype.slice.call(schema.inputs || []).concat(Array.prototype.slice.call(schema.attributes || []));
const parameters = Array.prototype.slice.call(schema.inputs || []);
let next = false;
let kwarg_only = false;
while (copyEvalArgs.length > 0) {
let position = 0;
while (position < evalArgs.length) {
if (parameters.length <= 0) {
next = !schema.name.startsWith('_caffe2::') && !schema.is_vararg;
break;
}
if (copyArgs.every((arg) => arg.type === '=' && arg.target && arg.target.type === 'id') &&
parameters.every((parameter) => parameter.type !== 'Tensor' && parameter.type !== 'Tensor[]')) {
const map = new Map(parameters.map((parameter) => [parameter.name, parameter]));
while (copyArgs.length > 0) {
const argument = copyArgs.shift();
const arg = copyEvalArgs.shift();
const parameter = map.get(argument.target.value);
if (!parameter) {
next = true;
break;
}
if (parameter.kwarg_only) {
kwarg_only = true;
}
let type = parameter.type;
let optional = false;
if (parameter.type.endsWith('?')) {
type = parameter.type.substring(0, parameter.type.length - 1);
optional = true;
}
if (!this.isType(arg, type)) {
if (optional) {
continue;
}
next = true;
break;
}
}
continue;
}
if (next) {
if (parameters[0].kwarg_only) {
break;
}
const parameter = parameters.shift();
if (parameter.kwarg_only) {
kwarg_only = true;
}
const [argument] = copyEvalArgs;
/* if (type === 'Tensor' || (type === 'Scalar' && pytorch.Utility.isTensor(argument))) {
if (Array.isArray(argument) || (!pytorch.Utility.isTensor(argument) && argument !== null && argument !== undefined)) {
if (optional) {
continue;
}
next = true;
} else {
copyArgs.shift();
copyEvalArgs.shift();
}
} else */
let type = parameter.type;
const arg = parameters.shift();
const value = evalArgs[position];
let type = arg.type;
let optional = false;
if (parameter.type.endsWith('?')) {
type = parameter.type.substring(0, parameter.type.length - 1);
if (arg.type.endsWith('?')) {
type = arg.type.substring(0, arg.type.length - 1);
optional = true;
}
if (optional === true &&
(type === 'float32' || type === 'boolean' || type === 'int64' || type === 'complex' || type === 'ScalarType' || type === 'Device' || type === 'Layout') &&
argument instanceof torch.Value && argument.type() instanceof torch.NoneType) {
copyArgs.shift();
copyEvalArgs.shift();
} else if (type === 'Tensor[]') {
const [argument] = copyEvalArgs;
if ((argument instanceof torch.Value && pytorch.Utility.toType(argument.type()) === 'Tensor[]') ||
(Array.isArray(argument) && argument.every((item) => pytorch.Utility.isTensor(item) || item === null || (item instanceof torch.Value && item.type() instanceof torch.TensorType)))) {
copyArgs.shift();
copyEvalArgs.shift();
} else {
if (optional) {
continue;
}
next = true;
value instanceof torch.Value && value.type() instanceof torch.NoneType) {
position++;
} else if (!this.isType(value, type) && value !== null) {
if (optional) {
continue;
}
/* } else if (type === 't[]') {
if (!Array.isArray(argument) && (argument instanceof torch.Value === false || argument.type() instanceof torch.ListType === false)) {
if (optional) {
continue;
}
next = true;
} else {
copyArgs.shift();
copyEvalArgs.shift();
}*/
next = true;
break;
} else if (args[position].type === '=') {
next = true;
break;
} else {
const [arg] = copyArgs;
if (!this.isType(argument, type) && argument !== null) {
position++;
}
}
if (next) {
continue;
}
if (args.every((arg, index) => index < position || (arg.type === '=' && arg.target && arg.target.type === 'id'))) {
const map = new Map(parameters.map((parameter) => [parameter.name, parameter]));
while (position < args.length) {
const value = evalArgs[position];
const arg = map.get(args[position].target.value);
position++;
if (!arg) {
next = true;
break;
}
if (arg.kwarg_only) {
kwarg_only = true;
}
let type = arg.type;
let optional = false;
if (arg.type.endsWith('?')) {
type = arg.type.substring(0, arg.type.length - 1);
optional = true;
}
if (!this.isType(value, type)) {
if (optional) {
continue;
}
next = true;
} else if (arg.type === '=') {
next = true;
// throw new pytorch.Error('Expected named argument.');
} else {
copyArgs.shift();
copyEvalArgs.shift();
break;
}
}
if (next) {
break;
}
}
if (next) {
continue;
}
if (!kwarg_only && parameters.some((parameter) => parameter.default === undefined)) {
if (position < evalArgs.length && !schema.is_vararg && !schema.name.startsWith('_caffe2::')) {
continue;
}
for (let i = 0; i < schema.outputs.length; i++) {
const parameter = schema.outputs[i];
switch (parameter.type) {
case 'Scalar':
case 'Tensor':
case 'Tensor[]':
case 'float32':
case 'float32[]':
case 'int64':
case 'int64[]':
case 'Device':
case 'boolean':
case 'boolean[]':
case 't':
case 't[]':
case 'complex':
case 'complex[]':
case 'string':
case 'string[]':
case 'Dict(string, Tensor)':
case 'Dict(Tensor, t)':
case 'Dict(boolean, t)':
case 'Dict(complex, t)':
case 'Dict(float32, t)':
case 'Dict(int64, t)':
case 'Dict(string, t)':
case 'Dict(Tensor, tVal)':
case 'Dict(boolean, tVal)':
case 'Dict(complex, tVal)':
case 'Dict(float32, tVal)':
case 'Dict(int64, tVal)':
case 'Dict(string, tVal)':
case '(string, t)[]':
case 'Any':
break;
case '__torch__.torch.classes.xnnpack.LinearOpContext':
case '__torch__.torch.classes.xnnpack.Conv2dOpContext':
case '__torch__.torch.classes.xnnpack.TransposeConv2dOpContext':
case '__torch__.torch.classes.rnn.CellParamsBase':
case '__torch__.torch.classes.quantized.LinearPackedParamsBase':
case '__torch__.torch.classes.quantized.Conv2dPackedParamsBase':
case '__torch__.torch.classes.quantized.Conv3dPackedParamsBase':
break;
default: {
throw new pytorch.Error(`Unknown return type '${parameter.type}'.`);
}
}
}
if (next) {
if (!kwarg_only && parameters.some((parameter) => parameter.default === undefined)) {
continue;
}
matches.push(schema);
Expand Down

0 comments on commit 4878492

Please sign in to comment.