diff --git a/source/python.js b/source/python.js index e53e875dd3..eaaf38d233 100644 --- a/source/python.js +++ b/source/python.js @@ -4967,6 +4967,35 @@ python.Execution = class { } return this[name]; } + __delattr__(name) { + if (this._modules.has(name)) { + this._modules.delete(name); + } + } + children() { + return this._modules.values(); + } + named_children() { + return this._modules; + } + parameters() { + return this._parameters.values(); + } + named_parameters(recurse) { + if (recurse) { + throw new python.Error('Named parameters with recurse not implemented.'); + } + return this._parameters; + } + buffers() { + return this._buffers.values(); + } + named_buffers(recurse) { + if (recurse) { + throw new python.Error('Named parameters with recurse not implemented.'); + } + return this._buffers; + } }); torch.nn.Module = torch.nn.modules.module.Module; torch.nn.modules.Module = torch.nn.modules.module.Module; @@ -6958,6 +6987,9 @@ python.Execution = class { } return this.equals(rhs); } + is_module() { + return false; + } expect(type) { if (this instanceof type === false) { throw new python.Error(`Expected '${type.kind()}' but got '${this.kind()}'.`); @@ -7004,6 +7036,12 @@ python.Execution = class { is_module() { return this._is_module; } + is_parameter(slot) { + return this._attributes[slot].is_parameter === true; + } + is_buffer(slot) { + return this._attributes[slot].is_buffer === true; + } addMethod(func) { this._methods.set(func.name(), func); } @@ -7023,6 +7061,9 @@ python.Execution = class { findStaticMethod(name) { return this._staticmethods.get(name); } + numAttributes() { + return this._attributes.length; + } addAttribute(name, type, is_parameter, is_buffer) { is_parameter = is_parameter || false; is_buffer = is_buffer || false; @@ -7057,10 +7098,16 @@ python.Execution = class { } return null; } - getAttribute(name) { - const slot = this.findAttributeSlot(name); + hasAttribute(name) { + return this._attributes.find((attr) => attr.name === name); + } + getAttribute(arg) { + const slot = Number.isInteger(arg) ? arg : this.findAttributeSlot(arg); return this._attributes[slot].type; } + getAttributeName(slot) { + return this._attributes[slot].name; + } hasConstant(/* name */) { } methods() { @@ -8131,7 +8178,10 @@ python.Execution = class { parseBroadcastList(/* expr */) { return null; } - + parseType(str) { + const expr = ast.parse(str); + return this.parseTypeFromExpr(expr.body[0]); + } }); this.registerType('torch._ops.OpOverload', class extends torch._ops.OperatorBase { constructor(overloadpacket, op, op_dk, schema, tags) { @@ -8874,8 +8924,9 @@ python.Execution = class { this._loaded_sources = new Set(); this._to_be_defined = new Map(); } - loadType(/* name */) { - // + loadType(name) { + const type_parser = new torch.jit.ScriptTypeParser(this); + return type_parser.parseType(name.qualifiedName()); } resolveType(name) { name = new torch.jit.QualifiedName(name); @@ -9120,12 +9171,32 @@ python.Execution = class { for (let i = 0; i < constants.length; i++) { execution.builtins.CONSTANTS[`c${i}`] = constants[i]; } - const module = this.readArchive('data'); - const name = `${module.__class__.__module__}.${module.__class__.__name__}`; - const type = torch.ClassType.create(name, null, true); - const result = new torch.ScriptModule(type); - result.data = module; - return result; + const obj = this.readArchive('data'); + const convertModule = (obj) => { + if (obj.__class__) { + const name = `${obj.__class__.__module__}.${obj.__class__.__name__}`; + const type = this._source_importer.loadType(new torch.jit.QualifiedName(name)); + const module = new torch.ScriptModule(type, this._compilation_unit); + for (let i = 0; i < type.numAttributes(); i++) { + const k = type.getAttributeName(i); + const t = type.getAttribute(i); + const v = obj[k]; + if (t.is_module()) { + module.__setattr__(k, convertModule(v)); + } else { + module.__setattr__(k, obj[k]); + } + } + for (const [key, value] of Object.entries(Object.getPrototypeOf(obj))) { + if (value && value.__class__ === builtins.method) { + module[key] = value; + } + } + return module; + } + throw new python.Error('Module class not found.'); + }; + return convertModule(obj); } LEGACY_deserialize() { const execution = this._compilation_unit.execution; @@ -9186,8 +9257,8 @@ python.Execution = class { for (const tensor of tensor_table) { this._constant_table.push(tensor); } - const temp = this.LEGACY_convertModule(module_def); - const data = obj.mainModule || {}; + return this.LEGACY_convertModule(module_def); + /* const data = obj.mainModule || {}; const queue = [data]; while (queue.length > 0) { const module = queue.shift(); @@ -9237,6 +9308,8 @@ python.Execution = class { const result = new torch.ScriptModule(temp.type()); result.data = data; return result; + return module; + */ } LEGACY_convertModule(module_def) { const atoms = new torch.jit.QualifiedName(module_def.name).atoms(); @@ -9245,13 +9318,14 @@ python.Execution = class { const sanitized = /^\d+$/.test(atom) ? `_${atom}` : atom; this._LEGACY_moduleStack.push(sanitized); } - const module = new torch.ScriptModule(new torch.jit.QualifiedName(this._LEGACY_moduleStack), this._compilation_unit); + const qn = new torch.jit.QualifiedName(this._LEGACY_moduleStack); + const module = new torch.ScriptModule(qn, this._compilation_unit); for (const sub_def of module_def.submodules || []) { const submodule = this.LEGACY_convertModule(sub_def); module.register_module(sub_def.name, submodule); } for (const param_def of module_def.parameters || []) { - const tensor = this._constant_table[Number(param_def.tensorId)]; + const tensor = this._constant_table[Number(param_def.tensor_id)]; if (param_def.isBuffer) { module.register_buffer(param_def.name, tensor); } else { @@ -9263,11 +9337,21 @@ python.Execution = class { if (module.hasattr(attr_def.name)) { continue; } + throw new python.Error('Not implemented.'); // IValue ivalue; // if (attr_def.id() >= 0) { // ivalue = LEGACY_pickled_ivalues_.at(attr_def.id()); // } - // module.register_attribute(attr_def.name(), typeParser.parseType(attr_def.type()), ivalue); + // module.register_attribute(attr_def.name, typeParser.parseType(attr_def.type), ivalue); + } + if (module_def.torchscript_arena) { + const key = module_def.torchscript_arena.key; + const file = key.substring('code/'.length); + const name = file.replace(/\.py$/, '').split('/').join('.'); + const code = execution.import(name); + if (code.forward.__class__ === execution.builtins.function) { + module.forward = code.forward; + } } /* std::shared_ptr gen_ranges = nullptr; @@ -9299,9 +9383,13 @@ python.Execution = class { return module; } readArchive(archive_name) { - const type_resolver = null; - const obj_loader = null; - return this.readArchiveAndTensors(archive_name, this._pickle_dir_prefix, this._tensor_dir_prefix, type_resolver, obj_loader, this._device, this._reader, null, this._storage_context); + const type_resolver = (qn) => { + const cls = this._source_importer.loadType(qn); + return cls; + }; + const ObjLoaderFunc = (/* type, ivalue */) => { + }; + return this.readArchiveAndTensors(archive_name, this._pickle_dir_prefix, this._tensor_dir_prefix, type_resolver, ObjLoaderFunc, this._device, this._reader, null, this._storage_context); } readArchiveAndTensors(archive_name, pickle_prefix, tensor_prefix, type_resolver, obj_loader, device, stream_reader, type_parser, storage_context) { const picklename = `${pickle_prefix + archive_name}.pkl`; @@ -9405,7 +9493,9 @@ python.Execution = class { const cu = new torch.jit.CompilationUnit(); cu.execution = execution; const cpp_module = torch._C.import_ir_module(cu, file, map_location, extra_files); - return new torch.jit._script.RecursiveScriptModule(cpp_module); + const module = torch.jit._script.wrap_cpp_module(cpp_module); + module.forward = cpp_module.forward; // remove + return module; }); this.registerFunction('torch._C.import_ir_module', function(cu, reader, ...args) { switch (arguments.length) { @@ -9495,7 +9585,7 @@ python.Execution = class { const cu = new torch.jit.CompilationUnit(); cu.execution = execution; const cpp_module = torch._C._import_ir_module_from_package(cu, importer.zip_reader, importer.storage_context, importer.last_map_location, script_module_id); - return new torch.jit._script.RecursiveScriptModule(cpp_module); + return torch.jit._script.wrap_cpp_module(cpp_module); }); this.registerFunction('torch.jit._script.wrap_cpp_module', (cpp_module) => { const init_fn = (script_module) => { @@ -9553,6 +9643,7 @@ python.Execution = class { this.registerType('torch.ScriptObject', class { constructor(type) { this._type = type; + this._ivalue = {}; } static create(type) { if (type.is_module()) { @@ -9579,14 +9670,17 @@ python.Execution = class { } __setattr__(name, value) { // if (this._type.hasContant(name)) - this[name] = value; + this._ivalue[name] = value; } __getattr__(name) { - return this[name]; + return this._ivalue[name]; } hasattr(name) { return this._type.hasAttribute(name) || this._type.hasConstant(name); } + getattr(name) { + return this.__getattr__(name); + } _properties() { throw new python.Error(); } @@ -9622,19 +9716,16 @@ python.Execution = class { return false; } }; - if (!this.data) { + if (!this.forward) { return null; } - if (!this.data.forward) { - throw new python.Error("Module 'forward' not implemented."); - } execution.traceAttr = false; const args = []; if (!execution.traceAttr) { - args.push(this.data); // self + args.push(this); // self } - if (this.data.forward.__code__ && this.data.forward.__code__.args) { - const params = this.data.forward.__code__.args.args; + if (this.forward.__code__ && this.forward.__code__.args) { + const params = this.forward.__code__.args.args; for (let i = 0; i < params.length; i++) { const arg = params[i]; if (execution.traceAttr || arg.arg !== 'self') { @@ -9653,7 +9744,7 @@ python.Execution = class { } } execution.purge = new Set(); - const result = this.data.forward.__call__(args); + const result = this.forward.__call__(args); const queue = Array.from(execution.purge); const visited = new Set(); while (queue.length > 0) { @@ -9715,15 +9806,15 @@ python.Execution = class { } register_module(name, module) { this.type().addOrCheckAttribute(name, module.type()); - // _ivalue()->setAttr(name, module._ivalue()); + this.__setattr__(name, module); // _ivalue()->setAttr(name, module._ivalue()); } - register_buffer(name /* , v */) { + register_buffer(name, v) { this.type().addOrCheckAttribute(name, torch.TensorType.get(), false, true); - // _ivalue()->setAttr(name, std::move(v)); + this.__setattr__(name, v); // _ivalue()->setAttr(name, std::move(v)); } register_parameter(name, v, is_buffer) { this.type().addOrCheckAttribute(name, torch.TensorType.get(), !is_buffer, is_buffer); - // _ivalue()->setAttr(name, std::move(v)); + this.__setattr__(name, v); // _ivalue()->setAttr(name, std::move(v)); } register_attribute(name, t, v, is_param, is_buffer) { this.type().addOrCheckAttribute(name, t, is_param, is_buffer); @@ -9731,11 +9822,59 @@ python.Execution = class { } }); this.registerType('torch.ModuleDict', class { - constructor(module) { - this._items = Object.entries(module).filter(([, value]) => value instanceof torch.ScriptModule); + constructor(mod) { + this._module = mod; + } + items() { + const result = new Map(); + const type = this._module.type(); + for (let i = 0; i < type.numAttributes(); i++) { + const k = type.getAttributeName(i); + const t = type.getAttribute(i); + if (t && t.is_module()) { + result.set(k, this._module.__getattr__(k)); + } + } + return result; + } + }); + this.registerType('torch.ParameterDict', class { + constructor(mod) { + this._module = mod; + } + items() { + const result = new Map(); + const type = this._module.type(); + for (let i = 0; i < type.numAttributes(); i++) { + if (type.is_parameter(i)) { + const k = type.getAttributeName(i); + const v = this._module.__getattr__(k); + if (v instanceof torch.Tensor) { + result.set(k, v); + } + } + } + return result; + } + }); + this.registerType('torch.BufferDict', class { + constructor(mod) { + this._module = mod; } items() { - return this._items; + const result = new Map(); + const type = this._module.type(); + for (let i = 0; i < type.numAttributes(); i++) { + if (type.is_buffer(i)) { + const t = type.getAttribute(i); + if (t.isSubtypeOf(torch.TensorType.get())) { + const k = type.getAttributeName(i); + const v = this._module.__getattr__(k); + result.set(k, v); + } + } + } + return result; } }); this.registerType('torch.jit.to_ir', class { @@ -9846,11 +9985,11 @@ python.Execution = class { torch.jit._script.RecursiveScriptModule._finalize_scriptmodule(script_module); return script_module; } - static _finalize_scriptmodule() { - this._initializing = false; - } - get data() { - return this._c.data; + static _finalize_scriptmodule(script_module) { + script_module._parameters = new torch.ParameterDict(script_module._c).items(); + script_module._buffers = new torch.BufferDict(script_module._c).items(); + // script_module._modules = OrderedModuleDict(script_module._c, script_module._modules) + script_module._initializing = false; } get graph() { // return this._c._get_method("forward").graph; @@ -9863,8 +10002,8 @@ python.Execution = class { __setattr__(name, value) { if (this._initializing) { super.__setattr__(name, value); - } else if (this.modules.has(name)) { - this.modules.set(name, value); + } else if (this._modules.has(name)) { + this._modules.set(name, value); } else if (this._c.hasattr(name)) { this._c.setattr(name, value); } else { @@ -9875,8 +10014,8 @@ python.Execution = class { if (this._initializing) { return super.__getattr__(name); } - if (this.modules.has(name)) { - return this.modules.get(name); + if (this._modules.has(name)) { + return this._modules.get(name); } if (this._c.hasattr(name)) { return this._c.getattr(name); @@ -12544,7 +12683,12 @@ python.Execution = class { if (path) { let target = null; for (let i = path.length - 1; i >= 0; i--) { - target = target ? target[path[i]] : context.get(path[i]); + const name = path[i]; + if (target) { + target = target.__getattr__ ? target.__getattr__(name) : target[name]; + } else { + target = context.get(name); + } if (!target) { break; } diff --git a/source/pytorch.js b/source/pytorch.js index fc475dbbfa..3480aaa590 100644 --- a/source/pytorch.js +++ b/source/pytorch.js @@ -89,126 +89,135 @@ pytorch.Graph = class { return values.get(name); }; const torch = execution ? execution.torch : null; - // type = module && module.__class__ && module.__class__.__module__ && module.__class__.__name__ ? `${module.__class__.__module__}.${module.__class__.__name__}` : null; if (torch && (module instanceof torch.ScriptModule || module instanceof torch.jit._script.ScriptModule || module instanceof torch.jit._script.RecursiveScriptModule) && module.graph) { const initializers = new Map(); const graph = module.graph; - const constants = module.code_with_constants[1].const_mapping; - if (constants) { - for (const [key, value] of constants) { - const name = `CONSTANTS.${key}`; - if (pytorch.Utility.isTensor(value)) { - initializers.set(value, new pytorch.Tensor(name, value)); - } else if (pytorch.Utility.isObject(value)) { - initializers.set(value, value); - } else { - // throw new pytorch.Error('Unsupported constant.'); + if (graph) { + const constants = module.code_with_constants[1].const_mapping; + if (constants) { + for (const [key, value] of constants) { + const name = `CONSTANTS.${key}`; + if (pytorch.Utility.isTensor(value)) { + initializers.set(value, new pytorch.Tensor(name, value)); + } else if (pytorch.Utility.isObject(value)) { + initializers.set(value, value); + } else { + // throw new pytorch.Error('Unsupported constant.'); + } } } - } - const queue = [module.data]; - while (queue.length > 0) { - const module = queue.shift(); - for (const [key, obj] of Object.entries(module)) { - if (key !== '__module__' && key !== '__name__' && key !== '__class__' && key !== '__parent__') { - if (!Array.isArray(obj) && obj === Object(obj)) { - if (pytorch.Utility.isTensor(obj)) { - const parameter = obj; - parameter.__parent__ = module; - if (parameter.storage() && !parameter.__origin__) { - if (parameter.__count__ === undefined || parameter.__count__ === 1) { - initializers.set(parameter, new pytorch.Tensor(parameter.name, parameter)); + const queue = [module]; + while (queue.length > 0) { + const module = queue.shift(); + const children = module.named_children(); + for (const [key, obj] of children) { + obj.__parent__ = module; + obj.__name__ = obj.__name__ || key; + queue.push(obj); + } + for (const buffer of module.buffers()) { + buffer.__parent__ = module; + if (buffer.storage() && !buffer.__origin__ && (buffer.__count__ === undefined || buffer.__count__ === 1)) { + initializers.set(buffer, new pytorch.Tensor(buffer.name, buffer)); + } + } + for (const parameter of module.parameters()) { + parameter.__parent__ = module; + if (parameter.storage() && !parameter.__origin__ && (parameter.__count__ === undefined || parameter.__count__ === 1)) { + initializers.set(parameter, new pytorch.Tensor(parameter.name, parameter)); + } + } + for (const [key, obj] of children) { + if (key !== '__module__' && key !== '__name__' && key !== '__class__' && key !== '__parent__') { + if (!Array.isArray(obj) && obj === Object(obj)) { + if (pytorch.Utility.isTensor(obj)) { + const parameter = obj; + parameter.__parent__ = module; + if (parameter.storage() && !parameter.__origin__) { + if (parameter.__count__ === undefined || parameter.__count__ === 1) { + initializers.set(parameter, new pytorch.Tensor(parameter.name, parameter)); + } } + } else if (pytorch.Utility.isObject(obj)) { + if (obj.__count__ === undefined || obj.__count__ === 1) { + initializers.set(obj, obj); + } + queue.push(obj); + } else if (obj instanceof torch.Value || obj instanceof torch.Node) { + continue; + } else if (obj && obj.__class__) { + obj.__parent__ = module; + obj.__name__ = obj.__name__ || key; + queue.push(obj); } - } else if (pytorch.Utility.isObject(obj)) { - if (obj.__count__ === undefined || obj.__count__ === 1) { - initializers.set(obj, obj); - } - queue.push(obj); - } else if (obj instanceof torch.Value || obj instanceof torch.Node) { - continue; - } else if (obj && obj.__class__) { - obj.__parent__ = module; - obj.__name__ = obj.__name__ || key; - queue.push(obj); } } } } - } - for (const value of graph.inputs()) { - const identifier = pytorch.Utility.unique(value); - const name = value.debugName() || identifier; - this.inputs.push(new pytorch.Argument(name, [values.map(identifier)])); - } - for (const value of graph.outputs()) { - const identifier = pytorch.Utility.unique(value); - this.outputs.push(new pytorch.Argument(identifier, [values.map(identifier)])); - } - for (const node of graph.nodes()) { - if (node === graph.param_node() || - node === graph.return_node()) { - continue; + for (const value of graph.inputs()) { + const identifier = pytorch.Utility.unique(value); + const name = value.debugName() || identifier; + this.inputs.push(new pytorch.Argument(name, [values.map(identifier)])); } - if (node.kind() === 'prim::TupleConstruct' && - node.inputs().length === 0 && - node.outputs().length === 1 && - node.outputs().every((output) => output.uses().length === 0)) { - continue; + for (const value of graph.outputs()) { + const identifier = pytorch.Utility.unique(value); + this.outputs.push(new pytorch.Argument(identifier, [values.map(identifier)])); } - if (node.kind() === 'prim::ListConstruct') { - if (node.outputs().length === 1 && - node.outputs().every((output) => output.uses().length === 1) && - node.inputs().every((input) => pytorch.Utility.isTensor(input.value) || input instanceof torch.Value)) { + for (const node of graph.nodes()) { + if (node === graph.param_node() || + node === graph.return_node()) { continue; } - if (node.inputs().length === 0 && + if (node.kind() === 'prim::TupleConstruct' && + node.inputs().length === 0 && node.outputs().length === 1 && node.outputs().every((output) => output.uses().length === 0)) { continue; } - if (node.inputs().every((value) => value && (value.type() instanceof torch.IntType || value.type() instanceof torch.FloatType || value.type() instanceof torch.StringType || value.type() instanceof torch.ComplexType)) && - node.outputs().length === 1 && - node.outputs().every((output) => output.uses().length === 1)) { + if (node.kind() === 'prim::ListConstruct') { + if (node.outputs().length === 1 && + node.outputs().every((output) => output.uses().length === 1) && + node.inputs().every((input) => pytorch.Utility.isTensor(input.value) || input instanceof torch.Value)) { + continue; + } + if (node.inputs().length === 0 && + node.outputs().length === 1 && + node.outputs().every((output) => output.uses().length === 0)) { + continue; + } + if (node.inputs().every((value) => value && (value.type() instanceof torch.IntType || value.type() instanceof torch.FloatType || value.type() instanceof torch.StringType || value.type() instanceof torch.ComplexType)) && + node.outputs().length === 1 && + node.outputs().every((output) => output.uses().length === 1)) { + continue; + } + } + if (node.kind() === 'prim::ListUnpack' && + node.inputs().length === 1 && + node.inputs().every((input) => input.uses().length === 1) && + node.outputs().every((output) => pytorch.Utility.isTensor(output.value))) { continue; } + if (node.kind() === 'prim::Constant' && node.outputs().length <= 1 && node.outputs()[0].uses().length <= 1) { + continue; + } + this.nodes.push(new pytorch.Node(execution, metadata, null, null, node, initializers, values)); } - if (node.kind() === 'prim::ListUnpack' && - node.inputs().length === 1 && - node.inputs().every((input) => input.uses().length === 1) && - node.outputs().every((output) => pytorch.Utility.isTensor(output.value))) { - continue; - } - if (node.kind() === 'prim::Constant' && node.outputs().length <= 1 && node.outputs()[0].uses().length <= 1) { - continue; - } - this.nodes.push(new pytorch.Node(execution, metadata, null, null, node, initializers, values)); } if (module) { - const queue = [module.data]; + const queue = [module]; while (queue.length > 0) { const module = queue.pop(); - if (module && !pytorch.Utility.isObject(module)) { - if (!module.__hide__ && pytorch.Graph._getParameters(module).size > 0) { - for (const [name, obj] of Object.entries(module)) { - if ((obj && obj.__hide__) || (obj !== null && !pytorch.Utility.isTensor(obj)) && typeof obj !== 'boolean' && typeof obj !== 'number' && typeof obj !== 'string') { - delete module[name]; + if (module) { + if (!module.__hide__ && module.named_parameters().size > 0) { + for (const [name, obj] of module.named_children()) { + if (obj && obj.__hide__) { + module.__delattr__(name); } } const node = new pytorch.Node(execution, metadata, null, null, module, initializers, values); this.nodes.push(node); } - const modules = []; - if (module.__class__ && module.__class__.__module__ && module.__class__.__name__) { - for (const [key, value] of Object.entries(module)) { - if (!key.startsWith('__') && value && value.__class__ && value.__class__.__module__ && value.__class__.__name__ && !pytorch.Utility.isTensor(value)) { - if (value instanceof torch.Value) { - continue; - } - modules.push(value); - } - } - } + const modules = Array.from(module.children()); queue.push(...modules.reverse()); } } @@ -331,18 +340,6 @@ pytorch.Graph = class { } } } - - static _getParameters(module) { - const parameters = new Map(); - if (module && module.__class__.__module__ && module.__class__.__name__) { - for (const [key, value] of Object.entries(module)) { - if (pytorch.Utility.isTensor(value)) { - parameters.set(key, value); - } - } - } - return parameters; - } }; pytorch.Argument = class { @@ -451,7 +448,7 @@ pytorch.Node = class { } } if (module) { - const parameters = pytorch.Graph._getParameters(module); + const parameters = module.named_parameters(); parameters.delete('num_batches_tracked'); if (parameters.size === count && match) { module.__hide__ = true; @@ -703,8 +700,24 @@ pytorch.Node = class { throw new pytorch.Error(`Unsupported node operation '${obj.op}'.`); } } else { + if (torch && obj instanceof torch.jit._script.RecursiveScriptModule && obj._c && obj._c.qualified_name) { + type = obj._c._type; + const target = { + _modules: obj._modules, + _parameters: obj._parameters, + _buffers: obj._buffers, + }; + for (let i = 0; i < type.numAttributes(); i++) { + if (!type.is_parameter(i) && !type.is_buffer(i) && !type.getAttribute(i).is_module()) { + const k = type.getAttributeName(i); + target[k] = obj.__getattr__(k); + } + } + type = obj._c.qualified_name; + obj = target; + } if (!type) { - if (torch && pytorch.Utility.isInstance(obj, 'torch.jit._script.RecursiveScriptModule') && obj._c && obj._c.qualified_name) { + if (torch && obj instanceof torch.jit._script.RecursiveScriptModule && obj._c && obj._c.qualified_name) { type = obj._c.qualified_name; } else if (pytorch.Utility.isInstance(obj, 'builtins.function')) { type = `${obj.__module__}.${obj.__name__}`; @@ -811,14 +824,14 @@ pytorch.Node = class { this.inputs.push(argument); } else if (name === '_modules' && pytorch.Utility.isInstance(value, 'collections.OrderedDict') && value instanceof Map && Array.from(value).every(([, value]) => value === null || value.__class__)) { - const values = Array.from(value).filter(([, value]) => !stack.has(value)).map(([name, obj]) => { + const list = Array.from(value).filter(([, value]) => !stack.has(value)).map(([name, obj]) => { stack.add(value); const type = obj === null ? 'builtins.NoneType' : `${obj.__class__.__module__}.${obj.__class__.__name__}`; - const node = new pytorch.Node(execution, metadata, this.name ? `${this.name}.${name}` : name, type, obj); + const node = new pytorch.Node(execution, metadata, this.name ? `${this.name}.${name}` : name, type, obj, initializers, values, stack); stack.delete(value); return node; }); - const argument = new pytorch.Argument(name, values, 'object[]'); + const argument = new pytorch.Argument(name, list, 'object[]'); this.inputs.push(argument); } else if (value && Array.isArray(value) && value.length > 0 && value.every((obj) => Array.isArray(obj) && obj.every((item) => typeof item === 'string' || typeof item === 'number'))) { const argument = new pytorch.Argument(name, value, 'attribute'); @@ -1300,15 +1313,10 @@ pytorch.Container.Zip = class extends pytorch.Container { const version = reader.version(); if (torchscript) { this.execution.trace = false; - const module = torch.jit.load(reader); + this.module = torch.jit.load(reader); this.execution.trace = true; metadata.register(this.execution); - if (module.data && module.data.forward) { - this.module = module; - } else { - torchscript = false; - this.module = module.data; - } + torchscript = this.module.forward; } else { const records = reader.get_all_records().map((key) => [key, reader.get_record(key)]); const entries = new Map(records); @@ -1353,9 +1361,15 @@ pytorch.Container.ModelJson = class extends pytorch.Container { 'version', ...this._model.tensors.filter((tensor) => tensor && tensor.data && tensor.data.key).map((tensor) => tensor.data.key) ]; - if (this._model.mainModule.torchscriptArena && this._model.mainModule.torchscriptArena.key) { - keys.push(this._model.mainModule.torchscriptArena.key); - } + const walk = (module) => { + if (module.torchscriptArena && module.torchscriptArena.key) { + keys.push(module.torchscriptArena.key); + } + for (const submodule of module.submodules || []) { + walk(submodule); + } + }; + walk(this._model.mainModule); const values = await Promise.all(keys.map((name) => this._context.fetch(name).then((context) => context.stream).catch(() => null))); for (let i = 0; i < keys.length; i++) { if (values[i]) { @@ -1374,14 +1388,14 @@ pytorch.Container.ModelJson = class extends pytorch.Container { } this.format = reader.has_record('attributes.pkl') ? 'TorchScript v1.1' : 'TorchScript v1.0'; this.execution.trace = false; - const module = torch.jit.load(reader); + this.module = torch.jit.load(reader); this.execution.trace = true; metadata.register(this.execution); - if (module.data && module.data.forward) { + /* if (module.data && module.data.forward) { this.module = module; } else { this.module = module.data; - } + }*/ delete this._context; delete this._model; delete this._entries; @@ -1767,7 +1781,12 @@ pytorch.Execution = class extends python.Execution { if (path) { let target = null; for (let i = path.length - 1; i >= 0; i--) { - target = target ? target[path[i]] : context.get(path[i]); + const name = path[i]; + if (target) { + target = target.__getattr__ ? target.__getattr__(name) : target[name]; + } else { + target = context.get(name); + } if (!target) { break; } @@ -2009,7 +2028,7 @@ pytorch.Execution = class extends python.Execution { this._graph.insertNode(node); return node.output(); } - return target[attr]; + return target.__getattr__ ? target.__getattr__(attr) : target[attr]; } case 'List': { const list = expr.elts.map((item) => this.expression(item, context)); @@ -2047,6 +2066,8 @@ pytorch.Execution = class extends python.Execution { let value = null; if (elt instanceof torch.Value) { value = elt; + } else if (pytorch.Utility.isTensor(elt)) { + value = this.variable(elt, null); } else if (elt === null || Number.isInteger(elt) || typeof elt === 'number' || typeof elt === 'boolean' || typeof elt === 'string') { value = this._graph.insertConstant(elt); } else { @@ -2143,7 +2164,7 @@ pytorch.Execution = class extends python.Execution { } case 'Attribute': { const target = this.target(expr.value, context); - return target[expr.attr]; + return target.__getattr__ ? target.__getattr__(expr.attr) : target[expr.attr]; } case 'Call': { const func = expr.func; @@ -2528,7 +2549,9 @@ pytorch.Execution = class extends python.Execution { statement(stmt, context) { if (stmt.__class__.__name__ === 'ClassDef') { const name = `${context.get('__name__')}.${stmt.name}`; - this._resolver.resolveType(name); + if (this._resolver) { + this._resolver.resolveType(name); + } } if (!this.trace) { @@ -3315,8 +3338,8 @@ pytorch.Container.Package = class extends pytorch.Container { this.entries = entries; } - async read() { - this.execution = new python.Execution(); + async read(metadata) { + this.execution = new pytorch.Execution(null, metadata); for (const event of this._events) { this.execution.on(event[0], event[1]); } @@ -3504,8 +3527,17 @@ pytorch.Utility = class { } static weights(obj) { - const type = obj && obj.__class__ && obj.__class__.__module__ && obj.__class__.__name__ ? `${obj.__class__.__module__}.${obj.__class__.__name__}` : null; - if (type && type !== 'builtins.dict' && type !== 'builtins.object' && type !== 'collections.OrderedDict' && type !== 'torch.nn.modules.module.Module' && type !== '__torch__.Module') { + let type = obj && obj.__class__ && obj.__class__.__module__ && obj.__class__.__name__ ? `${obj.__class__.__module__}.${obj.__class__.__name__}` : null; + if (type === 'torch.jit._script.RecursiveScriptModule') { + type = obj._c._type; + const target = {}; + for (let i = 0; i < type.numAttributes(); i++) { + const k = type.getAttributeName(i); + target[k] = obj.__getattr__(k); + } + type = obj._c.qualified_name; + obj = target; + } else if (type && type !== 'builtins.dict' && type !== 'builtins.object' && type !== 'collections.OrderedDict' && type !== 'torch.nn.modules.module.Module' && type !== '__torch__.Module') { return null; } if (pytorch.Utility.isTensor(obj)) { diff --git a/test/models.json b/test/models.json index 597a143018..00bc9511fc 100644 --- a/test/models.json +++ b/test/models.json @@ -6243,7 +6243,7 @@ "target": "resnet18.ot", "source": "https://github.com/lutzroeder/netron/files/7664092/resnet18.ot.zip[resnet18.ot]", "format": "TorchScript v1.0", - "assert": "model.graphs[0].nodes[1].inputs[0].value[0].name == 'conv1|weight'", + "assert": "model.graphs[0].nodes[0].inputs[0].value[0].name == 'conv1|weight'", "link": "https://github.com/lutzroeder/netron/issues/686" }, { @@ -6381,7 +6381,7 @@ "target": "segmentor.pt", "source": "https://github.com/lutzroeder/netron/files/7663953/segmentor.pt.zip[segmentor.pt]", "format": "PyTorch v1.6", - "assert": "model.graphs[0].nodes[0].inputs[0].value.inputs[0].value.type.name == '__torch__.___torch_mangle_1.Module'", + "assert": "model.graphs[0].nodes[0].inputs[0].value[0].inputs[0].value[0].type.name == '__torch__.___torch_mangle_1.Module'", "link": "https://github.com/lutzroeder/netron/issues/686" }, {