@@ -27,7 +27,7 @@ def convert_weight(old_state_dict, new_state_dict, model_size: int = 38):
2727 if old_name in weight_name :
2828 weight_name = weight_name .replace (old_name , new_name )
2929 if weight_name in new_weight_set :
30- assert new_state_dict [weight_name ].shape == weight_value .shape , "shape miss match"
30+ assert new_state_dict [weight_name ].shape == weight_value .shape , f "shape miss match { weight_name } "
3131 new_state_dict [weight_name ] = weight_value
3232 new_weight_set .remove (weight_name )
3333
@@ -136,3 +136,35 @@ def convert_weight_seg(old_state_dict, new_state_dict):
136136 print (f"{ new_state_dict [new_weight_name ].shape } { old_state_dict [old_weight_name ].shape } " )
137137 new_state_dict [new_weight_name ] = old_state_dict [old_weight_name ]
138138 return new_state_dict
139+
140+
141+ import sys
142+ from pathlib import Path
143+
144+ import hydra
145+ import torch
146+
147+ project_root = Path (__file__ ).resolve ().parent .parent .parent
148+ sys .path .append (str (project_root ))
149+
150+ from yolo .config .config import Config
151+ from yolo .tools .solver import BaseModel
152+
153+
154+ @hydra .main (config_path = "../config" , config_name = "config" , version_base = None )
155+ def main (cfg : Config ):
156+ old_weight_path = getattr (cfg , "old_weight" , "v9t.pt" )
157+ new_weight_path = getattr (cfg , "new_weight" , "ait.pt" )
158+ print (f"Changing { old_weight_path } -> { new_weight_path } " )
159+ cfg .weight = None
160+ model = BaseModel (cfg )
161+ old_weight = torch .load (old_weight_path , weights_only = False )
162+ new_weight = convert_weight (old_weight , model .model .state_dict ())
163+ model .model .load_state_dict (new_weight )
164+ torch .save (model .model .model .state_dict (), new_weight_path )
165+ cfg .weight = new_weight_path
166+ BaseModel (cfg )
167+
168+
169+ if __name__ == "__main__" :
170+ main ()
0 commit comments