diff --git a/source/python.js b/source/python.js index e53e875dd3..92167dca37 100644 --- a/source/python.js +++ b/source/python.js @@ -4967,6 +4967,35 @@ python.Execution = class { } return this[name]; } + __delattr__(name) { + if (this._modules.has(name)) { + this._modules.delete(name); + } + } + children() { + return this._modules.values(); + } + named_children() { + return this._modules; + } + parameters() { + return this._parameters.values(); + } + named_parameters(recurse) { + if (recurse) { + throw new python.Error('Named parameters with recurse not implemented.'); + } + return this._parameters; + } + buffers() { + return this._buffers.values(); + } + named_buffers(recurse) { + if (recurse) { + throw new python.Error('Named parameters with recurse not implemented.'); + } + return this._buffers; + } }); torch.nn.Module = torch.nn.modules.module.Module; torch.nn.modules.Module = torch.nn.modules.module.Module; @@ -6210,14 +6239,14 @@ python.Execution = class { }); this.registerFunction('torch._utils._rebuild_tensor_v3'); this.registerFunction('torch._utils._rebuild_parameter', (data, requires_grad, backward_hooks) => { - const param = self.invoke('torch.nn.parameter.Parameter', [data, requires_grad]); + const param = new torch.nn.parameter.Parameter(data, requires_grad); param.backward_hooks = backward_hooks; return param; }); this.registerFunction('torch._utils._rebuild_parameter_v2', (data, requires_grad, backward_hooks, state) => { - const param = self.invoke('torch.nn.parameter.Parameter', [data, requires_grad]); + const param = new torch.nn.parameter.Parameter(data, requires_grad); param.backward_hooks = backward_hooks; - execution.invoke('torch._utils._set_obj_state', [param, state]); + torch._utils._set_obj_state(param, state); return param; }); this.registerFunction('torch._utils._rebuild_parameter_with_state', (data, requires_grad, backward_hooks, state) => { @@ -6225,16 +6254,16 @@ python.Execution = class { const [dict_state, slots_state] = Array.isArray(state) ? state : [state, null]; if (dict_state) { for (const [k, v] of Object.entries(dict_state)) { - self.invoke('builtins.setattr', [obj, k, v]); + builtins.setattr(obj, k, v); } } if (slots_state) { for (const [k, v] of Object.entries(slots_state)) { - self.invoke('builtins.setattr', [obj, k, v]); + builtins.setattr(obj, k, v); } } }; - const param = self.invoke('torch.nn.parameter.Parameter', [data, requires_grad]); + const param = new torch.nn.parameter.Parameter(data, requires_grad); param._backward_hooks = backward_hooks; _set_obj_state(param, state); return param; @@ -6255,12 +6284,12 @@ python.Execution = class { } if (dict_state) { for (const [name, value] of Object.entries(dict_state)) { - execution.invoke('builtins.setattr', [obj, name, value]); + builtins.setattr(obj, name, value); } } if (slots_state) { for (const [name, value] of Object.entries(slots_state)) { - execution.invoke('builtins.setattr', [obj, name, value]); + builtins.setattr(obj, name, value); } } return obj; @@ -6958,6 +6987,9 @@ python.Execution = class { } return this.equals(rhs); } + is_module() { + return false; + } expect(type) { if (this instanceof type === false) { throw new python.Error(`Expected '${type.kind()}' but got '${this.kind()}'.`); @@ -7004,6 +7036,12 @@ python.Execution = class { is_module() { return this._is_module; } + is_parameter(slot) { + return this._attributes[slot].is_parameter === true; + } + is_buffer(slot) { + return this._attributes[slot].is_buffer === true; + } addMethod(func) { this._methods.set(func.name(), func); } @@ -7023,6 +7061,9 @@ python.Execution = class { findStaticMethod(name) { return this._staticmethods.get(name); } + numAttributes() { + return this._attributes.length; + } addAttribute(name, type, is_parameter, is_buffer) { is_parameter = is_parameter || false; is_buffer = is_buffer || false; @@ -7057,10 +7098,16 @@ python.Execution = class { } return null; } - getAttribute(name) { - const slot = this.findAttributeSlot(name); + hasAttribute(name) { + return this._attributes.find((attr) => attr.name === name); + } + getAttribute(arg) { + const slot = Number.isInteger(arg) ? arg : this.findAttributeSlot(arg); return this._attributes[slot].type; } + getAttributeName(slot) { + return this._attributes[slot].name; + } hasConstant(/* name */) { } methods() { @@ -8131,7 +8178,10 @@ python.Execution = class { parseBroadcastList(/* expr */) { return null; } - + parseType(str) { + const expr = ast.parse(str); + return this.parseTypeFromExpr(expr.body[0]); + } }); this.registerType('torch._ops.OpOverload', class extends torch._ops.OperatorBase { constructor(overloadpacket, op, op_dk, schema, tags) { @@ -8874,8 +8924,9 @@ python.Execution = class { this._loaded_sources = new Set(); this._to_be_defined = new Map(); } - loadType(/* name */) { - // + loadType(name) { + const type_parser = new torch.jit.ScriptTypeParser(this); + return type_parser.parseType(name.qualifiedName()); } resolveType(name) { name = new torch.jit.QualifiedName(name); @@ -9120,12 +9171,32 @@ python.Execution = class { for (let i = 0; i < constants.length; i++) { execution.builtins.CONSTANTS[`c${i}`] = constants[i]; } - const module = this.readArchive('data'); - const name = `${module.__class__.__module__}.${module.__class__.__name__}`; - const type = torch.ClassType.create(name, null, true); - const result = new torch.ScriptModule(type); - result.data = module; - return result; + const obj = this.readArchive('data'); + const convertModule = (obj) => { + if (obj.__class__) { + const name = `${obj.__class__.__module__}.${obj.__class__.__name__}`; + const type = this._source_importer.loadType(new torch.jit.QualifiedName(name)); + const module = new torch.ScriptModule(type, this._compilation_unit); + for (let i = 0; i < type.numAttributes(); i++) { + const k = type.getAttributeName(i); + const t = type.getAttribute(i); + const v = obj[k]; + if (t.is_module()) { + module.__setattr__(k, convertModule(v)); + } else { + module.__setattr__(k, obj[k]); + } + } + for (const [key, value] of Object.entries(Object.getPrototypeOf(obj))) { + if (value && value.__class__ === builtins.method) { + module[key] = value; + } + } + return module; + } + throw new python.Error('Module class not found.'); + }; + return convertModule(obj); } LEGACY_deserialize() { const execution = this._compilation_unit.execution; @@ -9186,8 +9257,8 @@ python.Execution = class { for (const tensor of tensor_table) { this._constant_table.push(tensor); } - const temp = this.LEGACY_convertModule(module_def); - const data = obj.mainModule || {}; + return this.LEGACY_convertModule(module_def); + /* const data = obj.mainModule || {}; const queue = [data]; while (queue.length > 0) { const module = queue.shift(); @@ -9237,6 +9308,8 @@ python.Execution = class { const result = new torch.ScriptModule(temp.type()); result.data = data; return result; + return module; + */ } LEGACY_convertModule(module_def) { const atoms = new torch.jit.QualifiedName(module_def.name).atoms(); @@ -9245,13 +9318,14 @@ python.Execution = class { const sanitized = /^\d+$/.test(atom) ? `_${atom}` : atom; this._LEGACY_moduleStack.push(sanitized); } - const module = new torch.ScriptModule(new torch.jit.QualifiedName(this._LEGACY_moduleStack), this._compilation_unit); + const qn = new torch.jit.QualifiedName(this._LEGACY_moduleStack); + const module = new torch.ScriptModule(qn, this._compilation_unit); for (const sub_def of module_def.submodules || []) { const submodule = this.LEGACY_convertModule(sub_def); module.register_module(sub_def.name, submodule); } for (const param_def of module_def.parameters || []) { - const tensor = this._constant_table[Number(param_def.tensorId)]; + const tensor = this._constant_table[Number(param_def.tensor_id)]; if (param_def.isBuffer) { module.register_buffer(param_def.name, tensor); } else { @@ -9263,11 +9337,21 @@ python.Execution = class { if (module.hasattr(attr_def.name)) { continue; } + throw new python.Error('Not implemented.'); // IValue ivalue; // if (attr_def.id() >= 0) { // ivalue = LEGACY_pickled_ivalues_.at(attr_def.id()); // } - // module.register_attribute(attr_def.name(), typeParser.parseType(attr_def.type()), ivalue); + // module.register_attribute(attr_def.name, typeParser.parseType(attr_def.type), ivalue); + } + if (module_def.torchscript_arena) { + const key = module_def.torchscript_arena.key; + const file = key.substring('code/'.length); + const name = file.replace(/\.py$/, '').split('/').join('.'); + const code = execution.import(name); + if (code.forward.__class__ === execution.builtins.function) { + module.forward = code.forward; + } } /* std::shared_ptr gen_ranges = nullptr; @@ -9299,9 +9383,13 @@ python.Execution = class { return module; } readArchive(archive_name) { - const type_resolver = null; - const obj_loader = null; - return this.readArchiveAndTensors(archive_name, this._pickle_dir_prefix, this._tensor_dir_prefix, type_resolver, obj_loader, this._device, this._reader, null, this._storage_context); + const type_resolver = (qn) => { + const cls = this._source_importer.loadType(qn); + return cls; + }; + const ObjLoaderFunc = (/* type, ivalue */) => { + }; + return this.readArchiveAndTensors(archive_name, this._pickle_dir_prefix, this._tensor_dir_prefix, type_resolver, ObjLoaderFunc, this._device, this._reader, null, this._storage_context); } readArchiveAndTensors(archive_name, pickle_prefix, tensor_prefix, type_resolver, obj_loader, device, stream_reader, type_parser, storage_context) { const picklename = `${pickle_prefix + archive_name}.pkl`; @@ -9405,7 +9493,9 @@ python.Execution = class { const cu = new torch.jit.CompilationUnit(); cu.execution = execution; const cpp_module = torch._C.import_ir_module(cu, file, map_location, extra_files); - return new torch.jit._script.RecursiveScriptModule(cpp_module); + const module = torch.jit._script.wrap_cpp_module(cpp_module); + module.forward = cpp_module.forward; // remove + return module; }); this.registerFunction('torch._C.import_ir_module', function(cu, reader, ...args) { switch (arguments.length) { @@ -9495,7 +9585,7 @@ python.Execution = class { const cu = new torch.jit.CompilationUnit(); cu.execution = execution; const cpp_module = torch._C._import_ir_module_from_package(cu, importer.zip_reader, importer.storage_context, importer.last_map_location, script_module_id); - return new torch.jit._script.RecursiveScriptModule(cpp_module); + return torch.jit._script.wrap_cpp_module(cpp_module); }); this.registerFunction('torch.jit._script.wrap_cpp_module', (cpp_module) => { const init_fn = (script_module) => { @@ -9552,7 +9642,8 @@ python.Execution = class { }); this.registerType('torch.ScriptObject', class { constructor(type) { - this._type = type; + this._typ = type; + this._ivalue = {}; } static create(type) { if (type.is_module()) { @@ -9561,10 +9652,10 @@ python.Execution = class { return new torch.ScriptObject(type); } type() { - return this._type; + return this._typ; } _type() { - return this._type; // torch.ClassType + return this._typ; // torch.ClassType } _get_method(name) { for (const fn of this._type.methods()) { @@ -9579,13 +9670,16 @@ python.Execution = class { } __setattr__(name, value) { // if (this._type.hasContant(name)) - this[name] = value; + this._ivalue[name] = value; } __getattr__(name) { - return this[name]; + return this._ivalue[name]; } hasattr(name) { - return this._type.hasAttribute(name) || this._type.hasConstant(name); + return this._typ.hasAttribute(name) || this._typ.hasConstant(name); + } + getattr(name) { + return this.__getattr__(name); } _properties() { throw new python.Error(); @@ -9601,7 +9695,7 @@ python.Execution = class { } } get qualified_name() { - return this._type.qualified_name(); + return this.type().qualified_name(); } get code_with_constants() { const const_map = {}; @@ -9622,19 +9716,16 @@ python.Execution = class { return false; } }; - if (!this.data) { + if (!this.forward) { return null; } - if (!this.data.forward) { - throw new python.Error("Module 'forward' not implemented."); - } execution.traceAttr = false; const args = []; if (!execution.traceAttr) { - args.push(this.data); // self + args.push(this); // self } - if (this.data.forward.__code__ && this.data.forward.__code__.args) { - const params = this.data.forward.__code__.args.args; + if (this.forward.__code__ && this.forward.__code__.args) { + const params = this.forward.__code__.args.args; for (let i = 0; i < params.length; i++) { const arg = params[i]; if (execution.traceAttr || arg.arg !== 'self') { @@ -9653,7 +9744,7 @@ python.Execution = class { } } execution.purge = new Set(); - const result = this.data.forward.__call__(args); + const result = this.forward.__call__(args); const queue = Array.from(execution.purge); const visited = new Set(); while (queue.length > 0) { @@ -9715,15 +9806,15 @@ python.Execution = class { } register_module(name, module) { this.type().addOrCheckAttribute(name, module.type()); - // _ivalue()->setAttr(name, module._ivalue()); + this.__setattr__(name, module); // _ivalue()->setAttr(name, module._ivalue()); } - register_buffer(name /* , v */) { + register_buffer(name, v) { this.type().addOrCheckAttribute(name, torch.TensorType.get(), false, true); - // _ivalue()->setAttr(name, std::move(v)); + this.__setattr__(name, v); // _ivalue()->setAttr(name, std::move(v)); } register_parameter(name, v, is_buffer) { this.type().addOrCheckAttribute(name, torch.TensorType.get(), !is_buffer, is_buffer); - // _ivalue()->setAttr(name, std::move(v)); + this.__setattr__(name, v); // _ivalue()->setAttr(name, std::move(v)); } register_attribute(name, t, v, is_param, is_buffer) { this.type().addOrCheckAttribute(name, t, is_param, is_buffer); @@ -9731,11 +9822,59 @@ python.Execution = class { } }); this.registerType('torch.ModuleDict', class { - constructor(module) { - this._items = Object.entries(module).filter(([, value]) => value instanceof torch.ScriptModule); + constructor(mod) { + this._module = mod; + } + items() { + const result = new Map(); + const type = this._module.type(); + for (let i = 0; i < type.numAttributes(); i++) { + const k = type.getAttributeName(i); + const t = type.getAttribute(i); + if (t && t.is_module()) { + result.set(k, this._module.__getattr__(k)); + } + } + return result; + } + }); + this.registerType('torch.ParameterDict', class { + constructor(mod) { + this._module = mod; + } + items() { + const result = new Map(); + const type = this._module.type(); + for (let i = 0; i < type.numAttributes(); i++) { + if (type.is_parameter(i)) { + const k = type.getAttributeName(i); + const v = this._module.__getattr__(k); + if (v instanceof torch.Tensor) { + result.set(k, v); + } + } + } + return result; + } + }); + this.registerType('torch.BufferDict', class { + constructor(mod) { + this._module = mod; } items() { - return this._items; + const result = new Map(); + const type = this._module.type(); + for (let i = 0; i < type.numAttributes(); i++) { + if (type.is_buffer(i)) { + const t = type.getAttribute(i); + if (t.isSubtypeOf(torch.TensorType.get())) { + const k = type.getAttributeName(i); + const v = this._module.__getattr__(k); + result.set(k, v); + } + } + } + return result; } }); this.registerType('torch.jit.to_ir', class { @@ -9846,11 +9985,11 @@ python.Execution = class { torch.jit._script.RecursiveScriptModule._finalize_scriptmodule(script_module); return script_module; } - static _finalize_scriptmodule() { - this._initializing = false; - } - get data() { - return this._c.data; + static _finalize_scriptmodule(script_module) { + script_module._parameters = new torch.ParameterDict(script_module._c).items(); + script_module._buffers = new torch.BufferDict(script_module._c).items(); + // script_module._modules = OrderedModuleDict(script_module._c, script_module._modules) + script_module._initializing = false; } get graph() { // return this._c._get_method("forward").graph; @@ -9863,8 +10002,8 @@ python.Execution = class { __setattr__(name, value) { if (this._initializing) { super.__setattr__(name, value); - } else if (this.modules.has(name)) { - this.modules.set(name, value); + } else if (this._modules.has(name)) { + this._modules.set(name, value); } else if (this._c.hasattr(name)) { this._c.setattr(name, value); } else { @@ -9875,8 +10014,8 @@ python.Execution = class { if (this._initializing) { return super.__getattr__(name); } - if (this.modules.has(name)) { - return this.modules.get(name); + if (this._modules.has(name)) { + return this._modules.get(name); } if (this._c.hasattr(name)) { return this._c.getattr(name); @@ -11646,10 +11785,7 @@ python.Execution = class { this.registerType('torch.nn.parameter.Parameter', class extends torch.Tensor { constructor(data, requires_grad) { super(); - if (!data) { - data = self.invoke('torch.Tensor', [[]]); - } - this.data = data; + this.data = data || new torch.Tensor([]); this.requires_grad = requires_grad === undefined ? true : requires_grad; } }); @@ -12544,7 +12680,12 @@ python.Execution = class { if (path) { let target = null; for (let i = path.length - 1; i >= 0; i--) { - target = target ? target[path[i]] : context.get(path[i]); + const name = path[i]; + if (target) { + target = target.__getattr__ ? target.__getattr__(name) : target[name]; + } else { + target = context.get(name); + } if (!target) { break; } diff --git a/source/pytorch.js b/source/pytorch.js index fc475dbbfa..7b0349bbc0 100644 --- a/source/pytorch.js +++ b/source/pytorch.js @@ -89,8 +89,7 @@ pytorch.Graph = class { return values.get(name); }; const torch = execution ? execution.torch : null; - // type = module && module.__class__ && module.__class__.__module__ && module.__class__.__name__ ? `${module.__class__.__module__}.${module.__class__.__name__}` : null; - if (torch && (module instanceof torch.ScriptModule || module instanceof torch.jit._script.ScriptModule || module instanceof torch.jit._script.RecursiveScriptModule) && module.graph) { + if (torch && module instanceof torch.jit._script.RecursiveScriptModule && module.graph) { const initializers = new Map(); const graph = module.graph; const constants = module.code_with_constants[1].const_mapping; @@ -106,10 +105,36 @@ pytorch.Graph = class { } } } - const queue = [module.data]; + const queue = [module]; while (queue.length > 0) { const module = queue.shift(); - for (const [key, obj] of Object.entries(module)) { + const children = module.named_children(); + for (const [key, obj] of children) { + obj.__parent__ = module; + obj.__name__ = obj.__name__ || key; + queue.push(obj); + const type = obj._c._type(); + for (let i = 0; i < type.numAttributes(); i++) { + const k = type.getAttributeName(i); + const v = obj.__getattr__(k); + if (pytorch.Utility.isObject(v)) { + initializers.set(v, v); + } + } + } + for (const buffer of module.buffers()) { + buffer.__parent__ = module; + if (buffer.storage() && !buffer.__origin__ && (buffer.__count__ === undefined || buffer.__count__ === 1)) { + initializers.set(buffer, new pytorch.Tensor(buffer.name, buffer)); + } + } + for (const parameter of module.parameters()) { + parameter.__parent__ = module; + if (parameter.storage() && !parameter.__origin__ && (parameter.__count__ === undefined || parameter.__count__ === 1)) { + initializers.set(parameter, new pytorch.Tensor(parameter.name, parameter)); + } + } + for (const [key, obj] of children) { if (key !== '__module__' && key !== '__name__' && key !== '__class__' && key !== '__parent__') { if (!Array.isArray(obj) && obj === Object(obj)) { if (pytorch.Utility.isTensor(obj)) { @@ -185,31 +210,19 @@ pytorch.Graph = class { this.nodes.push(new pytorch.Node(execution, metadata, null, null, node, initializers, values)); } if (module) { - const queue = [module.data]; + const queue = [module]; while (queue.length > 0) { const module = queue.pop(); - if (module && !pytorch.Utility.isObject(module)) { - if (!module.__hide__ && pytorch.Graph._getParameters(module).size > 0) { - for (const [name, obj] of Object.entries(module)) { - if ((obj && obj.__hide__) || (obj !== null && !pytorch.Utility.isTensor(obj)) && typeof obj !== 'boolean' && typeof obj !== 'number' && typeof obj !== 'string') { - delete module[name]; - } + if (module) { + const modules = Array.from(module.children()); + queue.push(...modules.reverse()); + if (!module.__hide__ && module.named_parameters().size > 0) { + for (const [name] of module.named_children()) { + module.__delattr__(name); } const node = new pytorch.Node(execution, metadata, null, null, module, initializers, values); this.nodes.push(node); } - const modules = []; - if (module.__class__ && module.__class__.__module__ && module.__class__.__name__) { - for (const [key, value] of Object.entries(module)) { - if (!key.startsWith('__') && value && value.__class__ && value.__class__.__module__ && value.__class__.__name__ && !pytorch.Utility.isTensor(value)) { - if (value instanceof torch.Value) { - continue; - } - modules.push(value); - } - } - } - queue.push(...modules.reverse()); } } } @@ -331,18 +344,6 @@ pytorch.Graph = class { } } } - - static _getParameters(module) { - const parameters = new Map(); - if (module && module.__class__.__module__ && module.__class__.__name__) { - for (const [key, value] of Object.entries(module)) { - if (pytorch.Utility.isTensor(value)) { - parameters.set(key, value); - } - } - } - return parameters; - } }; pytorch.Argument = class { @@ -425,6 +426,9 @@ pytorch.Node = class { values = Object.values(value); } else if (pytorch.Utility.isTensor(value)) { values = [value]; + } else if (Array.isArray(value) && value.every((value) => pytorch.Utility.isTensor(value))) { + values = value; + } else if (input instanceof torch.Value && input.type() instanceof torch.ListType && input.type().getElementType() instanceof torch.TensorType) { if (input.node() && input.node().kind() === 'prim::ListConstruct' && input.uses().length === 1 && @@ -451,9 +455,15 @@ pytorch.Node = class { } } if (module) { - const parameters = pytorch.Graph._getParameters(module); - parameters.delete('num_batches_tracked'); - if (parameters.size === count && match) { + const tensors = new Map(); + for (const [name, value] of module.named_parameters()) { + tensors.set(name, value); + } + for (const [name, value] of module.named_buffers()) { + tensors.set(name, value); + } + tensors.delete('num_batches_tracked'); + if (tensors.size === count && match) { module.__hide__ = true; } else { module = null; @@ -703,8 +713,24 @@ pytorch.Node = class { throw new pytorch.Error(`Unsupported node operation '${obj.op}'.`); } } else { + if (torch && obj instanceof torch.jit._script.RecursiveScriptModule && obj._c && obj._c.qualified_name) { + type = obj._c._type(); + const target = { + _modules: obj._modules, + _parameters: obj._parameters, + _buffers: obj._buffers, + }; + for (let i = 0; i < type.numAttributes(); i++) { + if (!type.is_parameter(i) && !type.is_buffer(i) && !type.getAttribute(i).is_module()) { + const k = type.getAttributeName(i); + target[k] = obj.__getattr__(k); + } + } + type = obj._c.qualified_name; + obj = target; + } if (!type) { - if (torch && pytorch.Utility.isInstance(obj, 'torch.jit._script.RecursiveScriptModule') && obj._c && obj._c.qualified_name) { + if (torch && obj instanceof torch.jit._script.RecursiveScriptModule && obj._c && obj._c.qualified_name) { type = obj._c.qualified_name; } else if (pytorch.Utility.isInstance(obj, 'builtins.function')) { type = `${obj.__module__}.${obj.__name__}`; @@ -751,17 +777,13 @@ pytorch.Node = class { } else if (pytorch.Utility.isInstance(value, 'torch.Size') && Array.isArray(value) && value.length === 0) { continue; } - const parameters = new Map(); - if ((name === '_parameters' || name === '_buffers') && value instanceof Map && value.size > 0) { - for (const [name, obj] of Array.from(value)) { - parameters.set(name, obj); - } - } else if (Array.isArray(value) && value.every((tensor) => pytorch.Utility.isTensor(tensor))) { - parameters.set(name, value); - } else if (pytorch.Utility.isTensor(value)) { - parameters.set(name, value); + let parameters = null; + if ((name === '_parameters' || name === '_buffers') && value instanceof Map) { + parameters = value; + } else if (pytorch.Utility.isTensor(value) || (Array.isArray(value) && value.every((tensor) => pytorch.Utility.isTensor(tensor)))) { + parameters = new Map([[name, value]]); } - if (parameters.size > 0) { + if (parameters) { for (const [name, value] of parameters) { const list = Array.isArray(value) ? value.map((item) => pytorch.Utility.toTensor(item)) : [pytorch.Utility.toTensor(value)]; const visible = inputs.has(name) ? inputs.get(name).visible || true : true; @@ -786,7 +808,6 @@ pytorch.Node = class { } continue; } - const type = this.type.identifier; if (pytorch.Utility.isTensor(value)) { const tensor = new pytorch.Tensor('', value); const argument = new pytorch.Argument(name, tensor, 'tensor'); @@ -811,14 +832,14 @@ pytorch.Node = class { this.inputs.push(argument); } else if (name === '_modules' && pytorch.Utility.isInstance(value, 'collections.OrderedDict') && value instanceof Map && Array.from(value).every(([, value]) => value === null || value.__class__)) { - const values = Array.from(value).filter(([, value]) => !stack.has(value)).map(([name, obj]) => { + const list = Array.from(value).filter(([, value]) => !stack.has(value)).map(([name, obj]) => { stack.add(value); const type = obj === null ? 'builtins.NoneType' : `${obj.__class__.__module__}.${obj.__class__.__name__}`; - const node = new pytorch.Node(execution, metadata, this.name ? `${this.name}.${name}` : name, type, obj); + const node = new pytorch.Node(execution, metadata, this.name ? `${this.name}.${name}` : name, type, obj, initializers, values, stack); stack.delete(value); return node; }); - const argument = new pytorch.Argument(name, values, 'object[]'); + const argument = new pytorch.Argument(name, list, 'object[]'); this.inputs.push(argument); } else if (value && Array.isArray(value) && value.length > 0 && value.every((obj) => Array.isArray(obj) && obj.every((item) => typeof item === 'string' || typeof item === 'number'))) { const argument = new pytorch.Argument(name, value, 'attribute'); @@ -841,34 +862,30 @@ pytorch.Node = class { const argument = new pytorch.Argument(name, node, 'object', visible); this.inputs.push(argument); } else { - const createAttribute = (metadata, name, value) => { - let visible = true; - let type = 'attribute'; - metadata = name === 'training' ? { type: 'boolean', visible: false } : metadata; - if (metadata) { - if (metadata.type) { - type = metadata.type; - } - if (metadata.visible === false) { - visible = false; - } else if (metadata.default !== undefined) { - if (Array.isArray(value)) { - if (Array.isArray(metadata.default)) { - visible = value.length !== metadata.default || !value.every((item, index) => item === metadata.default[index]); - } else { - visible = !value.every((item) => item === metadata.default); - } + let schema = metadata.attribute(this.type.identifier, name); + schema = name === 'training' ? { type: 'boolean', visible: false } : schema; + let visible = true; + let obj = value; + const type = schema && schema.type ? schema.type : 'attribute'; + if (schema) { + if (schema.visible === false) { + visible = false; + } else if (schema.default !== undefined) { + if (Array.isArray(obj)) { + if (Array.isArray(schema.default)) { + visible = obj.length !== schema.default || !obj.every((item, index) => item === schema.default[index]); } else { - visible = value !== metadata.default; + visible = !obj.every((item) => item === schema.default); } + } else { + visible = obj !== schema.default; } } - if (Array.isArray(value) && value.length > 0 && value.every((obj) => obj && obj.__class__ && obj.__class__.__module__ && obj.__class__.__module__.startsWith('torch.nn'))) { - value = '?'; - } - return new pytorch.Argument(name, value, type, visible); - }; - const argument = createAttribute(metadata.attribute(type, name), name, value); + } + if (Array.isArray(obj) && obj.length > 0 && obj.every((obj) => obj && obj.__class__ && obj.__class__.__module__ && obj.__class__.__module__.startsWith('torch.nn'))) { + obj = '?'; + } + const argument = new pytorch.Argument(name, obj, type, visible); this.inputs.push(argument); } } @@ -890,6 +907,7 @@ pytorch.Tensor = class { constructor(name, tensor) { this.name = name || ''; + tensor = tensor.data ? tensor.data : tensor; const layout = tensor.layout ? tensor.layout.__str__() : null; const storage = tensor.storage(); const size = tensor.size() || []; @@ -1300,15 +1318,10 @@ pytorch.Container.Zip = class extends pytorch.Container { const version = reader.version(); if (torchscript) { this.execution.trace = false; - const module = torch.jit.load(reader); + this.module = torch.jit.load(reader); this.execution.trace = true; metadata.register(this.execution); - if (module.data && module.data.forward) { - this.module = module; - } else { - torchscript = false; - this.module = module.data; - } + torchscript = this.module.forward; } else { const records = reader.get_all_records().map((key) => [key, reader.get_record(key)]); const entries = new Map(records); @@ -1353,9 +1366,15 @@ pytorch.Container.ModelJson = class extends pytorch.Container { 'version', ...this._model.tensors.filter((tensor) => tensor && tensor.data && tensor.data.key).map((tensor) => tensor.data.key) ]; - if (this._model.mainModule.torchscriptArena && this._model.mainModule.torchscriptArena.key) { - keys.push(this._model.mainModule.torchscriptArena.key); - } + const walk = (module) => { + if (module.torchscriptArena && module.torchscriptArena.key) { + keys.push(module.torchscriptArena.key); + } + for (const submodule of module.submodules || []) { + walk(submodule); + } + }; + walk(this._model.mainModule); const values = await Promise.all(keys.map((name) => this._context.fetch(name).then((context) => context.stream).catch(() => null))); for (let i = 0; i < keys.length; i++) { if (values[i]) { @@ -1374,14 +1393,9 @@ pytorch.Container.ModelJson = class extends pytorch.Container { } this.format = reader.has_record('attributes.pkl') ? 'TorchScript v1.1' : 'TorchScript v1.0'; this.execution.trace = false; - const module = torch.jit.load(reader); + this.module = torch.jit.load(reader); this.execution.trace = true; metadata.register(this.execution); - if (module.data && module.data.forward) { - this.module = module; - } else { - this.module = module.data; - } delete this._context; delete this._model; delete this._entries; @@ -1767,7 +1781,12 @@ pytorch.Execution = class extends python.Execution { if (path) { let target = null; for (let i = path.length - 1; i >= 0; i--) { - target = target ? target[path[i]] : context.get(path[i]); + const name = path[i]; + if (target) { + target = target.__getattr__ ? target.__getattr__(name) : target[name]; + } else { + target = context.get(name); + } if (!target) { break; } @@ -2009,7 +2028,7 @@ pytorch.Execution = class extends python.Execution { this._graph.insertNode(node); return node.output(); } - return target[attr]; + return target.__getattr__ ? target.__getattr__(attr) : target[attr]; } case 'List': { const list = expr.elts.map((item) => this.expression(item, context)); @@ -2047,6 +2066,8 @@ pytorch.Execution = class extends python.Execution { let value = null; if (elt instanceof torch.Value) { value = elt; + } else if (pytorch.Utility.isTensor(elt)) { + value = this.variable(elt, null); } else if (elt === null || Number.isInteger(elt) || typeof elt === 'number' || typeof elt === 'boolean' || typeof elt === 'string') { value = this._graph.insertConstant(elt); } else { @@ -2143,7 +2164,7 @@ pytorch.Execution = class extends python.Execution { } case 'Attribute': { const target = this.target(expr.value, context); - return target[expr.attr]; + return target.__getattr__ ? target.__getattr__(expr.attr) : target[expr.attr]; } case 'Call': { const func = expr.func; @@ -2528,7 +2549,9 @@ pytorch.Execution = class extends python.Execution { statement(stmt, context) { if (stmt.__class__.__name__ === 'ClassDef') { const name = `${context.get('__name__')}.${stmt.name}`; - this._resolver.resolveType(name); + if (this._resolver) { + this._resolver.resolveType(name); + } } if (!this.trace) { @@ -3315,8 +3338,8 @@ pytorch.Container.Package = class extends pytorch.Container { this.entries = entries; } - async read() { - this.execution = new python.Execution(); + async read(metadata) { + this.execution = new pytorch.Execution(null, metadata); for (const event of this._events) { this.execution.on(event[0], event[1]); } @@ -3504,8 +3527,17 @@ pytorch.Utility = class { } static weights(obj) { - const type = obj && obj.__class__ && obj.__class__.__module__ && obj.__class__.__name__ ? `${obj.__class__.__module__}.${obj.__class__.__name__}` : null; - if (type && type !== 'builtins.dict' && type !== 'builtins.object' && type !== 'collections.OrderedDict' && type !== 'torch.nn.modules.module.Module' && type !== '__torch__.Module') { + let type = obj && obj.__class__ && obj.__class__.__module__ && obj.__class__.__name__ ? `${obj.__class__.__module__}.${obj.__class__.__name__}` : null; + if (type === 'torch.jit._script.RecursiveScriptModule') { + type = obj._c._type(); + const target = {}; + for (let i = 0; i < type.numAttributes(); i++) { + const k = type.getAttributeName(i); + target[k] = obj.__getattr__(k); + } + type = obj._c.qualified_name; + obj = target; + } else if (type && type !== 'builtins.dict' && type !== 'builtins.object' && type !== 'collections.OrderedDict' && type !== 'torch.nn.modules.module.Module' && type !== '__torch__.Module') { return null; } if (pytorch.Utility.isTensor(obj)) { diff --git a/test/models.json b/test/models.json index 597a143018..9901e2514c 100644 --- a/test/models.json +++ b/test/models.json @@ -5301,6 +5301,7 @@ "target": "alexnet_traced.pt.zip", "source": "https://github.com/lutzroeder/netron/files/6096602/alexnet_traced.pt.zip", "format": "TorchScript v1.6", + "assert": "model.graphs[0].nodes.length == 28", "link": "https://github.com/lutzroeder/netron/issues/281" }, { @@ -5425,6 +5426,7 @@ "target": "coco128-yolov8n-seg_output.torchscript.ptl", "source": "https://github.com/user-attachments/files/16091260/coco128-yolov8n-seg_output.torchscript.ptl.zip[coco128-yolov8n-seg_output.torchscript.ptl]", "format": "TorchScript v1.6", + "assert": "model.graphs[0].nodes[0].inputs[1].value.type.name == '__torch__.torch.classes.xnnpack.Conv2dOpContext'", "link": "https://github.com/lutzroeder/netron/issues/1067" }, { @@ -5492,6 +5494,7 @@ "target": "cruise_go_vehicle_model.pt", "source": "https://raw.githubusercontent.com/ApolloAuto/apollo/master/modules/prediction/data/cruise_go_vehicle_model.pt", "format": "TorchScript v1.0", + "assert": "model.graphs[0].nodes.length == 73", "link": "https://github.com/ApolloAuto/apollo" }, { @@ -5522,6 +5525,7 @@ "target": "deeplabv3_scripted.pt", "source": "https://github.com/lutzroeder/netron/files/5604999/deeplabv3_scripted.pt.zip[deeplabv3_scripted.pt]", "format": "TorchScript v1.6", + "assert": "model.graphs[0].nodes.length == 478", "link": "https://github.com/lutzroeder/netron/issues/630" }, { @@ -5747,6 +5751,7 @@ "target": "lane_scanning_vehicle_model.pt", "source": "https://raw.githubusercontent.com/ApolloAuto/apollo/master/modules/prediction/data/lane_scanning_vehicle_model.pt", "format": "TorchScript v1.0", + "assert": "model.graphs[0].nodes.length == 121", "link": "https://github.com/ApolloAuto/apollo" }, { @@ -6243,7 +6248,7 @@ "target": "resnet18.ot", "source": "https://github.com/lutzroeder/netron/files/7664092/resnet18.ot.zip[resnet18.ot]", "format": "TorchScript v1.0", - "assert": "model.graphs[0].nodes[1].inputs[0].value[0].name == 'conv1|weight'", + "assert": "model.graphs[0].nodes[0].inputs[0].value[0].name == 'conv1|weight'", "link": "https://github.com/lutzroeder/netron/issues/686" }, { @@ -6381,7 +6386,7 @@ "target": "segmentor.pt", "source": "https://github.com/lutzroeder/netron/files/7663953/segmentor.pt.zip[segmentor.pt]", "format": "PyTorch v1.6", - "assert": "model.graphs[0].nodes[0].inputs[0].value.inputs[0].value.type.name == '__torch__.___torch_mangle_1.Module'", + "assert": "model.graphs[0].nodes[0].inputs[0].value[0].inputs[0].value[0].type.name == '__torch__.___torch_mangle_1.Module'", "link": "https://github.com/lutzroeder/netron/issues/686" }, {