Skip to content

Commit 982ff60

Browse files
committed
Update python.js (#1061)
1 parent 082c7e0 commit 982ff60

File tree

2 files changed

+157
-29
lines changed

2 files changed

+157
-29
lines changed

source/python.js

Lines changed: 140 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,7 @@ python.Execution = class {
125125
this.register('argparse');
126126
this._enum = this.register('enum');
127127
this.register('collections');
128+
const copy = this.register('copy');
128129
this.register('copy_reg');
129130
const ast = this.register('ast');
130131
this.ast = ast;
@@ -4330,6 +4331,7 @@ python.Execution = class {
43304331
this.registerFunction('collections.defaultdict', (/* default_factory */) => {
43314332
return {};
43324333
});
4334+
this.registerFunction('copy.deepcopy');
43334335
this.registerFunction('copy_reg._reconstructor', (cls, base, state) => {
43344336
// copyreg._reconstructor in Python 3
43354337
if (base === '__builtin__.object' || base === builtins.object) {
@@ -6980,13 +6982,16 @@ python.Execution = class {
69806982
});
69816983
this.registerType('torch.ClassType', class extends torch.Type {
69826984
constructor(qualified_name, cu, is_module) {
6983-
super('ClassType', qualified_name);
6985+
super('ClassType', typeof qualified_name === 'string' ? qualified_name : qualified_name.qualifiedName());
69846986
this._is_module = is_module;
69856987
this._attributes = new Map();
69866988
this._methods = new Map();
69876989
this._staticmethods = new Map();
69886990
this._constants = new Map();
69896991
}
6992+
static create(qualifiedName, cu, is_module /*, doc_string, unresolved_class_attributes */) {
6993+
return new torch.ClassType(qualifiedName, cu, is_module);
6994+
}
69906995
qualified_name() {
69916996
return this.annotation_str;
69926997
}
@@ -7655,7 +7660,7 @@ python.Execution = class {
76557660
while (L.eat('.')) {
76567661
name = `${name}.${L.expect('id')}`;
76577662
}
7658-
real_value = new torch.ClassType(name); // getCustomClass
7663+
real_value = torch.ClassType.create(name); // getCustomClass
76597664
fake_value = real_value;
76607665
} else {
76617666
real_value = this.parseBaseType();
@@ -8805,6 +8810,8 @@ python.Execution = class {
88058810
let name = null;
88068811
if (args.length === 1 && typeof args[0] === 'string') {
88078812
[name] = args;
8813+
} else if (args.length === 1 && Array.isArray(args[0]) && args[0].every((arg) => typeof arg === 'string')) {
8814+
name = args[0].join('.');
88088815
} else {
88098816
name = `${args[0].qualifiedName()}.${args[1]}`;
88108817
}
@@ -8822,6 +8829,9 @@ python.Execution = class {
88228829
name() {
88238830
return this._name; // "baz"
88248831
}
8832+
atoms() {
8833+
return this._qualifiedName.split('.');
8834+
}
88258835
});
88268836
this.registerType('torch.jit.SourceImporter', class {
88278837
constructor(cu, constant_table, source_loader, version) {
@@ -8887,7 +8897,7 @@ python.Execution = class {
88878897
const pre_hook_def_map = new Map();
88888898
const hook_names = new Set();
88898899
const hook_def_map = new Map();
8890-
const class_type = new torch.ClassType(qualified_classname.qualifiedName(), this._cu, is_module);
8900+
const class_type = torch.ClassType.create(qualified_classname.qualifiedName(), this._cu, is_module);
88918901
for (const stmt of class_def.body) {
88928902
if (stmt instanceof ast.Assign || stmt instanceof ast.AnnAssign) {
88938903
let target = null;
@@ -8940,8 +8950,9 @@ python.Execution = class {
89408950
break;
89418951
}
89428952
}
8943-
} else if (target instanceof ast.Subscript) {
8944-
// not implemented
8953+
} else if (target instanceof ast.Subscript && target.value instanceof ast.Name && target.value.id === '__annotations__') {
8954+
const name = target.slice.elts[0].value;
8955+
attributes.push({ name, value, annotation: stmt.value });
89458956
continue;
89468957
} else {
89478958
throw new python.Error('Unexpected statement kind in module metadata.');
@@ -9020,11 +9031,11 @@ python.Execution = class {
90209031
this._code_prefix = !pickle_dir_prefix && !tensor_dir_prefix ? 'code/' : '.data/ts_code/code/';
90219032
this._pickle_dir_prefix = pickle_dir_prefix || '';
90229033
this._tensor_dir_prefix = tensor_dir_prefix || '';
9034+
this._constant_table = [];
90239035
const SourceLoader = (qualifier) => {
90249036
return this.findSourceInArchiveFromQualifier(this._reader, this._code_prefix, qualifier);
90259037
};
9026-
this._source_importer = new torch.jit.SourceImporter(
9027-
this._compilation_unit, this._constants_table, SourceLoader, reader.version());
9038+
this._source_importer = new torch.jit.SourceImporter(this._compilation_unit, this._constant_table, SourceLoader, reader.version());
90289039
}
90299040
deserialize() {
90309041
const execution = this._compilation_unit.execution;
@@ -9061,7 +9072,7 @@ python.Execution = class {
90619072
];
90629073
for (const known_type of known_types) {
90639074
const prefix = new torch.jit.QualifiedName(known_type.name);
9064-
const type = new torch.ClassType(known_type.name, this._compilation_unit, false);
9075+
const type = torch.ClassType.create(known_type.name, this._compilation_unit, false);
90659076
for (const known_method of known_type.methods || []) {
90669077
const schema = new torch.FunctionSchema(known_method);
90679078
const name = new torch.jit.QualifiedName(prefix, schema.name);
@@ -9078,7 +9089,8 @@ python.Execution = class {
90789089
execution.builtins.CONSTANTS[`c${i}`] = constants[i];
90799090
}
90809091
const module = this.readArchive('data');
9081-
const type = new torch.ClassType(`${module.__class__.__module__}.${module.__class__.__name__}`, null, true);
9092+
const name = `${module.__class__.__module__}.${module.__class__.__name__}`;
9093+
const type = torch.ClassType.create(name, null, true);
90829094
const result = new torch.ScriptModule(type);
90839095
result.data = module;
90849096
return result;
@@ -9101,7 +9113,7 @@ python.Execution = class {
91019113
['INT32', 'Int'],
91029114
['INT64', 'Long']
91039115
]);
9104-
const constants = (model.tensors || []).map((constant) => {
9116+
const tensor_table = (model.tensors || []).map((constant) => {
91059117
const key = constant.data.key;
91069118
if (!tensorTypeMap.has(constant.dataType)) {
91079119
throw new python.Error(`Unsupported tensor data type '${constant.dataType}'.`);
@@ -9126,8 +9138,8 @@ python.Execution = class {
91269138
return tensor;
91279139
});
91289140
execution.builtins.CONSTANTS = {};
9129-
for (let i = 0; i < constants.length; i++) {
9130-
execution.builtins.CONSTANTS[`c${i}`] = constants[i];
9141+
for (let i = 0; i < tensor_table.length; i++) {
9142+
execution.builtins.CONSTANTS[`c${i}`] = tensor_table[i];
91319143
}
91329144
const attributes = [];
91339145
if (this._reader.has_record('attributes.pkl')) {
@@ -9137,6 +9149,14 @@ python.Execution = class {
91379149
const obj = unpickler.load();
91389150
attributes.push(...obj);
91399151
}
9152+
9153+
this._LEGACY_moduleStack = ['__torch__'];
9154+
// const module_def = model.mainModule;
9155+
for (const tensor of tensor_table) {
9156+
this._constant_table.push(tensor);
9157+
}
9158+
// this.LEGACY_convertModule(module_def);
9159+
91409160
while (queue.length > 0) {
91419161
const module = queue.shift();
91429162
module.__class__ = module.__class__ || { __module__: 'torch.nn.modules.module', __name__: 'Module' };
@@ -9161,7 +9181,7 @@ python.Execution = class {
91619181
delete module.arguments;
91629182
}
91639183
for (const parameter of parameters) {
9164-
const tensor = constants[parameter.tensorId];
9184+
const tensor = tensor_table[parameter.tensorId];
91659185
module[parameter.name] = tensor;
91669186
parameter.__class__ = parameter.__class__ || { __module__: 'torch', __name__: 'Tensor' };
91679187
}
@@ -9182,11 +9202,71 @@ python.Execution = class {
91829202
data.forward = module.forward;
91839203
}
91849204
}
9185-
const class_type = new torch.ClassType(data.name);
9205+
const class_type = torch.ClassType.create(data.name);
91869206
const result = new torch.ScriptModule(class_type);
91879207
result.data = data;
91889208
return result;
91899209
}
9210+
LEGACY_convertModule(module_def) {
9211+
const atoms = new torch.jit.QualifiedName(module_def.name).atoms();
9212+
const numPushed = atoms.length;
9213+
for (const atom of atoms) {
9214+
const sanitized = /^\d+$/.test(atom) ? `_${atom}` : atom;
9215+
this._LEGACY_moduleStack.push(sanitized);
9216+
}
9217+
const module = new torch.ScriptModule(new torch.jit.QualifiedName(this._LEGACY_moduleStack), this._compilation_unit);
9218+
for (const sub_def of module_def.submodules || []) {
9219+
const submodule = this.LEGACY_convertModule(sub_def);
9220+
module.register_module(sub_def.name, submodule);
9221+
}
9222+
for (const param_def of module_def.parameters || []) {
9223+
const tensor = this._constant_table[Number(param_def.tensorId)];
9224+
if (param_def.isBuffer) {
9225+
module.register_buffer(param_def.name, tensor);
9226+
} else {
9227+
module.register_parameter(param_def.name, tensor, false);
9228+
}
9229+
}
9230+
// const typeParser = new torch.jit.ScriptTypeParser(this._source_importer);
9231+
for (const attr_def of module_def.attributes || []) {
9232+
if (module.hasattr(attr_def.name)) {
9233+
continue;
9234+
}
9235+
// IValue ivalue;
9236+
// if (attr_def.id() >= 0) {
9237+
// ivalue = LEGACY_pickled_ivalues_.at(attr_def.id());
9238+
// }
9239+
// module.register_attribute(attr_def.name(), typeParser.parseType(attr_def.type()), ivalue);
9240+
}
9241+
/*
9242+
std::shared_ptr<SourceRangeUnpickler> gen_ranges = nullptr;
9243+
if (module_def.has_torchscript_debug_arena()) {
9244+
auto [data, size] = reader_->getRecord(module_def.torchscript_debug_arena().key());
9245+
gen_ranges = std::make_shared<ConcreteSourceRangeUnpickler>(std::move(data), size);
9246+
}
9247+
if (module_def.has_torchscript_arena()) {
9248+
auto [data, size] =
9249+
reader_->getRecord(module_def.torchscript_arena().key());
9250+
std::string data_str(static_cast<const char*>(data.get()), size);
9251+
auto src = std::make_shared<Source>(std::string(static_cast<const char*>(data.get()), size), module_def.torchscript_arena().key(), 1, std::move(gen_ranges));
9252+
source_importer_.LEGACY_import_methods(module, src);
9253+
}
9254+
if (module_def.has_get_state_attribute_id()) {
9255+
LEGACY_moduleSetState(module, LEGACY_pickled_ivalues_.at(module_def.get_state_attribute_id()));
9256+
}
9257+
const ClassTypePtr& module_type = module._ivalue()->type();
9258+
for (size_t i = 0, N = module_type->numAttributes(); i < N; ++i) {
9259+
const IValue& v = module._ivalue()->getSlot(i);
9260+
if (module_type->getAttribute(i)->kind() != TypeKind::OptionalType) {
9261+
TORCH_CHECK(!v.isNone(), "The field '", module_type->getAttributeName(i), "' was left unitialized after __setstate__, but expected a ", "value of type '", v.type()->repr_str(), "'");
9262+
}
9263+
}
9264+
*/
9265+
for (let i = 0; i < numPushed; i++) {
9266+
this._LEGACY_moduleStack.pop();
9267+
}
9268+
return module;
9269+
}
91909270
readArchive(archive_name) {
91919271
const type_resolver = null;
91929272
const obj_loader = null;
@@ -9478,6 +9558,14 @@ python.Execution = class {
94789558
}
94799559
});
94809560
this.registerType('torch.ScriptModule', class extends torch.ScriptObject {
9561+
constructor(...args) {
9562+
if (args[0] instanceof torch.jit.QualifiedName && args[1] instanceof torch.jit.CompilationUnit) {
9563+
const [class_name, cu, shouldMangle] = args;
9564+
super(...torch.ScriptModule.create_module_object(class_name, cu, shouldMangle));
9565+
} else {
9566+
super(...args);
9567+
}
9568+
}
94819569
get qualified_name() {
94829570
return this._type.qualified_name();
94839571
}
@@ -9579,6 +9667,24 @@ python.Execution = class {
95799667
}
95809668
return this._graph;
95819669
}
9670+
static create_module_object(class_name, cu, shouldMangle) {
9671+
shouldMangle = shouldMangle || false;
9672+
if (!class_name.prefix()) {
9673+
class_name = new torch.jit.QualifiedName('__torch__', class_name.name());
9674+
}
9675+
if (shouldMangle && cu.get_class(class_name)) {
9676+
class_name = cu.mangle(class_name);
9677+
}
9678+
const cls = torch.ClassType.create(class_name, cu, true);
9679+
cu.register_type(cls);
9680+
return [cls, cu];
9681+
}
9682+
register_module(/* name, module */) {
9683+
}
9684+
register_buffer(/* name, buffer */) {
9685+
}
9686+
register_parameter(/* name, parameter, is_buffer */) {
9687+
}
95829688
});
95839689
this.registerType('torch.ModuleDict', class {
95849690
constructor(module) {
@@ -9659,6 +9765,7 @@ python.Execution = class {
96599765
const graph = new torch.Graph();
96609766
graph.set_op_version(operator_set_version);
96619767
const fn = new torch.jit.GraphFunction(name, graph, creator);
9768+
fn.__ast__ = def;
96629769
if (shouldMangle && this.find_function(name)) {
96639770
// name = mangle(name);
96649771
}
@@ -9845,7 +9952,7 @@ python.Execution = class {
98459952
cls = this._cu.get_class(new torch.jit.QualifiedName(name));
98469953
if (!cls) {
98479954
const torch = this._torch;
9848-
cls = new torch.ClassType(name, this._cu, true);
9955+
cls = torch.ClassType.create(name, this._cu, true);
98499956
this._cu.register_type(cls);
98509957
}
98519958
} else {
@@ -9862,7 +9969,20 @@ python.Execution = class {
98629969
return cls;
98639970
}
98649971
});
9865-
this.registerType('torch.export.unflatten.UnflattenedModule', class extends torch.nn.modules.module.Module {});
9972+
this.registerType('torch.export.UnflattenedModule', class extends torch.nn.modules.module.Module {
9973+
constructor(export_module, flat_args_adapter) {
9974+
super();
9975+
const export_graph = copy.deepcopy(export_module.graph);
9976+
self.graph_signature = copy.deepcopy(export_module.graph_signature);
9977+
this.graph = torch.fx.Graph();
9978+
this.graph.owning_module = this;
9979+
this.module_call_graph = copy.deepcopy(export_module.module_call_graph);
9980+
this.flat_args_adapter = flat_args_adapter;
9981+
this.adapted = false;
9982+
// this._run_with_interpreter = RUN_WITH_INTERPRETER
9983+
this._inplace_buffer_mutations(export_graph, this.graph_signature);
9984+
}
9985+
});
98669986
this.registerType('torch.export.graph_signature.ExportGraphSignature', class {
98679987
constructor(input_specs, output_specs) {
98689988
this.input_specs = input_specs;
@@ -10017,7 +10137,10 @@ python.Execution = class {
1001710137
});
1001810138
this.registerType('torch.export.exported_program.ModuleCallEntry', class {});
1001910139
this.registerType('torch.export.exported_program.ModuleCallSignature', class {});
10020-
this.registerFunction('torch.export.unflatten');
10140+
this.registerFunction('torch.export.unflatten', (module, flat_args_adapter) => {
10141+
module = torch.export._remove_effect_tokens(module);
10142+
return new torch.export.UnflattenedModule(module, flat_args_adapter);
10143+
});
1002110144
this.registerFunction('torch._export.exported_program._create_graph_module_for_export', (root, graph) => {
1002210145
return new torch.fx.graph_module.GraphModule(root, graph);
1002310146
});

0 commit comments

Comments
 (0)