|
| 1 | +convert_dict = { |
| 2 | + "19.cv1": "19.conv", |
| 3 | + "16.cv1": "16.conv", |
| 4 | + ".7.cv1": ".7.conv", |
| 5 | + ".5.cv1": ".5.conv", |
| 6 | + ".3.cv1": ".3.conv", |
| 7 | + ".28.": ".29.", |
| 8 | + ".25.": ".26.", |
| 9 | + ".22.": ".23.", |
| 10 | + "cv": "conv", |
| 11 | + ".m.": ".bottleneck.", |
| 12 | +} |
| 13 | + |
| 14 | +HEAD_NUM = "29" |
| 15 | + |
| 16 | + |
1 | 17 | def convert_weight(old_state_dict, new_state_dict, model_size: int = 38): |
2 | | - # TODO: need to refactor |
3 | | - shift = 1 |
4 | | - for idx in range(model_size): |
5 | | - new_list, old_list = [], [] |
6 | | - for weight_name, weight_value in new_state_dict.items(): |
7 | | - if weight_name.split(".")[0] == str(idx): |
8 | | - new_list.append((weight_name, None)) |
9 | | - for weight_name, weight_value in old_state_dict.items(): |
10 | | - if f"model.{idx+shift}." in weight_name: |
11 | | - old_list.append((weight_name, weight_value)) |
12 | | - if len(new_list) == len(old_list): |
13 | | - for (weight_name, _), (_, weight_value) in zip(new_list, old_list): |
14 | | - new_state_dict[weight_name] = weight_value |
| 18 | + new_weight_set = set(new_state_dict.keys()) |
| 19 | + for weight_name, weight_value in old_state_dict.items(): |
| 20 | + if HEAD_NUM in weight_name: |
| 21 | + _, _, conv_name, conv_id, *post_fix = weight_name.split(".") |
| 22 | + head_id = 30 if conv_name in ["cv2", "cv3"] else 22 |
| 23 | + head_type = "anchor_conv" if conv_name in ["cv2", "cv4"] else "class_conv" |
| 24 | + weight_name = ".".join(["model", str(head_id), "heads", conv_id, head_type, *post_fix]) |
15 | 25 | else: |
16 | | - for weight_name, weight_value in old_list: |
17 | | - if "dfl" in weight_name: |
18 | | - continue |
19 | | - _, _, conv_name, conv_idx, *details = weight_name.split(".") |
20 | | - if conv_name == "cv4" or conv_name == "cv5": |
21 | | - layer_idx = 22 |
22 | | - shift = 2 |
23 | | - else: |
24 | | - layer_idx = 37 |
25 | | - |
26 | | - if conv_name == "cv2" or conv_name == "cv4": |
27 | | - conv_task = "anchor_conv" |
28 | | - if conv_name == "cv3" or conv_name == "cv5": |
29 | | - conv_task = "class_conv" |
30 | | - |
31 | | - weight_name = ".".join([str(layer_idx), "heads", conv_idx, conv_task, *details]) |
32 | | - new_state_dict[weight_name] = weight_value |
| 26 | + for old_name, new_name in convert_dict.items(): |
| 27 | + if old_name in weight_name: |
| 28 | + weight_name = weight_name.replace(old_name, new_name) |
| 29 | + if weight_name in new_weight_set: |
| 30 | + assert new_state_dict[weight_name].shape == weight_value.shape, "shape miss match" |
| 31 | + new_state_dict[weight_name] = weight_value |
| 32 | + new_weight_set.remove(weight_name) |
| 33 | + |
33 | 34 | return new_state_dict |
34 | 35 |
|
35 | 36 |
|
|
0 commit comments