Skip to content

Commit cee0b8c

Browse files
committed
Update python.js (#1061)
1 parent 2545ad7 commit cee0b8c

File tree

2 files changed

+72
-27
lines changed

2 files changed

+72
-27
lines changed

source/python.js

Lines changed: 68 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -6987,7 +6987,7 @@ python.Execution = class {
69876987
constructor(qualified_name, cu, is_module) {
69886988
super('ClassType', typeof qualified_name === 'string' ? qualified_name : qualified_name.qualifiedName());
69896989
this._is_module = is_module;
6990-
this._attributes = new Map();
6990+
this._attributes = [];
69916991
this._methods = new Map();
69926992
this._staticmethods = new Map();
69936993
this._constants = new Map();
@@ -7023,14 +7023,43 @@ python.Execution = class {
70237023
findStaticMethod(name) {
70247024
return this._staticmethods.get(name);
70257025
}
7026-
addAttribute(name, type /*, is_parameter, is_buffer */) {
7027-
this._attributes.set(name, type);
7026+
addAttribute(name, type, is_parameter, is_buffer) {
7027+
is_parameter = is_parameter || false;
7028+
is_buffer = is_buffer || false;
7029+
const slot = this._attributes.length;
7030+
this._attributes.push({ name, type, is_parameter, is_buffer });
7031+
return slot;
7032+
}
7033+
addOrCheckAttribute(name, ty, is_parameter, is_buffer) {
7034+
is_parameter = is_parameter || false;
7035+
is_buffer = is_buffer || false;
7036+
const slot_idx = this.findAttributeSlot(name);
7037+
if (slot_idx === null) {
7038+
return this.addAttribute(name, ty, is_parameter, is_buffer);
7039+
}
7040+
// TORCH_CHECK(is_parameter == this->is_parameter(*slot_idx), "Parameter field mismatch for the field '", name, "'");
7041+
// const TypePtr& atype = getAttribute(*slot_idx);
7042+
// TORCH_CHECK(ty->isSubtypeOf(*atype), ty->repr_str(), " is not compatible with the type ", atype->repr_str(), " for the field '", name, "'");
7043+
return slot_idx;
7044+
}
7045+
findAttributeSlot(name) {
7046+
for (let pos = 0; pos < this._attributes.length; pos++) {
7047+
if (name === this._attributes[pos].name) {
7048+
return pos;
7049+
}
7050+
}
7051+
return null;
70287052
}
70297053
findAttribute(name) {
7030-
return this._attributes.get(name);
7054+
const slot = this.findAttributeSlot(name);
7055+
if (slot !== null) {
7056+
return this._attributes[slot].type;
7057+
}
7058+
return null;
70317059
}
70327060
getAttribute(name) {
7033-
return this._attributes.get(name);
7061+
const slot = this.findAttributeSlot(name);
7062+
return this._attributes[slot].type;
70347063
}
70357064
hasConstant(/* name */) {
70367065
}
@@ -9100,28 +9129,28 @@ python.Execution = class {
91009129
}
91019130
LEGACY_deserialize() {
91029131
const execution = this._compilation_unit.execution;
9132+
const caffe2 = execution.proto.caffe2;
91039133
const torch = execution.import('torch');
91049134
const stream = this._reader.get_record('model.json');
91059135
const buffer = stream.peek();
91069136
const decoder = new TextDecoder('utf-8');
91079137
const content = decoder.decode(buffer);
9108-
const model = JSON.parse(content);
9109-
const data = model.mainModule || {};
9110-
const queue = [data];
9138+
const obj = JSON.parse(content);
9139+
const model = execution.proto.torch.ModelDef.decodeJson(obj);
91119140
const tensorTypeMap = new Map([
9112-
['FLOAT', 'Float'],
9113-
['FLOAT16', 'Half'],
9114-
['DOUBLE', 'Double'],
9115-
['INT8', 'Char'],
9116-
['INT32', 'Int'],
9117-
['INT64', 'Long']
9141+
[caffe2.TensorProto.DataType.FLOAT, 'Float'],
9142+
[caffe2.TensorProto.DataType.FLOAT16, 'Half'],
9143+
[caffe2.TensorProto.DataType.DOUBLE, 'Double'],
9144+
[caffe2.TensorProto.DataType.INT8, 'Char'],
9145+
[caffe2.TensorProto.DataType.INT32, 'Int'],
9146+
[caffe2.TensorProto.DataType.INT64, 'Long']
91189147
]);
91199148
const tensor_table = (model.tensors || []).map((constant) => {
91209149
const key = constant.data.key;
9121-
if (!tensorTypeMap.has(constant.dataType)) {
9122-
throw new python.Error(`Unsupported tensor data type '${constant.dataType}'.`);
9150+
if (!tensorTypeMap.has(constant.data_type)) {
9151+
throw new python.Error(`Unsupported tensor data type '${constant.data_type}'.`);
91239152
}
9124-
const type = tensorTypeMap.get(constant.dataType);
9153+
const type = tensorTypeMap.get(constant.data_type);
91259154
const shape = constant.dims ? constant.dims.map((dim) => parseInt(dim, 10)) : null;
91269155
const strides = constant.strides ? constant.strides.map((dim) => parseInt(dim, 10)) : null;
91279156
const storage_type = execution.resolve(`torch.${type}Storage`);
@@ -9137,7 +9166,7 @@ python.Execution = class {
91379166
storage._set_cdata(data);
91389167
}
91399168
const tensor = torch._utils._rebuild_tensor(storage, 0, shape, strides);
9140-
tensor.name = constant.data.key;
9169+
tensor.name = key;
91419170
return tensor;
91429171
});
91439172
execution.builtins.CONSTANTS = {};
@@ -9152,14 +9181,14 @@ python.Execution = class {
91529181
const obj = unpickler.load();
91539182
attributes.push(...obj);
91549183
}
9155-
91569184
this._LEGACY_moduleStack = ['__torch__'];
9157-
// const module_def = model.mainModule;
9185+
const module_def = model.main_module;
91589186
for (const tensor of tensor_table) {
91599187
this._constant_table.push(tensor);
91609188
}
9161-
// this.LEGACY_convertModule(module_def);
9162-
9189+
const temp = this.LEGACY_convertModule(module_def);
9190+
const data = obj.mainModule || {};
9191+
const queue = [data];
91639192
while (queue.length > 0) {
91649193
const module = queue.shift();
91659194
module.__class__ = module.__class__ || { __module__: 'torch.nn.modules.module', __name__: 'Module' };
@@ -9205,8 +9234,7 @@ python.Execution = class {
92059234
data.forward = module.forward;
92069235
}
92079236
}
9208-
const class_type = torch.ClassType.create(data.name);
9209-
const result = new torch.ScriptModule(class_type);
9237+
const result = new torch.ScriptModule(temp.type());
92109238
result.data = data;
92119239
return result;
92129240
}
@@ -9532,6 +9560,9 @@ python.Execution = class {
95329560
}
95339561
return new torch.ScriptObject(type);
95349562
}
9563+
type() {
9564+
return this._type;
9565+
}
95359566
_type() {
95369567
return this._type; // torch.ClassType
95379568
}
@@ -9682,11 +9713,21 @@ python.Execution = class {
96829713
cu.register_type(cls);
96839714
return [cls, cu];
96849715
}
9685-
register_module(/* name, module */) {
9716+
register_module(name, module) {
9717+
this.type().addOrCheckAttribute(name, module.type());
9718+
// _ivalue()->setAttr(name, module._ivalue());
9719+
}
9720+
register_buffer(name /* , v */) {
9721+
this.type().addOrCheckAttribute(name, torch.TensorType.get(), false, true);
9722+
// _ivalue()->setAttr(name, std::move(v));
96869723
}
9687-
register_buffer(/* name, buffer */) {
9724+
register_parameter(name, v, is_buffer) {
9725+
this.type().addOrCheckAttribute(name, torch.TensorType.get(), !is_buffer, is_buffer);
9726+
// _ivalue()->setAttr(name, std::move(v));
96889727
}
9689-
register_parameter(/* name, parameter, is_buffer */) {
9728+
register_attribute(name, t, v, is_param, is_buffer) {
9729+
this.type().addOrCheckAttribute(name, t, is_param, is_buffer);
9730+
// _ivalue()->setAttr(name, v);
96909731
}
96919732
});
96929733
this.registerType('torch.ModuleDict', class {

source/pytorch.js

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1345,6 +1345,9 @@ pytorch.Container.ModelJson = class extends pytorch.Container {
13451345
}
13461346

13471347
async read(metadata) {
1348+
if (this._entries.has('model.json')) {
1349+
pytorch.proto = await this._context.require('./pytorch-proto');
1350+
}
13481351
const keys = [
13491352
'attributes.pkl',
13501353
'version',
@@ -1360,6 +1363,7 @@ pytorch.Container.ModelJson = class extends pytorch.Container {
13601363
}
13611364
}
13621365
this.execution = new pytorch.Execution(null, metadata);
1366+
this.execution.proto = pytorch.proto;
13631367
for (const event of this._events) {
13641368
this.execution.on(event[0], event[1]);
13651369
}

0 commit comments

Comments
 (0)