Skip to content

Commit 7612049

Browse files
committed
🔀 [Merge] branch 'feature/weight-transfer-from-other-repo'
2 parents 83c2634 + 1b54259 commit 7612049

File tree

1 file changed

+33
-1
lines changed

1 file changed

+33
-1
lines changed

‎yolo/tools/format_converters.py‎

Lines changed: 33 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ def convert_weight(old_state_dict, new_state_dict, model_size: int = 38):
2727
if old_name in weight_name:
2828
weight_name = weight_name.replace(old_name, new_name)
2929
if weight_name in new_weight_set:
30-
assert new_state_dict[weight_name].shape == weight_value.shape, "shape miss match"
30+
assert new_state_dict[weight_name].shape == weight_value.shape, f"shape miss match {weight_name}"
3131
new_state_dict[weight_name] = weight_value
3232
new_weight_set.remove(weight_name)
3333

@@ -136,3 +136,35 @@ def convert_weight_seg(old_state_dict, new_state_dict):
136136
print(f"{new_state_dict[new_weight_name].shape} {old_state_dict[old_weight_name].shape}")
137137
new_state_dict[new_weight_name] = old_state_dict[old_weight_name]
138138
return new_state_dict
139+
140+
141+
import sys
142+
from pathlib import Path
143+
144+
import hydra
145+
import torch
146+
147+
project_root = Path(__file__).resolve().parent.parent.parent
148+
sys.path.append(str(project_root))
149+
150+
from yolo.config.config import Config
151+
from yolo.tools.solver import BaseModel
152+
153+
154+
@hydra.main(config_path="../config", config_name="config", version_base=None)
155+
def main(cfg: Config):
156+
old_weight_path = getattr(cfg, "old_weight", "v9t.pt")
157+
new_weight_path = getattr(cfg, "new_weight", "ait.pt")
158+
print(f"Changing {old_weight_path} -> {new_weight_path}")
159+
cfg.weight = None
160+
model = BaseModel(cfg)
161+
old_weight = torch.load(old_weight_path, weights_only=False)
162+
new_weight = convert_weight(old_weight, model.model.state_dict())
163+
model.model.load_state_dict(new_weight)
164+
torch.save(model.model.model.state_dict(), new_weight_path)
165+
cfg.weight = new_weight_path
166+
BaseModel(cfg)
167+
168+
169+
if __name__ == "__main__":
170+
main()

0 commit comments

Comments
 (0)