-
Notifications
You must be signed in to change notification settings - Fork 0
/
OFAResNets18.patch
29 lines (28 loc) · 963 Bytes
/
OFAResNets18.patch
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
--- /home/rick/nas_rram/ofa/once-for-all/ofa/imagenet_classification/elastic_nn/networks/ofa_resnets.py
+++ /home/rick/nas_rram/ofa/once-for-all/ofa/imagenet_classification/elastic_nn/networks/ofa_resnets.py
@@ -135,8 +135,6 @@
model_dict = self.state_dict()
for key in state_dict:
new_key = key
- # import pdb
- # pdb.set_trace()
if new_key in model_dict:
pass
elif '.linear.' in new_key:
@@ -148,7 +146,6 @@
else:
raise ValueError(new_key)
assert new_key in model_dict, '%s' % new_key
-
model_dict[new_key] = state_dict[key]
super(OFAResNets18, self).load_state_dict(model_dict)
@@ -161,8 +158,7 @@
depth = val2list(d, len(ResNets.BASE_DEPTH_LIST) + 1)
expand_ratio = val2list(e, len(self.blocks))
width_mult = val2list(w, len(ResNets.BASE_DEPTH_LIST) + 2)
- # import pdb
- # pdb.set_trace()
+
for block, e in zip(self.blocks, expand_ratio):
if e is not None:
block.active_expand_ratio = e