diff --git a/yolov9/gen_wts.py b/yolov9/gen_wts.py index a507e1f4..0420d721 100644 --- a/yolov9/gen_wts.py +++ b/yolov9/gen_wts.py @@ -32,7 +32,7 @@ def parse_args(): print(f'Loading {pt_file}') device = select_device('cpu') model = torch.load(pt_file, map_location=device) # Load FP32 weights -model.model.float() +model = model['model'].float() if m_type in ['detect', 'seg']: # update anchor_grid info @@ -67,4 +67,4 @@ def parse_args(): for k, v in model.state_dict().items(): vr = v.reshape(-1).cpu().numpy() f.write('{} {} '.format(k, len(vr))) - f.write('\n') + f.write('\n') \ No newline at end of file