Skip to content

Commit 644837e

Browse files
committed
re-activated validator recent model demand loading mechanism
1 parent 62e1aae commit 644837e

File tree

3 files changed

+24
-16
lines changed

3 files changed

+24
-16
lines changed

tensorlink/ml/graphing.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ def __init__(self, user_memory: int = 0):
6464

6565
def create_distributed_config(
6666
self,
67-
model: nn.Module,
67+
model: Union[nn.Module, str],
6868
training: bool,
6969
trusted: bool,
7070
handle_layers: bool = True,

tensorlink/ml/validator.py

Lines changed: 22 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -254,17 +254,15 @@ def _get_popular_models(self) -> list:
254254
def _manage_auto_loaded_models(self):
255255
"""Manage auto-loaded models based on popularity from JSON cache, falling back to DEFAULT_MODELS"""
256256
popular_models = self._get_popular_models()
257-
# if not popular_models:
258-
# models_to_load = DEFAULT_MODELS[: self.MAX_AUTO_MODELS]
259-
# else:
260-
# models_to_load = popular_models[: self.MAX_AUTO_MODELS]
261-
# self.send_request(
262-
# "debug_print",
263-
# (f"Loading popular models: {models_to_load}", "blue", logging.INFO),
264-
# )
265-
266-
# If no popular models tracked yet, use DEFAULT_MODELS as fallback
267-
models_to_load = DEFAULT_MODELS[: self.MAX_AUTO_MODELS]
257+
258+
if not popular_models:
259+
models_to_load = DEFAULT_MODELS[: self.MAX_AUTO_MODELS]
260+
else:
261+
models_to_load = popular_models[: self.MAX_AUTO_MODELS]
262+
self.send_request(
263+
"debug_print",
264+
(f"Loading popular models: {models_to_load}", "blue", logging.INFO),
265+
)
268266

269267
# Load models up to the limit
270268
for model_name in models_to_load:
@@ -305,10 +303,10 @@ def _manage_auto_loaded_models(self):
305303
)
306304
self._remove_hosted_job(model_name)
307305

308-
def inspect_model(self, model_name: str, job_data: dict = None):
306+
def inspect_model(self, model_name: str, job_data: dict):
309307
"""Inspect a model to determine network requirements and store distribution in JSON cache"""
310308
parser = ModelParser()
311-
model_name = job_data.get("model_name", model_name)
309+
model_name: str = job_data.get("model_name", model_name)
312310

313311
# Load HF model, create and save distribution
314312
distribution = parser.create_distributed_config(
@@ -372,7 +370,7 @@ def check_node(self):
372370
job_data = self.send_request("get_jobs", None)
373371

374372
if isinstance(job_data, dict):
375-
model_name = job_data.get("model_name")
373+
model_name: str = job_data.get("model_name", "")
376374

377375
if job_data.get("api"):
378376
payment = job_data.get("payment", 0)
@@ -625,6 +623,16 @@ def _finalize_hosted_job(self, model_name: str):
625623
# Distribute the model across workers
626624
self.modules[module_id].distribute_model(distribution)
627625

626+
# Ensure workers are registered
627+
for dist_module_id, dist_module_info in distribution.items():
628+
if dist_module_id in self.modules and isinstance(
629+
self.modules[dist_module_id], dict
630+
):
631+
# Update workers list to ensure it's current
632+
self.modules[dist_module_id]["workers"] = dist_module_info.get(
633+
"workers", []
634+
)
635+
628636
# Load tokenizer
629637
self.tokenizers[model_name] = AutoTokenizer.from_pretrained(model_name)
630638

tensorlink/p2p/smart_node.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1457,7 +1457,7 @@ def send_to_node_from_file(self, n: Connection, file, tag):
14571457
def handle_message(self, node: Connection, data) -> None:
14581458
"""Callback method to handles incoming data from connections"""
14591459
self.debug_print(
1460-
f"handle_message from {node.host}:{node.port} -> {data.__sizeof__()/1e6}MB",
1460+
f"handle_message from {node.host}:{node.port} -> {data.__sizeof__() / 1e6}MB",
14611461
tag="Smartnode",
14621462
)
14631463

0 commit comments

Comments
 (0)