Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 11 additions & 3 deletions tensorlink/api/node.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@ class NodeRequest(BaseModel):

class JobRequest(BaseModel):
hf_name: str
time: int
payment: int
time: int = 1800
payment: int = 0


class GenerationRequest(BaseModel):
Expand Down Expand Up @@ -142,6 +142,7 @@ def request_model(job_request: JobRequest, request: Request):
# Trigger the loading process
job_data = {
"author": self.smart_node.rsa_key_hash,
"api": True,
"active": True,
"hosted": True,
"training": False,
Expand Down Expand Up @@ -345,7 +346,14 @@ def _start_server(self):
"""Start the FastAPI server in a separate thread"""

def run_server():
uvicorn.run(self.app, host=self.host, port=self.port)
uvicorn.run(
self.app,
host=self.host,
port=self.port,
timeout_keep_alive=20,
limit_concurrency=100,
lifespan="on",
)

server_thread = Thread(target=run_server, daemon=True)
server_thread.start()
212 changes: 146 additions & 66 deletions tensorlink/ml/validator.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,9 +254,6 @@ def _get_popular_models(self) -> list:
def _manage_auto_loaded_models(self):
"""Manage auto-loaded models based on popularity from JSON cache, falling back to DEFAULT_MODELS"""
popular_models = self._get_popular_models()

# If no popular models tracked yet, use DEFAULT_MODELS as fallback
models_to_load = DEFAULT_MODELS[: self.MAX_AUTO_MODELS]
# if not popular_models:
# models_to_load = DEFAULT_MODELS[: self.MAX_AUTO_MODELS]
# else:
Expand All @@ -266,6 +263,9 @@ def _manage_auto_loaded_models(self):
# (f"Loading popular models: {models_to_load}", "blue", logging.INFO),
# )

# If no popular models tracked yet, use DEFAULT_MODELS as fallback
models_to_load = DEFAULT_MODELS[: self.MAX_AUTO_MODELS]

# Load models up to the limit
for model_name in models_to_load:
if (
Expand All @@ -279,30 +279,9 @@ def _manage_auto_loaded_models(self):
self.models_initializing.add(model_name)
self._initialize_hosted_job(model_name)

# Continue initialization for models that are in progress
for model_name in list(self.models_initializing):
if model_name in models_to_load: # Still wanted
# Try second initialization call
self._initialize_hosted_job(model_name)
# Check if initialization is complete
if model_name in self.models and isinstance(
self.models[model_name], str
):
# Model is fully initialized (module_id is now a string)
self.models_initializing.discard(model_name)
self.send_request(
"debug_print",
(
f"Completed auto-loading model: {model_name}",
"green",
logging.INFO,
),
)
else:
# Model no longer wanted, cancel initialization
self.models_initializing.discard(model_name)
if model_name in self.models:
self._remove_hosted_job(model_name)
# Try to finalize models that are initializing
if self.models_initializing:
self._try_finalize_initializing_models()

# Remove models not in the current priority list
currently_loaded = [
Expand All @@ -314,15 +293,17 @@ def _manage_auto_loaded_models(self):
model_name, days=1
) # Check last day
if recent_requests < 5: # Low recent activity
self.send_request(
"debug_print",
(
f"Removing unpopular model: {model_name}",
"yellow",
logging.INFO,
),
)
self._remove_hosted_job(model_name)
is_active = self.send_request("check_job", (model_name,))
if not is_active:
self.send_request(
"debug_print",
(
f"Removing unpopular model: {model_name}",
"yellow",
logging.INFO,
),
)
self._remove_hosted_job(model_name)

def inspect_model(self, model_name: str, job_data: dict = None):
"""Inspect a model to determine network requirements and store distribution in JSON cache"""
Expand Down Expand Up @@ -391,8 +372,29 @@ def check_node(self):
job_data = self.send_request("get_jobs", None)

if isinstance(job_data, dict):
# Offload model inspection to a background thread to avoid blocking
self.inspect_model(job_data.get("model_name"), job_data)
model_name = job_data.get("model_name")

if job_data.get("api"):
payment = job_data.get("payment", 0)
time_limit = job_data.get("time", 1800)

# Initialize if not already done
if (
model_name not in self.models
and model_name not in self.models_initializing
):
self.models_initializing.add(model_name)
self._initialize_hosted_job(
model_name, payment=payment, time_limit=time_limit
)

# Try to finalize if already initializing
if model_name in self.models_initializing:
self._finalize_hosted_job(model_name)

else:
# If request via user node, begin the model reqs inspection for the job request
self.inspect_model(model_name, job_data)

# Check for inference generate calls
for model_name, module_id in self.models.items():
Expand Down Expand Up @@ -517,48 +519,57 @@ def _handle_generate_request(self, request: GenerationRequest):
# Decode generated tokens
generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)

# Extract only the assistant's response from the generated text
clean_response = extract_assistant_response(generated_text, request.hf_name)
request.output = clean_response
# Many models echo the prompt, so remove it
if generated_text.startswith(formatted_prompt):
request.output = generated_text[len(formatted_prompt) :].strip()
else:
request.output = generated_text

# Return the clean response
self.send_request("update_api_request", (request,))

def _initialize_hosted_job(self, model_name: str):
"""Method that can be invoked twice, once to begin setup of the job, and a second
time to finalize the job init."""
args = self.send_request("check_module", None)
def _try_finalize_initializing_models(self):
"""Attempt to finalize all models that are currently initializing."""
for model_name in list(self.models_initializing):
if self._finalize_hosted_job(model_name):
self.send_request(
"debug_print",
(
f"Successfully finalized model: {model_name}",
"green",
logging.INFO,
),
)

# Check if the model loading is complete across workers and ready to go (second call)
if model_name in self.models and args:
if isinstance(args, tuple):
(
file_name,
module_id,
distribution,
module_name,
optimizer_name,
training,
) = args
self.modules[module_id] = self.models.pop(model_name)
self.models[model_name] = module_id
self.modules[module_id].distribute_model(distribution)
self.tokenizers[model_name] = AutoTokenizer.from_pretrained(model_name)

# If not, check if we can spin up the model (first call)
else:
# Small init sleep time
def _initialize_hosted_job(
self, model_name: str, payment: int = 0, time_limit: int = None
):
"""Initialize a hosted job by creating the distributed model and submitting inspection request."""
try:
# Check if already initialized
if model_name in self.models:
time.sleep(20)
self.send_request(
"debug_print",
(
f"Model {model_name} already initializing, skipping duplicate init",
"yellow",
logging.DEBUG,
),
)
return

# Create distributed model instance
distributed_model = DistributedModel(model_name, node=self.node)
self.models[model_name] = distributed_model

# Prepare job data for inspection
job_data = {
"author": None,
"active": True,
"hosted": True,
"training": False,
"payment": 0,
"payment": payment,
"time": time_limit,
"capacity": 0,
"n_pipelines": 1,
"dp_factor": 1,
Expand All @@ -567,8 +578,77 @@ def _initialize_hosted_job(self, model_name: str):
"model_name": model_name,
"seed_validators": [],
}

# Inspect model to determine network requirements
self.inspect_model(model_name, job_data)

self.send_request(
"debug_print",
(f"Initialized hosted job for {model_name}", "green", logging.INFO),
)

except Exception as e:
logging.error(f"Error initializing hosted job for {model_name}: {str(e)}")
self.models_initializing.discard(model_name)
if model_name in self.models:
del self.models[model_name]

def _finalize_hosted_job(self, model_name: str):
"""Finalize a hosted job by setting up the distributed model with workers."""
try:
# Check if we have module info ready
args = self.send_request("check_module", None)

if not args or not isinstance(args, tuple):
# Module not ready yet
return False

# Check if model is in initialization state
if model_name not in self.models:
return False

# Unpack module information
(
file_name,
module_id,
distribution,
module_name,
optimizer_name,
training,
) = args

# Move from initialization to active state
distributed_model = self.models.pop(model_name)
self.modules[module_id] = distributed_model
self.models[model_name] = module_id

# Distribute the model across workers
self.modules[module_id].distribute_model(distribution)

# Load tokenizer
self.tokenizers[model_name] = AutoTokenizer.from_pretrained(model_name)

# Remove from initializing set
self.models_initializing.discard(model_name)

self.send_request(
"debug_print",
(
f"Finalized hosted job for {model_name} with module_id {module_id}",
"green",
logging.INFO,
),
)

return True

except Exception as e:
logging.error(f"Error finalizing hosted job for {model_name}: {str(e)}")
self.models_initializing.discard(model_name)
if model_name in self.models:
del self.models[model_name]
return False

def _remove_hosted_job(self, model_name: str):
"""Remove a hosted job and clean up all associated resources"""
try:
Expand Down
Loading