@@ -9100,28 +9100,28 @@ python.Execution = class {
9100
9100
}
9101
9101
LEGACY_deserialize() {
9102
9102
const execution = this._compilation_unit.execution;
9103
+ const caffe2 = execution.proto.caffe2;
9103
9104
const torch = execution.import('torch');
9104
9105
const stream = this._reader.get_record('model.json');
9105
9106
const buffer = stream.peek();
9106
9107
const decoder = new TextDecoder('utf-8');
9107
9108
const content = decoder.decode(buffer);
9108
- const model = JSON.parse(content);
9109
- const data = model.mainModule || {};
9110
- const queue = [data];
9109
+ const obj = JSON.parse(content);
9110
+ const model = execution.proto.torch.ModelDef.decodeJson(obj);
9111
9111
const tensorTypeMap = new Map([
9112
- [' FLOAT' , 'Float'],
9113
- [' FLOAT16' , 'Half'],
9114
- [' DOUBLE' , 'Double'],
9115
- [' INT8' , 'Char'],
9116
- [' INT32' , 'Int'],
9117
- [' INT64' , 'Long']
9112
+ [caffe2.TensorProto.DataType. FLOAT, 'Float'],
9113
+ [caffe2.TensorProto.DataType. FLOAT16, 'Half'],
9114
+ [caffe2.TensorProto.DataType. DOUBLE, 'Double'],
9115
+ [caffe2.TensorProto.DataType. INT8, 'Char'],
9116
+ [caffe2.TensorProto.DataType. INT32, 'Int'],
9117
+ [caffe2.TensorProto.DataType. INT64, 'Long']
9118
9118
]);
9119
9119
const tensor_table = (model.tensors || []).map((constant) => {
9120
9120
const key = constant.data.key;
9121
- if (!tensorTypeMap.has(constant.dataType )) {
9122
- throw new python.Error(`Unsupported tensor data type '${constant.dataType }'.`);
9121
+ if (!tensorTypeMap.has(constant.data_type )) {
9122
+ throw new python.Error(`Unsupported tensor data type '${constant.data_type }'.`);
9123
9123
}
9124
- const type = tensorTypeMap.get(constant.dataType );
9124
+ const type = tensorTypeMap.get(constant.data_type );
9125
9125
const shape = constant.dims ? constant.dims.map((dim) => parseInt(dim, 10)) : null;
9126
9126
const strides = constant.strides ? constant.strides.map((dim) => parseInt(dim, 10)) : null;
9127
9127
const storage_type = execution.resolve(`torch.${type}Storage`);
@@ -9137,7 +9137,7 @@ python.Execution = class {
9137
9137
storage._set_cdata(data);
9138
9138
}
9139
9139
const tensor = torch._utils._rebuild_tensor(storage, 0, shape, strides);
9140
- tensor.name = constant.data. key;
9140
+ tensor.name = key;
9141
9141
return tensor;
9142
9142
});
9143
9143
execution.builtins.CONSTANTS = {};
@@ -9152,14 +9152,14 @@ python.Execution = class {
9152
9152
const obj = unpickler.load();
9153
9153
attributes.push(...obj);
9154
9154
}
9155
-
9156
9155
this._LEGACY_moduleStack = ['__torch__'];
9157
- // const module_def = model.mainModule ;
9156
+ const module_def = model.main_module ;
9158
9157
for (const tensor of tensor_table) {
9159
9158
this._constant_table.push(tensor);
9160
9159
}
9161
- // this.LEGACY_convertModule(module_def);
9162
-
9160
+ this.LEGACY_convertModule(module_def);
9161
+ const data = obj.mainModule || {};
9162
+ const queue = [data];
9163
9163
while (queue.length > 0) {
9164
9164
const module = queue.shift();
9165
9165
module.__class__ = module.__class__ || { __module__: 'torch.nn.modules.module', __name__: 'Module' };
0 commit comments