@@ -125,6 +125,7 @@ python.Execution = class {
125
125
this.register('argparse');
126
126
this._enum = this.register('enum');
127
127
this.register('collections');
128
+ const copy = this.register('copy');
128
129
this.register('copy_reg');
129
130
const ast = this.register('ast');
130
131
this.ast = ast;
@@ -4330,6 +4331,7 @@ python.Execution = class {
4330
4331
this.registerFunction('collections.defaultdict', (/* default_factory */) => {
4331
4332
return {};
4332
4333
});
4334
+ this.registerFunction('copy.deepcopy');
4333
4335
this.registerFunction('copy_reg._reconstructor', (cls, base, state) => {
4334
4336
// copyreg._reconstructor in Python 3
4335
4337
if (base === '__builtin__.object' || base === builtins.object) {
@@ -6980,13 +6982,16 @@ python.Execution = class {
6980
6982
});
6981
6983
this.registerType('torch.ClassType', class extends torch.Type {
6982
6984
constructor(qualified_name, cu, is_module) {
6983
- super('ClassType', qualified_name);
6985
+ super('ClassType', typeof qualified_name === 'string' ? qualified_name : qualified_name.qualifiedName() );
6984
6986
this._is_module = is_module;
6985
6987
this._attributes = new Map();
6986
6988
this._methods = new Map();
6987
6989
this._staticmethods = new Map();
6988
6990
this._constants = new Map();
6989
6991
}
6992
+ static create(qualifiedName, cu, is_module /*, doc_string, unresolved_class_attributes */) {
6993
+ return new torch.ClassType(qualifiedName, cu, is_module);
6994
+ }
6990
6995
qualified_name() {
6991
6996
return this.annotation_str;
6992
6997
}
@@ -7655,7 +7660,7 @@ python.Execution = class {
7655
7660
while (L.eat('.')) {
7656
7661
name = `${name}.${L.expect('id')}`;
7657
7662
}
7658
- real_value = new torch.ClassType(name); // getCustomClass
7663
+ real_value = torch.ClassType.create (name); // getCustomClass
7659
7664
fake_value = real_value;
7660
7665
} else {
7661
7666
real_value = this.parseBaseType();
@@ -8805,6 +8810,8 @@ python.Execution = class {
8805
8810
let name = null;
8806
8811
if (args.length === 1 && typeof args[0] === 'string') {
8807
8812
[name] = args;
8813
+ } else if (args.length === 1 && Array.isArray(args[0]) && args[0].every((arg) => typeof arg === 'string')) {
8814
+ name = args[0].join('.');
8808
8815
} else {
8809
8816
name = `${args[0].qualifiedName()}.${args[1]}`;
8810
8817
}
@@ -8822,6 +8829,9 @@ python.Execution = class {
8822
8829
name() {
8823
8830
return this._name; // "baz"
8824
8831
}
8832
+ atoms() {
8833
+ return this._qualifiedName.split('.');
8834
+ }
8825
8835
});
8826
8836
this.registerType('torch.jit.SourceImporter', class {
8827
8837
constructor(cu, constant_table, source_loader, version) {
@@ -8887,7 +8897,7 @@ python.Execution = class {
8887
8897
const pre_hook_def_map = new Map();
8888
8898
const hook_names = new Set();
8889
8899
const hook_def_map = new Map();
8890
- const class_type = new torch.ClassType(qualified_classname.qualifiedName(), this._cu, is_module);
8900
+ const class_type = torch.ClassType.create (qualified_classname.qualifiedName(), this._cu, is_module);
8891
8901
for (const stmt of class_def.body) {
8892
8902
if (stmt instanceof ast.Assign || stmt instanceof ast.AnnAssign) {
8893
8903
let target = null;
@@ -8940,8 +8950,9 @@ python.Execution = class {
8940
8950
break;
8941
8951
}
8942
8952
}
8943
- } else if (target instanceof ast.Subscript) {
8944
- // not implemented
8953
+ } else if (target instanceof ast.Subscript && target.value instanceof ast.Name && target.value.id === '__annotations__') {
8954
+ const name = target.slice.elts[0].value;
8955
+ attributes.push({ name, value, annotation: stmt.value });
8945
8956
continue;
8946
8957
} else {
8947
8958
throw new python.Error('Unexpected statement kind in module metadata.');
@@ -9020,11 +9031,11 @@ python.Execution = class {
9020
9031
this._code_prefix = !pickle_dir_prefix && !tensor_dir_prefix ? 'code/' : '.data/ts_code/code/';
9021
9032
this._pickle_dir_prefix = pickle_dir_prefix || '';
9022
9033
this._tensor_dir_prefix = tensor_dir_prefix || '';
9034
+ this._constant_table = [];
9023
9035
const SourceLoader = (qualifier) => {
9024
9036
return this.findSourceInArchiveFromQualifier(this._reader, this._code_prefix, qualifier);
9025
9037
};
9026
- this._source_importer = new torch.jit.SourceImporter(
9027
- this._compilation_unit, this._constants_table, SourceLoader, reader.version());
9038
+ this._source_importer = new torch.jit.SourceImporter(this._compilation_unit, this._constant_table, SourceLoader, reader.version());
9028
9039
}
9029
9040
deserialize() {
9030
9041
const execution = this._compilation_unit.execution;
@@ -9061,7 +9072,7 @@ python.Execution = class {
9061
9072
];
9062
9073
for (const known_type of known_types) {
9063
9074
const prefix = new torch.jit.QualifiedName(known_type.name);
9064
- const type = new torch.ClassType(known_type.name, this._compilation_unit, false);
9075
+ const type = torch.ClassType.create (known_type.name, this._compilation_unit, false);
9065
9076
for (const known_method of known_type.methods || []) {
9066
9077
const schema = new torch.FunctionSchema(known_method);
9067
9078
const name = new torch.jit.QualifiedName(prefix, schema.name);
@@ -9078,7 +9089,8 @@ python.Execution = class {
9078
9089
execution.builtins.CONSTANTS[`c${i}`] = constants[i];
9079
9090
}
9080
9091
const module = this.readArchive('data');
9081
- const type = new torch.ClassType(`${module.__class__.__module__}.${module.__class__.__name__}`, null, true);
9092
+ const name = `${module.__class__.__module__}.${module.__class__.__name__}`;
9093
+ const type = torch.ClassType.create(name, null, true);
9082
9094
const result = new torch.ScriptModule(type);
9083
9095
result.data = module;
9084
9096
return result;
@@ -9101,7 +9113,7 @@ python.Execution = class {
9101
9113
['INT32', 'Int'],
9102
9114
['INT64', 'Long']
9103
9115
]);
9104
- const constants = (model.tensors || []).map((constant) => {
9116
+ const tensor_table = (model.tensors || []).map((constant) => {
9105
9117
const key = constant.data.key;
9106
9118
if (!tensorTypeMap.has(constant.dataType)) {
9107
9119
throw new python.Error(`Unsupported tensor data type '${constant.dataType}'.`);
@@ -9126,8 +9138,8 @@ python.Execution = class {
9126
9138
return tensor;
9127
9139
});
9128
9140
execution.builtins.CONSTANTS = {};
9129
- for (let i = 0; i < constants .length; i++) {
9130
- execution.builtins.CONSTANTS[`c${i}`] = constants [i];
9141
+ for (let i = 0; i < tensor_table .length; i++) {
9142
+ execution.builtins.CONSTANTS[`c${i}`] = tensor_table [i];
9131
9143
}
9132
9144
const attributes = [];
9133
9145
if (this._reader.has_record('attributes.pkl')) {
@@ -9137,6 +9149,14 @@ python.Execution = class {
9137
9149
const obj = unpickler.load();
9138
9150
attributes.push(...obj);
9139
9151
}
9152
+
9153
+ this._LEGACY_moduleStack = ['__torch__'];
9154
+ // const module_def = model.mainModule;
9155
+ for (const tensor of tensor_table) {
9156
+ this._constant_table.push(tensor);
9157
+ }
9158
+ // this.LEGACY_convertModule(module_def);
9159
+
9140
9160
while (queue.length > 0) {
9141
9161
const module = queue.shift();
9142
9162
module.__class__ = module.__class__ || { __module__: 'torch.nn.modules.module', __name__: 'Module' };
@@ -9161,7 +9181,7 @@ python.Execution = class {
9161
9181
delete module.arguments;
9162
9182
}
9163
9183
for (const parameter of parameters) {
9164
- const tensor = constants [parameter.tensorId];
9184
+ const tensor = tensor_table [parameter.tensorId];
9165
9185
module[parameter.name] = tensor;
9166
9186
parameter.__class__ = parameter.__class__ || { __module__: 'torch', __name__: 'Tensor' };
9167
9187
}
@@ -9182,11 +9202,71 @@ python.Execution = class {
9182
9202
data.forward = module.forward;
9183
9203
}
9184
9204
}
9185
- const class_type = new torch.ClassType(data.name);
9205
+ const class_type = torch.ClassType.create (data.name);
9186
9206
const result = new torch.ScriptModule(class_type);
9187
9207
result.data = data;
9188
9208
return result;
9189
9209
}
9210
+ LEGACY_convertModule(module_def) {
9211
+ const atoms = new torch.jit.QualifiedName(module_def.name).atoms();
9212
+ const numPushed = atoms.length;
9213
+ for (const atom of atoms) {
9214
+ const sanitized = /^\d+$/.test(atom) ? `_${atom}` : atom;
9215
+ this._LEGACY_moduleStack.push(sanitized);
9216
+ }
9217
+ const module = new torch.ScriptModule(new torch.jit.QualifiedName(this._LEGACY_moduleStack), this._compilation_unit);
9218
+ for (const sub_def of module_def.submodules || []) {
9219
+ const submodule = this.LEGACY_convertModule(sub_def);
9220
+ module.register_module(sub_def.name, submodule);
9221
+ }
9222
+ for (const param_def of module_def.parameters || []) {
9223
+ const tensor = this._constant_table[Number(param_def.tensorId)];
9224
+ if (param_def.isBuffer) {
9225
+ module.register_buffer(param_def.name, tensor);
9226
+ } else {
9227
+ module.register_parameter(param_def.name, tensor, false);
9228
+ }
9229
+ }
9230
+ // const typeParser = new torch.jit.ScriptTypeParser(this._source_importer);
9231
+ for (const attr_def of module_def.attributes || []) {
9232
+ if (module.hasattr(attr_def.name)) {
9233
+ continue;
9234
+ }
9235
+ // IValue ivalue;
9236
+ // if (attr_def.id() >= 0) {
9237
+ // ivalue = LEGACY_pickled_ivalues_.at(attr_def.id());
9238
+ // }
9239
+ // module.register_attribute(attr_def.name(), typeParser.parseType(attr_def.type()), ivalue);
9240
+ }
9241
+ /*
9242
+ std::shared_ptr<SourceRangeUnpickler> gen_ranges = nullptr;
9243
+ if (module_def.has_torchscript_debug_arena()) {
9244
+ auto [data, size] = reader_->getRecord(module_def.torchscript_debug_arena().key());
9245
+ gen_ranges = std::make_shared<ConcreteSourceRangeUnpickler>(std::move(data), size);
9246
+ }
9247
+ if (module_def.has_torchscript_arena()) {
9248
+ auto [data, size] =
9249
+ reader_->getRecord(module_def.torchscript_arena().key());
9250
+ std::string data_str(static_cast<const char*>(data.get()), size);
9251
+ auto src = std::make_shared<Source>(std::string(static_cast<const char*>(data.get()), size), module_def.torchscript_arena().key(), 1, std::move(gen_ranges));
9252
+ source_importer_.LEGACY_import_methods(module, src);
9253
+ }
9254
+ if (module_def.has_get_state_attribute_id()) {
9255
+ LEGACY_moduleSetState(module, LEGACY_pickled_ivalues_.at(module_def.get_state_attribute_id()));
9256
+ }
9257
+ const ClassTypePtr& module_type = module._ivalue()->type();
9258
+ for (size_t i = 0, N = module_type->numAttributes(); i < N; ++i) {
9259
+ const IValue& v = module._ivalue()->getSlot(i);
9260
+ if (module_type->getAttribute(i)->kind() != TypeKind::OptionalType) {
9261
+ TORCH_CHECK(!v.isNone(), "The field '", module_type->getAttributeName(i), "' was left unitialized after __setstate__, but expected a ", "value of type '", v.type()->repr_str(), "'");
9262
+ }
9263
+ }
9264
+ */
9265
+ for (let i = 0; i < numPushed; i++) {
9266
+ this._LEGACY_moduleStack.pop();
9267
+ }
9268
+ return module;
9269
+ }
9190
9270
readArchive(archive_name) {
9191
9271
const type_resolver = null;
9192
9272
const obj_loader = null;
@@ -9478,6 +9558,14 @@ python.Execution = class {
9478
9558
}
9479
9559
});
9480
9560
this.registerType('torch.ScriptModule', class extends torch.ScriptObject {
9561
+ constructor(...args) {
9562
+ if (args[0] instanceof torch.jit.QualifiedName && args[1] instanceof torch.jit.CompilationUnit) {
9563
+ const [class_name, cu, shouldMangle] = args;
9564
+ super(...torch.ScriptModule.create_module_object(class_name, cu, shouldMangle));
9565
+ } else {
9566
+ super(...args);
9567
+ }
9568
+ }
9481
9569
get qualified_name() {
9482
9570
return this._type.qualified_name();
9483
9571
}
@@ -9579,6 +9667,24 @@ python.Execution = class {
9579
9667
}
9580
9668
return this._graph;
9581
9669
}
9670
+ static create_module_object(class_name, cu, shouldMangle) {
9671
+ shouldMangle = shouldMangle || false;
9672
+ if (!class_name.prefix()) {
9673
+ class_name = new torch.jit.QualifiedName('__torch__', class_name.name());
9674
+ }
9675
+ if (shouldMangle && cu.get_class(class_name)) {
9676
+ class_name = cu.mangle(class_name);
9677
+ }
9678
+ const cls = torch.ClassType.create(class_name, cu, true);
9679
+ cu.register_type(cls);
9680
+ return [cls, cu];
9681
+ }
9682
+ register_module(/* name, module */) {
9683
+ }
9684
+ register_buffer(/* name, buffer */) {
9685
+ }
9686
+ register_parameter(/* name, parameter, is_buffer */) {
9687
+ }
9582
9688
});
9583
9689
this.registerType('torch.ModuleDict', class {
9584
9690
constructor(module) {
@@ -9659,6 +9765,7 @@ python.Execution = class {
9659
9765
const graph = new torch.Graph();
9660
9766
graph.set_op_version(operator_set_version);
9661
9767
const fn = new torch.jit.GraphFunction(name, graph, creator);
9768
+ fn.__ast__ = def;
9662
9769
if (shouldMangle && this.find_function(name)) {
9663
9770
// name = mangle(name);
9664
9771
}
@@ -9845,7 +9952,7 @@ python.Execution = class {
9845
9952
cls = this._cu.get_class(new torch.jit.QualifiedName(name));
9846
9953
if (!cls) {
9847
9954
const torch = this._torch;
9848
- cls = new torch.ClassType(name, this._cu, true);
9955
+ cls = torch.ClassType.create (name, this._cu, true);
9849
9956
this._cu.register_type(cls);
9850
9957
}
9851
9958
} else {
@@ -9862,7 +9969,20 @@ python.Execution = class {
9862
9969
return cls;
9863
9970
}
9864
9971
});
9865
- this.registerType('torch.export.unflatten.UnflattenedModule', class extends torch.nn.modules.module.Module {});
9972
+ this.registerType('torch.export.UnflattenedModule', class extends torch.nn.modules.module.Module {
9973
+ constructor(export_module, flat_args_adapter) {
9974
+ super();
9975
+ const export_graph = copy.deepcopy(export_module.graph);
9976
+ self.graph_signature = copy.deepcopy(export_module.graph_signature);
9977
+ this.graph = torch.fx.Graph();
9978
+ this.graph.owning_module = this;
9979
+ this.module_call_graph = copy.deepcopy(export_module.module_call_graph);
9980
+ this.flat_args_adapter = flat_args_adapter;
9981
+ this.adapted = false;
9982
+ // this._run_with_interpreter = RUN_WITH_INTERPRETER
9983
+ this._inplace_buffer_mutations(export_graph, this.graph_signature);
9984
+ }
9985
+ });
9866
9986
this.registerType('torch.export.graph_signature.ExportGraphSignature', class {
9867
9987
constructor(input_specs, output_specs) {
9868
9988
this.input_specs = input_specs;
@@ -10017,7 +10137,10 @@ python.Execution = class {
10017
10137
});
10018
10138
this.registerType('torch.export.exported_program.ModuleCallEntry', class {});
10019
10139
this.registerType('torch.export.exported_program.ModuleCallSignature', class {});
10020
- this.registerFunction('torch.export.unflatten');
10140
+ this.registerFunction('torch.export.unflatten', (module, flat_args_adapter) => {
10141
+ module = torch.export._remove_effect_tokens(module);
10142
+ return new torch.export.UnflattenedModule(module, flat_args_adapter);
10143
+ });
10021
10144
this.registerFunction('torch._export.exported_program._create_graph_module_for_export', (root, graph) => {
10022
10145
return new torch.fx.graph_module.GraphModule(root, graph);
10023
10146
});
0 commit comments