@@ -27,6 +27,7 @@ def _get_processor_function(model_type: str) -> Callable:
2727 "paligemma" ,
2828 "paligemma2" ,
2929 "florence-2" ,
30+ "rfdetr" ,
3031 ]
3132
3233 if not any (supported_model in model_type for supported_model in supported_models ):
@@ -57,6 +58,9 @@ def _get_processor_function(model_type: str) -> Callable:
5758 if "yolonas" in model_type :
5859 return _process_yolonas
5960
61+ if "rfdetr" in model_type :
62+ return _process_rfdetr
63+
6064 return _process_yolo
6165
6266
@@ -220,6 +224,79 @@ def _process_yolo(model_type: str, model_path: str, filename: str) -> str:
220224 return zip_file_name
221225
222226
227+ def _process_rfdetr (model_type : str , model_path : str , filename : str ) -> str :
228+ _supported_types = ["rfdetr-base" , "rfdetr-large" ]
229+ if model_type not in _supported_types :
230+ raise ValueError (f"Model type { model_type } not supported. Supported types are { _supported_types } " )
231+
232+ if not os .path .exists (model_path ):
233+ raise FileNotFoundError (f"Model path { model_path } does not exist." )
234+
235+ model_files = os .listdir (model_path )
236+ pt_file = next ((f for f in model_files if f .endswith (".pt" ) or f .endswith (".pth" )), None )
237+
238+ if pt_file is None :
239+ raise RuntimeError ("No .pt or .pth model file found in the provided path" )
240+
241+ get_classnames_txt_for_rfdetr (model_path , pt_file )
242+
243+ # Copy the .pt file to weights.pt if not already named weights.pt
244+ if pt_file != "weights.pt" :
245+ shutil .copy (os .path .join (model_path , pt_file ), os .path .join (model_path , "weights.pt" ))
246+
247+ required_files = ["weights.pt" ]
248+
249+ optional_files = ["results.csv" , "results.png" , "model_artifacts.json" , "class_names.txt" ]
250+
251+ zip_file_name = "roboflow_deploy.zip"
252+ with zipfile .ZipFile (os .path .join (model_path , zip_file_name ), "w" ) as zipMe :
253+ for file in required_files :
254+ zipMe .write (os .path .join (model_path , file ), arcname = file , compress_type = zipfile .ZIP_DEFLATED )
255+
256+ for file in optional_files :
257+ if os .path .exists (os .path .join (model_path , file )):
258+ zipMe .write (os .path .join (model_path , file ), arcname = file , compress_type = zipfile .ZIP_DEFLATED )
259+
260+ return zip_file_name
261+
262+
263+ def get_classnames_txt_for_rfdetr (model_path : str , pt_file : str ):
264+ class_names_path = os .path .join (model_path , "class_names.txt" )
265+ if os .path .exists (class_names_path ):
266+ maybe_prepend_dummy_class (class_names_path )
267+ return class_names_path
268+
269+ import torch
270+
271+ model = torch .load (os .path .join (model_path , pt_file ), map_location = "cpu" , weights_only = False )
272+ args = vars (model ["args" ])
273+ if "class_names" in args :
274+ with open (class_names_path , "w" ) as f :
275+ for class_name in args ["class_names" ]:
276+ f .write (class_name + "\n " )
277+ maybe_prepend_dummy_class (class_names_path )
278+ return class_names_path
279+
280+ raise FileNotFoundError (
281+ f"No class_names.txt file found in model path { model_path } .\n "
282+ f"This should only happen on rfdetr models trained before version 1.1.0.\n "
283+ f"Please re-train your model with the latest version of the rfdetr library, or\n "
284+ f"please create a class_names.txt file in the model path with the class names\n "
285+ f"in new lines in the order of the classes in the model.\n "
286+ )
287+
288+
289+ def maybe_prepend_dummy_class (class_name_file : str ):
290+ with open (class_name_file ) as f :
291+ class_names = f .readlines ()
292+
293+ dummy_class = "background_class83422\n "
294+ if dummy_class not in class_names :
295+ class_names .insert (0 , dummy_class )
296+ with open (class_name_file , "w" ) as f :
297+ f .writelines (class_names )
298+
299+
223300def _process_huggingface (
224301 model_type : str , model_path : str , filename : str = "fine-tuned-paligemma-3b-pt-224.f16.npz"
225302) -> str :
0 commit comments