Skip to content

Commit 83c2634

Browse files
committed
🔀 [Merge] remote-tracking branch 'origin/fix/clean-load-weight-log'
2 parents 9a5ffe2 + 16ba3d1 commit 83c2634

File tree

3 files changed

+43
-36
lines changed

3 files changed

+43
-36
lines changed

‎yolo/config/task/train.yaml‎

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,4 +51,4 @@ scheduler:
5151

5252
ema:
5353
enable: true
54-
decay: 0.995
54+
decay: 0.9999

‎yolo/model/yolo.py‎

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -133,9 +133,8 @@ def save_load_weights(self, weights: Union[Path, OrderedDict]):
133133
"""
134134
if isinstance(weights, Path):
135135
weights = torch.load(weights, map_location=torch.device("cpu"), weights_only=False)
136-
if "model_state_dict" in weights:
137-
weights = weights["model_state_dict"]
138-
136+
if "state_dict" in weights:
137+
weights = {name.removeprefix("model.model."): key for name, key in weights["state_dict"].items()}
139138
model_state_dict = self.model.state_dict()
140139

141140
# TODO1: autoload old version weight
@@ -152,8 +151,15 @@ def save_load_weights(self, weights: Union[Path, OrderedDict]):
152151
model_state_dict[model_key] = weights[model_key]
153152

154153
for error_name, error_set in error_dict.items():
155-
for weight_name in error_set:
156-
logger.warning(f":warning: Weight {error_name} for key: {'.'.join(weight_name)}")
154+
error_dict = dict()
155+
for layer_idx, *layer_name in error_set:
156+
if layer_idx not in error_dict:
157+
error_dict[layer_idx] = [".".join(layer_name)]
158+
else:
159+
error_dict[layer_idx].append(".".join(layer_name))
160+
for layer_idx, layer_name in error_dict.items():
161+
layer_name.sort()
162+
logger.warning(f":warning: Weight {error_name} for Layer {layer_idx}: {', '.join(layer_name)}")
157163

158164
self.model.load_state_dict(model_state_dict)
159165

‎yolo/tools/format_converters.py‎

Lines changed: 31 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -1,35 +1,36 @@
1+
convert_dict = {
2+
"19.cv1": "19.conv",
3+
"16.cv1": "16.conv",
4+
".7.cv1": ".7.conv",
5+
".5.cv1": ".5.conv",
6+
".3.cv1": ".3.conv",
7+
".28.": ".29.",
8+
".25.": ".26.",
9+
".22.": ".23.",
10+
"cv": "conv",
11+
".m.": ".bottleneck.",
12+
}
13+
14+
HEAD_NUM = "29"
15+
16+
117
def convert_weight(old_state_dict, new_state_dict, model_size: int = 38):
2-
# TODO: need to refactor
3-
shift = 1
4-
for idx in range(model_size):
5-
new_list, old_list = [], []
6-
for weight_name, weight_value in new_state_dict.items():
7-
if weight_name.split(".")[0] == str(idx):
8-
new_list.append((weight_name, None))
9-
for weight_name, weight_value in old_state_dict.items():
10-
if f"model.{idx+shift}." in weight_name:
11-
old_list.append((weight_name, weight_value))
12-
if len(new_list) == len(old_list):
13-
for (weight_name, _), (_, weight_value) in zip(new_list, old_list):
14-
new_state_dict[weight_name] = weight_value
18+
new_weight_set = set(new_state_dict.keys())
19+
for weight_name, weight_value in old_state_dict.items():
20+
if HEAD_NUM in weight_name:
21+
_, _, conv_name, conv_id, *post_fix = weight_name.split(".")
22+
head_id = 30 if conv_name in ["cv2", "cv3"] else 22
23+
head_type = "anchor_conv" if conv_name in ["cv2", "cv4"] else "class_conv"
24+
weight_name = ".".join(["model", str(head_id), "heads", conv_id, head_type, *post_fix])
1525
else:
16-
for weight_name, weight_value in old_list:
17-
if "dfl" in weight_name:
18-
continue
19-
_, _, conv_name, conv_idx, *details = weight_name.split(".")
20-
if conv_name == "cv4" or conv_name == "cv5":
21-
layer_idx = 22
22-
shift = 2
23-
else:
24-
layer_idx = 37
25-
26-
if conv_name == "cv2" or conv_name == "cv4":
27-
conv_task = "anchor_conv"
28-
if conv_name == "cv3" or conv_name == "cv5":
29-
conv_task = "class_conv"
30-
31-
weight_name = ".".join([str(layer_idx), "heads", conv_idx, conv_task, *details])
32-
new_state_dict[weight_name] = weight_value
26+
for old_name, new_name in convert_dict.items():
27+
if old_name in weight_name:
28+
weight_name = weight_name.replace(old_name, new_name)
29+
if weight_name in new_weight_set:
30+
assert new_state_dict[weight_name].shape == weight_value.shape, "shape miss match"
31+
new_state_dict[weight_name] = weight_value
32+
new_weight_set.remove(weight_name)
33+
3334
return new_state_dict
3435

3536

0 commit comments

Comments
 (0)