Skip to content

Commit 7ff178f

Browse files
committed
Update python.js (#1061)
1 parent 2545ad7 commit 7ff178f

File tree

2 files changed

+21
-17
lines changed

2 files changed

+21
-17
lines changed

source/python.js

Lines changed: 17 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -9100,28 +9100,28 @@ python.Execution = class {
91009100
}
91019101
LEGACY_deserialize() {
91029102
const execution = this._compilation_unit.execution;
9103+
const caffe2 = execution.proto.caffe2;
91039104
const torch = execution.import('torch');
91049105
const stream = this._reader.get_record('model.json');
91059106
const buffer = stream.peek();
91069107
const decoder = new TextDecoder('utf-8');
91079108
const content = decoder.decode(buffer);
9108-
const model = JSON.parse(content);
9109-
const data = model.mainModule || {};
9110-
const queue = [data];
9109+
const obj = JSON.parse(content);
9110+
const model = execution.proto.torch.ModelDef.decodeJson(obj);
91119111
const tensorTypeMap = new Map([
9112-
['FLOAT', 'Float'],
9113-
['FLOAT16', 'Half'],
9114-
['DOUBLE', 'Double'],
9115-
['INT8', 'Char'],
9116-
['INT32', 'Int'],
9117-
['INT64', 'Long']
9112+
[caffe2.TensorProto.DataType.FLOAT, 'Float'],
9113+
[caffe2.TensorProto.DataType.FLOAT16, 'Half'],
9114+
[caffe2.TensorProto.DataType.DOUBLE, 'Double'],
9115+
[caffe2.TensorProto.DataType.INT8, 'Char'],
9116+
[caffe2.TensorProto.DataType.INT32, 'Int'],
9117+
[caffe2.TensorProto.DataType.INT64, 'Long']
91189118
]);
91199119
const tensor_table = (model.tensors || []).map((constant) => {
91209120
const key = constant.data.key;
9121-
if (!tensorTypeMap.has(constant.dataType)) {
9122-
throw new python.Error(`Unsupported tensor data type '${constant.dataType}'.`);
9121+
if (!tensorTypeMap.has(constant.data_type)) {
9122+
throw new python.Error(`Unsupported tensor data type '${constant.data_type}'.`);
91239123
}
9124-
const type = tensorTypeMap.get(constant.dataType);
9124+
const type = tensorTypeMap.get(constant.data_type);
91259125
const shape = constant.dims ? constant.dims.map((dim) => parseInt(dim, 10)) : null;
91269126
const strides = constant.strides ? constant.strides.map((dim) => parseInt(dim, 10)) : null;
91279127
const storage_type = execution.resolve(`torch.${type}Storage`);
@@ -9137,7 +9137,7 @@ python.Execution = class {
91379137
storage._set_cdata(data);
91389138
}
91399139
const tensor = torch._utils._rebuild_tensor(storage, 0, shape, strides);
9140-
tensor.name = constant.data.key;
9140+
tensor.name = key;
91419141
return tensor;
91429142
});
91439143
execution.builtins.CONSTANTS = {};
@@ -9152,14 +9152,14 @@ python.Execution = class {
91529152
const obj = unpickler.load();
91539153
attributes.push(...obj);
91549154
}
9155-
91569155
this._LEGACY_moduleStack = ['__torch__'];
9157-
// const module_def = model.mainModule;
9156+
const module_def = model.main_module;
91589157
for (const tensor of tensor_table) {
91599158
this._constant_table.push(tensor);
91609159
}
9161-
// this.LEGACY_convertModule(module_def);
9162-
9160+
this.LEGACY_convertModule(module_def);
9161+
const data = obj.mainModule || {};
9162+
const queue = [data];
91639163
while (queue.length > 0) {
91649164
const module = queue.shift();
91659165
module.__class__ = module.__class__ || { __module__: 'torch.nn.modules.module', __name__: 'Module' };

source/pytorch.js

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1345,6 +1345,9 @@ pytorch.Container.ModelJson = class extends pytorch.Container {
13451345
}
13461346

13471347
async read(metadata) {
1348+
if (this._entries.has('model.json')) {
1349+
pytorch.proto = await this._context.require('./pytorch-proto');
1350+
}
13481351
const keys = [
13491352
'attributes.pkl',
13501353
'version',
@@ -1360,6 +1363,7 @@ pytorch.Container.ModelJson = class extends pytorch.Container {
13601363
}
13611364
}
13621365
this.execution = new pytorch.Execution(null, metadata);
1366+
this.execution.proto = pytorch.proto;
13631367
for (const event of this._events) {
13641368
this.execution.on(event[0], event[1]);
13651369
}

0 commit comments

Comments
 (0)