@@ -6987,7 +6987,7 @@ python.Execution = class {
6987
6987
constructor(qualified_name, cu, is_module) {
6988
6988
super('ClassType', typeof qualified_name === 'string' ? qualified_name : qualified_name.qualifiedName());
6989
6989
this._is_module = is_module;
6990
- this._attributes = new Map() ;
6990
+ this._attributes = [] ;
6991
6991
this._methods = new Map();
6992
6992
this._staticmethods = new Map();
6993
6993
this._constants = new Map();
@@ -7023,14 +7023,43 @@ python.Execution = class {
7023
7023
findStaticMethod(name) {
7024
7024
return this._staticmethods.get(name);
7025
7025
}
7026
- addAttribute(name, type /*, is_parameter, is_buffer */) {
7027
- this._attributes.set(name, type);
7026
+ addAttribute(name, type, is_parameter, is_buffer) {
7027
+ is_parameter = is_parameter || false;
7028
+ is_buffer = is_buffer || false;
7029
+ const slot = this._attributes.length;
7030
+ this._attributes.push({ name, type, is_parameter, is_buffer });
7031
+ return slot;
7032
+ }
7033
+ addOrCheckAttribute(name, ty, is_parameter, is_buffer) {
7034
+ is_parameter = is_parameter || false;
7035
+ is_buffer = is_buffer || false;
7036
+ const slot_idx = this.findAttributeSlot(name);
7037
+ if (slot_idx === null) {
7038
+ return this.addAttribute(name, ty, is_parameter, is_buffer);
7039
+ }
7040
+ // TORCH_CHECK(is_parameter == this->is_parameter(*slot_idx), "Parameter field mismatch for the field '", name, "'");
7041
+ // const TypePtr& atype = getAttribute(*slot_idx);
7042
+ // TORCH_CHECK(ty->isSubtypeOf(*atype), ty->repr_str(), " is not compatible with the type ", atype->repr_str(), " for the field '", name, "'");
7043
+ return slot_idx;
7044
+ }
7045
+ findAttributeSlot(name) {
7046
+ for (let pos = 0; pos < this._attributes.length; pos++) {
7047
+ if (name === this._attributes[pos].name) {
7048
+ return pos;
7049
+ }
7050
+ }
7051
+ return null;
7028
7052
}
7029
7053
findAttribute(name) {
7030
- return this._attributes.get(name);
7054
+ const slot = this.findAttributeSlot(name);
7055
+ if (slot !== null) {
7056
+ return this._attributes[slot].type;
7057
+ }
7058
+ return null;
7031
7059
}
7032
7060
getAttribute(name) {
7033
- return this._attributes.get(name);
7061
+ const slot = this.findAttributeSlot(name);
7062
+ return this._attributes[slot].type;
7034
7063
}
7035
7064
hasConstant(/* name */) {
7036
7065
}
@@ -9100,28 +9129,28 @@ python.Execution = class {
9100
9129
}
9101
9130
LEGACY_deserialize() {
9102
9131
const execution = this._compilation_unit.execution;
9132
+ const caffe2 = execution.proto.caffe2;
9103
9133
const torch = execution.import('torch');
9104
9134
const stream = this._reader.get_record('model.json');
9105
9135
const buffer = stream.peek();
9106
9136
const decoder = new TextDecoder('utf-8');
9107
9137
const content = decoder.decode(buffer);
9108
- const model = JSON.parse(content);
9109
- const data = model.mainModule || {};
9110
- const queue = [data];
9138
+ const obj = JSON.parse(content);
9139
+ const model = execution.proto.torch.ModelDef.decodeJson(obj);
9111
9140
const tensorTypeMap = new Map([
9112
- [' FLOAT' , 'Float'],
9113
- [' FLOAT16' , 'Half'],
9114
- [' DOUBLE' , 'Double'],
9115
- [' INT8' , 'Char'],
9116
- [' INT32' , 'Int'],
9117
- [' INT64' , 'Long']
9141
+ [caffe2.TensorProto.DataType. FLOAT, 'Float'],
9142
+ [caffe2.TensorProto.DataType. FLOAT16, 'Half'],
9143
+ [caffe2.TensorProto.DataType. DOUBLE, 'Double'],
9144
+ [caffe2.TensorProto.DataType. INT8, 'Char'],
9145
+ [caffe2.TensorProto.DataType. INT32, 'Int'],
9146
+ [caffe2.TensorProto.DataType. INT64, 'Long']
9118
9147
]);
9119
9148
const tensor_table = (model.tensors || []).map((constant) => {
9120
9149
const key = constant.data.key;
9121
- if (!tensorTypeMap.has(constant.dataType )) {
9122
- throw new python.Error(`Unsupported tensor data type '${constant.dataType }'.`);
9150
+ if (!tensorTypeMap.has(constant.data_type )) {
9151
+ throw new python.Error(`Unsupported tensor data type '${constant.data_type }'.`);
9123
9152
}
9124
- const type = tensorTypeMap.get(constant.dataType );
9153
+ const type = tensorTypeMap.get(constant.data_type );
9125
9154
const shape = constant.dims ? constant.dims.map((dim) => parseInt(dim, 10)) : null;
9126
9155
const strides = constant.strides ? constant.strides.map((dim) => parseInt(dim, 10)) : null;
9127
9156
const storage_type = execution.resolve(`torch.${type}Storage`);
@@ -9137,7 +9166,7 @@ python.Execution = class {
9137
9166
storage._set_cdata(data);
9138
9167
}
9139
9168
const tensor = torch._utils._rebuild_tensor(storage, 0, shape, strides);
9140
- tensor.name = constant.data. key;
9169
+ tensor.name = key;
9141
9170
return tensor;
9142
9171
});
9143
9172
execution.builtins.CONSTANTS = {};
@@ -9152,14 +9181,14 @@ python.Execution = class {
9152
9181
const obj = unpickler.load();
9153
9182
attributes.push(...obj);
9154
9183
}
9155
-
9156
9184
this._LEGACY_moduleStack = ['__torch__'];
9157
- // const module_def = model.mainModule ;
9185
+ const module_def = model.main_module ;
9158
9186
for (const tensor of tensor_table) {
9159
9187
this._constant_table.push(tensor);
9160
9188
}
9161
- // this.LEGACY_convertModule(module_def);
9162
-
9189
+ const temp = this.LEGACY_convertModule(module_def);
9190
+ const data = obj.mainModule || {};
9191
+ const queue = [data];
9163
9192
while (queue.length > 0) {
9164
9193
const module = queue.shift();
9165
9194
module.__class__ = module.__class__ || { __module__: 'torch.nn.modules.module', __name__: 'Module' };
@@ -9205,8 +9234,7 @@ python.Execution = class {
9205
9234
data.forward = module.forward;
9206
9235
}
9207
9236
}
9208
- const class_type = torch.ClassType.create(data.name);
9209
- const result = new torch.ScriptModule(class_type);
9237
+ const result = new torch.ScriptModule(temp.type());
9210
9238
result.data = data;
9211
9239
return result;
9212
9240
}
@@ -9532,6 +9560,9 @@ python.Execution = class {
9532
9560
}
9533
9561
return new torch.ScriptObject(type);
9534
9562
}
9563
+ type() {
9564
+ return this._type;
9565
+ }
9535
9566
_type() {
9536
9567
return this._type; // torch.ClassType
9537
9568
}
@@ -9682,11 +9713,21 @@ python.Execution = class {
9682
9713
cu.register_type(cls);
9683
9714
return [cls, cu];
9684
9715
}
9685
- register_module(/* name, module */) {
9716
+ register_module(name, module) {
9717
+ this.type().addOrCheckAttribute(name, module.type());
9718
+ // _ivalue()->setAttr(name, module._ivalue());
9719
+ }
9720
+ register_buffer(name /* , v */) {
9721
+ this.type().addOrCheckAttribute(name, torch.TensorType.get(), false, true);
9722
+ // _ivalue()->setAttr(name, std::move(v));
9686
9723
}
9687
- register_buffer(/* name, buffer */) {
9724
+ register_parameter(name, v, is_buffer) {
9725
+ this.type().addOrCheckAttribute(name, torch.TensorType.get(), !is_buffer, is_buffer);
9726
+ // _ivalue()->setAttr(name, std::move(v));
9688
9727
}
9689
- register_parameter(/* name, parameter, is_buffer */) {
9728
+ register_attribute(name, t, v, is_param, is_buffer) {
9729
+ this.type().addOrCheckAttribute(name, t, is_param, is_buffer);
9730
+ // _ivalue()->setAttr(name, v);
9690
9731
}
9691
9732
});
9692
9733
this.registerType('torch.ModuleDict', class {
0 commit comments