@@ -254,17 +254,15 @@ def _get_popular_models(self) -> list:
254254 def _manage_auto_loaded_models (self ):
255255 """Manage auto-loaded models based on popularity from JSON cache, falling back to DEFAULT_MODELS"""
256256 popular_models = self ._get_popular_models ()
257- # if not popular_models:
258- # models_to_load = DEFAULT_MODELS[: self.MAX_AUTO_MODELS]
259- # else:
260- # models_to_load = popular_models[: self.MAX_AUTO_MODELS]
261- # self.send_request(
262- # "debug_print",
263- # (f"Loading popular models: {models_to_load}", "blue", logging.INFO),
264- # )
265-
266- # If no popular models tracked yet, use DEFAULT_MODELS as fallback
267- models_to_load = DEFAULT_MODELS [: self .MAX_AUTO_MODELS ]
257+
258+ if not popular_models :
259+ models_to_load = DEFAULT_MODELS [: self .MAX_AUTO_MODELS ]
260+ else :
261+ models_to_load = popular_models [: self .MAX_AUTO_MODELS ]
262+ self .send_request (
263+ "debug_print" ,
264+ (f"Loading popular models: { models_to_load } " , "blue" , logging .INFO ),
265+ )
268266
269267 # Load models up to the limit
270268 for model_name in models_to_load :
@@ -305,10 +303,10 @@ def _manage_auto_loaded_models(self):
305303 )
306304 self ._remove_hosted_job (model_name )
307305
308- def inspect_model (self , model_name : str , job_data : dict = None ):
306+ def inspect_model (self , model_name : str , job_data : dict ):
309307 """Inspect a model to determine network requirements and store distribution in JSON cache"""
310308 parser = ModelParser ()
311- model_name = job_data .get ("model_name" , model_name )
309+ model_name : str = job_data .get ("model_name" , model_name )
312310
313311 # Load HF model, create and save distribution
314312 distribution = parser .create_distributed_config (
@@ -372,7 +370,7 @@ def check_node(self):
372370 job_data = self .send_request ("get_jobs" , None )
373371
374372 if isinstance (job_data , dict ):
375- model_name = job_data .get ("model_name" )
373+ model_name : str = job_data .get ("model_name" , " " )
376374
377375 if job_data .get ("api" ):
378376 payment = job_data .get ("payment" , 0 )
@@ -625,6 +623,16 @@ def _finalize_hosted_job(self, model_name: str):
625623 # Distribute the model across workers
626624 self .modules [module_id ].distribute_model (distribution )
627625
626+ # Ensure workers are registered
627+ for dist_module_id , dist_module_info in distribution .items ():
628+ if dist_module_id in self .modules and isinstance (
629+ self .modules [dist_module_id ], dict
630+ ):
631+ # Update workers list to ensure it's current
632+ self .modules [dist_module_id ]["workers" ] = dist_module_info .get (
633+ "workers" , []
634+ )
635+
628636 # Load tokenizer
629637 self .tokenizers [model_name ] = AutoTokenizer .from_pretrained (model_name )
630638
0 commit comments