diff --git a/latencypredictor/manifests/dual-server-deployment.yaml b/latencypredictor/manifests/dual-server-deployment.yaml index 761cf79d2..de19aed4d 100644 --- a/latencypredictor/manifests/dual-server-deployment.yaml +++ b/latencypredictor/manifests/dual-server-deployment.yaml @@ -13,6 +13,8 @@ data: LATENCY_TPOT_MODEL_PATH: "/models/tpot.joblib" LATENCY_TTFT_SCALER_PATH: "/models/ttft_scaler.joblib" LATENCY_TPOT_SCALER_PATH: "/models/tpot_scaler.joblib" + LATENCY_TTFT_TREELITE_PATH: "/models/ttft_treelite.so" + LATENCY_TPOT_TREELITE_PATH: "/models/tpot_treelite.so" LATENCY_MODEL_TYPE: "xgboost" --- @@ -22,7 +24,7 @@ metadata: name: prediction-server-config namespace: default data: - MODEL_SYNC_INTERVAL_SEC: "10" # Download models every 5 seconds + MODEL_SYNC_INTERVAL_SEC: "10" # Download models every 10 seconds LATENCY_MODEL_TYPE: "xgboost" PREDICT_HOST: "0.0.0.0" PREDICT_PORT: "8001" @@ -31,6 +33,9 @@ data: LOCAL_TPOT_MODEL_PATH: "/local_models/tpot.joblib" LOCAL_TTFT_SCALER_PATH: "/local_models/ttft_scaler.joblib" LOCAL_TPOT_SCALER_PATH: "/local_models/tpot_scaler.joblib" + LOCAL_TTFT_TREELITE_PATH: "/local_models/ttft_treelite.so" + LOCAL_TPOT_TREELITE_PATH: "/local_models/tpot_treelite.so" + USE_TREELITE: "true" # Enable TreeLite for faster inference HTTP_TIMEOUT: "30" --- diff --git a/latencypredictor/prediction_server.py b/latencypredictor/prediction_server.py index 581f83421..ebf000f65 100644 --- a/latencypredictor/prediction_server.py +++ b/latencypredictor/prediction_server.py @@ -42,12 +42,19 @@ except ImportError: LIGHTGBM_AVAILABLE = False logging.warning("LightGBM not available. Install with: pip install lightgbm") + +try: + import treelite_runtime + TREELITE_AVAILABLE = True +except ImportError: + TREELITE_AVAILABLE = False + logging.warning("TreeLite runtime not available. Install with: pip install treelite_runtime") class ModelType(str, Enum): BAYESIAN_RIDGE = "bayesian_ridge" XGBOOST = "xgboost" LIGHTGBM = "lightgbm" - + TREELITE = "treelite" class PredictSettings: """Configuration for the prediction server.""" @@ -60,6 +67,11 @@ class PredictSettings: LOCAL_TPOT_MODEL_PATH: str = os.getenv("LOCAL_TPOT_MODEL_PATH", "/local_models/tpot.joblib") LOCAL_TTFT_SCALER_PATH: str = os.getenv("LOCAL_TTFT_SCALER_PATH", "/local_models/ttft_scaler.joblib") LOCAL_TPOT_SCALER_PATH: str = os.getenv("LOCAL_TPOT_SCALER_PATH", "/local_models/tpot_scaler.joblib") + LOCAL_TTFT_TREELITE_PATH: str = os.getenv("LOCAL_TTFT_TREELITE_PATH", "/local_models/ttft_treelite.so") + LOCAL_TPOT_TREELITE_PATH: str = os.getenv("LOCAL_TPOT_TREELITE_PATH", "/local_models/tpot_treelite.so") + + # Use TreeLite for inference (preferred for production) + USE_TREELITE: bool = os.getenv("USE_TREELITE", "true").lower() == "true" # Sync interval and model type MODEL_SYNC_INTERVAL_SEC: int = int(os.getenv("MODEL_SYNC_INTERVAL_SEC", "10")) @@ -94,6 +106,8 @@ def __init__(self): settings.LOCAL_TPOT_MODEL_PATH, settings.LOCAL_TTFT_SCALER_PATH, settings.LOCAL_TPOT_SCALER_PATH, + settings.LOCAL_TTFT_TREELITE_PATH, + settings.LOCAL_TPOT_TREELITE_PATH, ]: os.makedirs(os.path.dirname(path), exist_ok=True) @@ -150,11 +164,21 @@ def sync_models(self) -> bool: ("ttft", settings.LOCAL_TTFT_MODEL_PATH), ("tpot", settings.LOCAL_TPOT_MODEL_PATH), ] + + # Sync TreeLite models if enabled + if settings.USE_TREELITE and TREELITE_AVAILABLE: + to_sync += [ + ("ttft_treelite", settings.LOCAL_TTFT_TREELITE_PATH), + ("tpot_treelite", settings.LOCAL_TPOT_TREELITE_PATH), + ] + + # Sync scalers only for Bayesian Ridge if settings.MODEL_TYPE == ModelType.BAYESIAN_RIDGE: to_sync += [ ("ttft_scaler", settings.LOCAL_TTFT_SCALER_PATH), ("tpot_scaler", settings.LOCAL_TPOT_SCALER_PATH), ] + for name, path in to_sync: if self._download_model_if_newer(name, path): updated = True @@ -189,7 +213,7 @@ class LightweightPredictor: def __init__(self): mt = settings.MODEL_TYPE self.prefix_buckets = 4 - + # Add LightGBM fallback logic if mt == ModelType.XGBOOST and not XGBOOST_AVAILABLE: logging.warning("XGBoost not available. Falling back to Bayesian Ridge") @@ -197,23 +221,34 @@ def __init__(self): elif mt == ModelType.LIGHTGBM and not LIGHTGBM_AVAILABLE: logging.warning("LightGBM not available. Falling back to Bayesian Ridge") mt = ModelType.BAYESIAN_RIDGE - + self.model_type = mt self.quantile = settings.QUANTILE_ALPHA + self.use_treelite = settings.USE_TREELITE and TREELITE_AVAILABLE and mt in [ModelType.XGBOOST, ModelType.LIGHTGBM] + + # Model storage self.ttft_model = None self.tpot_model = None self.ttft_scaler = None self.tpot_scaler = None + + # TreeLite predictors (lightweight, compiled models) + self.ttft_predictor = None + self.tpot_predictor = None + self.lock = threading.RLock() self.last_load: Optional[datetime] = None - logging.info(f"Predictor type: {self.model_type}, quantile: {self.quantile}") + logging.info(f"Predictor type: {self.model_type}, quantile: {self.quantile}, use_treelite: {self.use_treelite}") @property def is_ready(self) -> bool: with self.lock: if self.model_type == ModelType.BAYESIAN_RIDGE: return all([self.ttft_model, self.tpot_model, self.ttft_scaler, self.tpot_scaler]) - else: # XGBoost or LightGBM + elif self.use_treelite: + # For TreeLite, we need the compiled predictors + return all([self.ttft_predictor, self.tpot_predictor]) + else: # XGBoost or LightGBM without TreeLite return all([self.ttft_model, self.tpot_model]) def _prepare_features_with_interaction(self, df: pd.DataFrame, model_type: str) -> pd.DataFrame: @@ -268,21 +303,38 @@ def _prepare_features_with_interaction(self, df: pd.DataFrame, model_type: str) def load_models(self) -> bool: try: with self.lock: - new_ttft = joblib.load(settings.LOCAL_TTFT_MODEL_PATH) if os.path.exists(settings.LOCAL_TTFT_MODEL_PATH) else None - new_tpot = joblib.load(settings.LOCAL_TPOT_MODEL_PATH) if os.path.exists(settings.LOCAL_TPOT_MODEL_PATH) else None - if self.model_type == ModelType.BAYESIAN_RIDGE: - new_ttft_scaler = joblib.load(settings.LOCAL_TTFT_SCALER_PATH) if os.path.exists(settings.LOCAL_TTFT_SCALER_PATH) else None - new_tpot_scaler = joblib.load(settings.LOCAL_TPOT_SCALER_PATH) if os.path.exists(settings.LOCAL_TPOT_SCALER_PATH) else None + # Load TreeLite models if enabled + if self.use_treelite: + if os.path.exists(settings.LOCAL_TTFT_TREELITE_PATH): + self.ttft_predictor = treelite_runtime.Predictor(settings.LOCAL_TTFT_TREELITE_PATH, nthread=8) + logging.info("TTFT TreeLite model loaded") + else: + logging.warning(f"TreeLite model not found: {settings.LOCAL_TTFT_TREELITE_PATH}") + + if os.path.exists(settings.LOCAL_TPOT_TREELITE_PATH): + self.tpot_predictor = treelite_runtime.Predictor(settings.LOCAL_TPOT_TREELITE_PATH, nthread=8) + logging.info("TPOT TreeLite model loaded") + else: + logging.warning(f"TreeLite model not found: {settings.LOCAL_TPOT_TREELITE_PATH}") else: - new_ttft_scaler = new_tpot_scaler = None + # Load XGBoost/LightGBM/BayesianRidge models + new_ttft = joblib.load(settings.LOCAL_TTFT_MODEL_PATH) if os.path.exists(settings.LOCAL_TTFT_MODEL_PATH) else None + new_tpot = joblib.load(settings.LOCAL_TPOT_MODEL_PATH) if os.path.exists(settings.LOCAL_TPOT_MODEL_PATH) else None + + if self.model_type == ModelType.BAYESIAN_RIDGE: + new_ttft_scaler = joblib.load(settings.LOCAL_TTFT_SCALER_PATH) if os.path.exists(settings.LOCAL_TTFT_SCALER_PATH) else None + new_tpot_scaler = joblib.load(settings.LOCAL_TPOT_SCALER_PATH) if os.path.exists(settings.LOCAL_TPOT_SCALER_PATH) else None + else: + new_ttft_scaler = new_tpot_scaler = None + + if new_ttft: self.ttft_model = new_ttft + if new_tpot: self.tpot_model = new_tpot + if new_ttft_scaler: self.ttft_scaler = new_ttft_scaler + if new_tpot_scaler: self.tpot_scaler = new_tpot_scaler - if new_ttft: self.ttft_model = new_ttft - if new_tpot: self.tpot_model = new_tpot - if new_ttft_scaler: self.ttft_scaler = new_ttft_scaler - if new_tpot_scaler: self.tpot_scaler = new_tpot_scaler self.last_load = datetime.now(timezone.utc) if self.is_ready: - logging.info("Models loaded") + logging.info(f"Models loaded successfully (TreeLite: {self.use_treelite})") return True logging.warning("Models missing after load") return False @@ -296,9 +348,9 @@ def predict(self, features: dict) -> Tuple[float, float]: with self.lock: if not self.is_ready: raise HTTPException(status_code=503, detail="Models not ready") - + # Validation - required = ['kv_cache_percentage', 'input_token_length', 'num_request_waiting', + required = ['kv_cache_percentage', 'input_token_length', 'num_request_waiting', 'num_request_running', 'num_tokens_generated', 'prefix_cache_score'] for f in required: if f not in features: @@ -314,7 +366,7 @@ def predict(self, features: dict) -> Tuple[float, float]: 'num_request_running': features['num_request_running'], 'prefix_cache_score': features['prefix_cache_score'] } - + tpot_raw_data = { 'kv_cache_percentage': features['kv_cache_percentage'], 'input_token_length': features['input_token_length'], @@ -322,18 +374,30 @@ def predict(self, features: dict) -> Tuple[float, float]: 'num_request_running': features['num_request_running'], 'num_tokens_generated': features['num_tokens_generated'] } - + # Prepare features with interactions df_ttft_raw = pd.DataFrame([ttft_raw_data]) df_ttft = self._prepare_features_with_interaction(df_ttft_raw, "ttft") - - + df_tpot_raw = pd.DataFrame([tpot_raw_data]) df_tpot = self._prepare_features_with_interaction(df_tpot_raw, "tpot") - #df_tpot = pd.DataFrame([tpot_raw_data]) - if self.model_type == ModelType.BAYESIAN_RIDGE: - + # Use TreeLite for inference if enabled + if self.use_treelite: + # TreeLite expects numpy arrays + ttft_arr = df_ttft.values.astype('float32') + tpot_arr = df_tpot.values.astype('float32') + + # Create DMatrix for TreeLite + ttft_dmat = treelite_runtime.DMatrix(ttft_arr) + tpot_dmat = treelite_runtime.DMatrix(tpot_arr) + + ttft_pred = self.ttft_predictor.predict(ttft_dmat) + tpot_pred = self.tpot_predictor.predict(tpot_dmat) + + return float(ttft_pred[0]), float(tpot_pred[0]) + + elif self.model_type == ModelType.BAYESIAN_RIDGE: ttft_for_scale = df_ttft.drop(columns=['prefill_score_bucket'], errors='ignore') ttft_scaled = self.ttft_scaler.transform(ttft_for_scale) tpot_scaled = self.tpot_scaler.transform(df_tpot) @@ -344,19 +408,19 @@ def predict(self, features: dict) -> Tuple[float, float]: std_factor = 1.28 if self.quantile == 0.9 else (2.0 if self.quantile == 0.95 else 0.674) ttft_pred = ttft_pred_mean[0] + std_factor * ttft_std[0] tpot_pred = tpot_pred_mean[0] + std_factor * tpot_std[0] - + return ttft_pred, tpot_pred - + elif self.model_type == ModelType.XGBOOST: ttft_pred = self.ttft_model.predict(df_ttft) tpot_pred = self.tpot_model.predict(df_tpot) - + return ttft_pred[0], tpot_pred[0] - + else: # LightGBM ttft_pred = self.ttft_model.predict(df_ttft) tpot_pred = self.tpot_model.predict(df_tpot) - + return ttft_pred[0], tpot_pred[0] except ValueError as ve: @@ -374,9 +438,9 @@ def predict_batch(self, features_list: List[dict]) -> Tuple[np.ndarray, np.ndarr with self.lock: if not self.is_ready: raise HTTPException(status_code=503, detail="Models not ready") - + # Validation - required = ['kv_cache_percentage', 'input_token_length', 'num_request_waiting', + required = ['kv_cache_percentage', 'input_token_length', 'num_request_waiting', 'num_request_running', 'num_tokens_generated', 'prefix_cache_score'] for i, features in enumerate(features_list): for f in required: @@ -388,7 +452,7 @@ def predict_batch(self, features_list: List[dict]) -> Tuple[np.ndarray, np.ndarr # Create raw feature data (without interaction) ttft_raw_data = [] tpot_raw_data = [] - + for features in features_list: ttft_raw_data.append({ 'kv_cache_percentage': features['kv_cache_percentage'], @@ -397,7 +461,7 @@ def predict_batch(self, features_list: List[dict]) -> Tuple[np.ndarray, np.ndarr 'num_request_running': features['num_request_running'], 'prefix_cache_score': features['prefix_cache_score'] }) - + tpot_raw_data.append({ 'kv_cache_percentage': features['kv_cache_percentage'], 'input_token_length': features['input_token_length'], @@ -405,40 +469,53 @@ def predict_batch(self, features_list: List[dict]) -> Tuple[np.ndarray, np.ndarr 'num_request_running': features['num_request_running'], 'num_tokens_generated': features['num_tokens_generated'] }) - + # Prepare features with interactions df_ttft_raw = pd.DataFrame(ttft_raw_data) df_ttft_batch = self._prepare_features_with_interaction(df_ttft_raw, "ttft") - #df_ttft_batch = pd.DataFrame(ttft_raw_data) - + df_tpot_raw = pd.DataFrame(tpot_raw_data) df_tpot_batch = self._prepare_features_with_interaction(df_tpot_raw, "tpot") - #df_tpot_batch = pd.DataFrame(tpot_raw_data) - if self.model_type == ModelType.BAYESIAN_RIDGE: + # Use TreeLite for batch inference if enabled + if self.use_treelite: + # TreeLite expects numpy arrays + ttft_arr = df_ttft_batch.values.astype('float32') + tpot_arr = df_tpot_batch.values.astype('float32') + + # Create DMatrix for TreeLite + ttft_dmat = treelite_runtime.DMatrix(ttft_arr) + tpot_dmat = treelite_runtime.DMatrix(tpot_arr) + + ttft_pred = self.ttft_predictor.predict(ttft_dmat) + tpot_pred = self.tpot_predictor.predict(tpot_dmat) + + return ttft_pred, tpot_pred + + elif self.model_type == ModelType.BAYESIAN_RIDGE: ttft_for_scale = df_ttft_batch.drop(columns=['prefill_score_bucket'], errors='ignore') ttft_scaled = self.ttft_scaler.transform(ttft_for_scale) tpot_scaled = self.tpot_scaler.transform(df_tpot_batch) ttft_pred_mean, ttft_std = self.ttft_model.predict(ttft_scaled, return_std=True) tpot_pred_mean, tpot_std = self.tpot_model.predict(tpot_scaled, return_std=True) - + std_factor = 1.28 if self.quantile == 0.9 else (2.0 if self.quantile == 0.95 else 0.674) ttft_pred = ttft_pred_mean + std_factor * ttft_std tpot_pred = tpot_pred_mean + std_factor * tpot_std - + return ttft_pred, tpot_pred - + elif self.model_type == ModelType.XGBOOST: ttft_pred = self.ttft_model.predict(df_ttft_batch) tpot_pred = self.tpot_model.predict(df_tpot_batch) - + return ttft_pred, tpot_pred - + else: # LightGBM ttft_pred = self.ttft_model.predict(df_ttft_batch) tpot_pred = self.tpot_model.predict(df_tpot_batch) - + return ttft_pred, tpot_pred except ValueError as ve: @@ -773,4 +850,4 @@ async def shutdown(): if __name__ == "__main__": - uvicorn.run("__main__:app", host=settings.HOST, port=settings.PORT, reload=True) \ No newline at end of file + uvicorn.run("__main__:app", host=settings.HOST, port=settings.PORT, reload=True) diff --git a/latencypredictor/requirements.txt b/latencypredictor/requirements.txt index 30b6c54dd..ca8bc2c11 100644 --- a/latencypredictor/requirements.txt +++ b/latencypredictor/requirements.txt @@ -10,3 +10,5 @@ requests xgboost aiohttp lightgbm +treelite +treelite_runtime diff --git a/latencypredictor/test_dual_server_client.py b/latencypredictor/test_dual_server_client.py index 5ab96699a..adaff21c3 100644 --- a/latencypredictor/test_dual_server_client.py +++ b/latencypredictor/test_dual_server_client.py @@ -119,17 +119,20 @@ def test_training_server_models_list(): """Test training server models list endpoint.""" r = requests.get(f"{TRAINING_URL}/models/list") assert r.status_code == 200 - + data = r.json() assert "models" in data assert "model_type" in data assert "server_time" in data - + models = data["models"] expected_models = ["ttft", "tpot"] if data["model_type"] == "bayesian_ridge": expected_models.extend(["ttft_scaler", "tpot_scaler"]) - + elif data["model_type"] in ["xgboost", "lightgbm"]: + # TreeLite models should also be available for XGBoost and LightGBM + expected_models.extend(["ttft_treelite", "tpot_treelite"]) + for model_name in expected_models: assert model_name in models, f"Model {model_name} should be listed" print(f"Model {model_name}: exists={models[model_name]['exists']}, size={models[model_name]['size_bytes']} bytes") @@ -140,7 +143,8 @@ def test_model_download_from_training_server(): # First check what models are available models_r = requests.get(f"{TRAINING_URL}/models/list") models_data = models_r.json() - + + # Test basic models (ttft, tpot) for model_name in ["ttft", "tpot"]: if models_data["models"][model_name]["exists"]: # Test model info endpoint @@ -149,13 +153,13 @@ def test_model_download_from_training_server(): info_data = info_r.json() assert info_data["exists"] == True assert info_data["size_bytes"] > 0 - + # Test model download with retry and streaming max_retries = 3 for attempt in range(max_retries): try: download_r = requests.get( - f"{TRAINING_URL}/model/{model_name}/download", + f"{TRAINING_URL}/model/{model_name}/download", timeout=30, stream=True # Use streaming to handle large files better ) @@ -164,7 +168,7 @@ def test_model_download_from_training_server(): content_length = 0 for chunk in download_r.iter_content(chunk_size=8192): content_length += len(chunk) - + assert content_length > 0, f"Downloaded {model_name} model is empty" print(f"Successfully downloaded {model_name} model ({content_length} bytes)") break @@ -176,6 +180,79 @@ def test_model_download_from_training_server(): continue time.sleep(2) # Wait before retry + # Test TreeLite models for XGBoost and LightGBM + model_type = models_data["model_type"] + if model_type in ["xgboost", "lightgbm"]: + for model_name in ["ttft_treelite", "tpot_treelite"]: + if models_data["models"].get(model_name, {}).get("exists"): + # Test model info endpoint + info_r = requests.get(f"{TRAINING_URL}/model/{model_name}/info") + assert info_r.status_code == 200 + info_data = info_r.json() + assert info_data["exists"] == True + assert info_data["size_bytes"] > 0 + + # Test model download with retry and streaming + max_retries = 3 + for attempt in range(max_retries): + try: + download_r = requests.get( + f"{TRAINING_URL}/model/{model_name}/download", + timeout=30, + stream=True + ) + if download_r.status_code == 200: + # Read content in chunks to avoid memory issues + content_length = 0 + for chunk in download_r.iter_content(chunk_size=8192): + content_length += len(chunk) + + assert content_length > 0, f"Downloaded {model_name} model is empty" + print(f"Successfully downloaded {model_name} TreeLite model ({content_length} bytes)") + break + except requests.exceptions.ChunkedEncodingError as e: + print(f"Download attempt {attempt + 1}/{max_retries} failed for {model_name}: {e}") + if attempt == max_retries - 1: + print(f"⚠️ TreeLite model download test skipped for {model_name} due to connection issues") + continue + time.sleep(2) # Wait before retry + +def test_treelite_models_on_training_server(): + """Test TreeLite model endpoints on training server for XGBoost and LightGBM.""" + model_info_r = requests.get(f"{TRAINING_URL}/model/download/info") + model_type = model_info_r.json().get("model_type") + + if model_type not in ["xgboost", "lightgbm"]: + print(f"Skipping TreeLite tests - model type is {model_type}") + return + + print(f"Testing TreeLite models for {model_type}...") + + # Test TTFT TreeLite model + ttft_info_r = requests.get(f"{TRAINING_URL}/model/ttft_treelite/info") + if ttft_info_r.status_code == 200: + ttft_info = ttft_info_r.json() + if ttft_info.get("exists"): + print(f"✓ TTFT TreeLite model available ({ttft_info['size_bytes']} bytes)") + assert ttft_info["size_bytes"] > 0, "TTFT TreeLite model should have non-zero size" + else: + print(f"TTFT TreeLite model not yet generated") + else: + print(f"TTFT TreeLite model endpoint returned status {ttft_info_r.status_code}") + + # Test TPOT TreeLite model + tpot_info_r = requests.get(f"{TRAINING_URL}/model/tpot_treelite/info") + if tpot_info_r.status_code == 200: + tpot_info = tpot_info_r.json() + if tpot_info.get("exists"): + print(f"✓ TPOT TreeLite model available ({tpot_info['size_bytes']} bytes)") + assert tpot_info["size_bytes"] > 0, "TPOT TreeLite model should have non-zero size" + else: + print(f"TPOT TreeLite model not yet generated") + else: + print(f"TPOT TreeLite model endpoint returned status {tpot_info_r.status_code}") + + def test_lightgbm_endpoints_on_training_server(): """Test LightGBM endpoints on training server if LightGBM is being used.""" model_info_r = requests.get(f"{TRAINING_URL}/model/download/info") @@ -1370,6 +1447,7 @@ def test_training_server_flush_error_handling(): ("Training Server Model Info", test_training_server_model_info), ("Training Server Models List", test_training_server_models_list), ("Model Download", test_model_download_from_training_server), + ("TreeLite Models", test_treelite_models_on_training_server), ("Send Training Data", test_add_training_data_to_training_server), ("Model Sync", test_prediction_server_model_sync), ("Predictions", test_prediction_via_prediction_server), @@ -1380,9 +1458,9 @@ def test_training_server_flush_error_handling(): ("Training Metrics", test_training_server_metrics), ("Model Consistency", test_model_consistency_between_servers), ("XGBoost Trees", test_model_specific_endpoints_on_training_server), - ("Flush API", test_training_server_flush_api), + ("Flush API", test_training_server_flush_api), ("Flush Error Handling", test_training_server_flush_error_handling), - + ("Dual Server Model Learns Equation", test_dual_server_quantile_regression_learns_distribution), ("End-to-End Workflow", test_end_to_end_workflow), ("Prediction Stress Test", test_prediction_server_stress_test), diff --git a/latencypredictor/training_server.py b/latencypredictor/training_server.py index 3e1e2751f..0502595dc 100644 --- a/latencypredictor/training_server.py +++ b/latencypredictor/training_server.py @@ -52,11 +52,20 @@ LIGHTGBM_AVAILABLE = False logging.warning("LightGBM not available. Please install with: pip install lightgbm") +try: + import treelite + import treelite.sklearn + TREELITE_AVAILABLE = True +except ImportError: + TREELITE_AVAILABLE = False + logging.warning("TreeLite not available. Please install with: pip install treelite treelite_runtime") + class ModelType(str, Enum): BAYESIAN_RIDGE = "bayesian_ridge" XGBOOST = "xgboost" LIGHTGBM = "lightgbm" + TREELITE = "treelite" class RandomDropDeque(deque): @@ -97,6 +106,8 @@ class Settings: TPOT_MODEL_PATH: str = os.getenv("LATENCY_TPOT_MODEL_PATH", "/tmp/models/tpot.joblib") TTFT_SCALER_PATH: str = os.getenv("LATENCY_TTFT_SCALER_PATH", "/tmp/models/ttft_scaler.joblib") TPOT_SCALER_PATH: str = os.getenv("LATENCY_TPOT_SCALER_PATH", "/tmp/models/tpot_scaler.joblib") + TTFT_TREELITE_PATH: str = os.getenv("LATENCY_TTFT_TREELITE_PATH", "/tmp/models/ttft_treelite.so") + TPOT_TREELITE_PATH: str = os.getenv("LATENCY_TPOT_TREELITE_PATH", "/tmp/models/tpot_treelite.so") RETRAINING_INTERVAL_SEC: int = int(os.getenv("LATENCY_RETRAINING_INTERVAL_SEC", 1800)) MIN_SAMPLES_FOR_RETRAIN_FRESH: int = int(os.getenv("LATENCY_MIN_SAMPLES_FOR_RETRAIN_FRESH", 10)) MIN_SAMPLES_FOR_RETRAIN: int = int(os.getenv("LATENCY_MIN_SAMPLES_FOR_RETRAIN", 1000)) @@ -211,16 +222,16 @@ def __init__(self, model_type: str = None): if model_type is None: model_type = settings.MODEL_TYPE - if model_type not in [ModelType.BAYESIAN_RIDGE, ModelType.XGBOOST, ModelType.LIGHTGBM]: - raise ValueError(f"Invalid model_type: {model_type}. Must be one of {list(ModelType)}") - - if model_type == ModelType.XGBOOST and not XGBOOST_AVAILABLE: + if model_type not in [e.value for e in ModelType]: + raise ValueError(f"Invalid model_type: {model_type}. Must be one of {[e.value for e in ModelType]}") + + if model_type == ModelType.XGBOOST.value and not XGBOOST_AVAILABLE: logging.warning("XGBoost requested but not available. Falling back to Bayesian Ridge.") - model_type = ModelType.BAYESIAN_RIDGE + model_type = ModelType.BAYESIAN_RIDGE.value - if model_type == ModelType.LIGHTGBM and not LIGHTGBM_AVAILABLE: + if model_type == ModelType.LIGHTGBM.value and not LIGHTGBM_AVAILABLE: logging.warning("LightGBM requested but not available. Falling back to Bayesian Ridge.") - model_type = ModelType.BAYESIAN_RIDGE + model_type = ModelType.BAYESIAN_RIDGE.value self.model_type = ModelType(model_type) self.quantile = settings.QUANTILE_ALPHA @@ -395,8 +406,11 @@ def is_ready(self) -> bool: """Checks if all models and scalers are loaded/trained.""" if self.model_type == ModelType.BAYESIAN_RIDGE: return all([self.ttft_model, self.tpot_model, self.ttft_scaler, self.tpot_scaler]) - else: # XGBoost or LightGBM + elif self.model_type in (ModelType.XGBOOST, ModelType.LIGHTGBM): return all([self.ttft_model, self.tpot_model]) + else: + # TREELITE is not a valid training model type + raise ValueError(f"Invalid model_type: {self.model_type}. Use XGBOOST, LIGHTGBM, or BAYESIAN_RIDGE.") @is_ready.setter def is_ready(self, value: bool): @@ -928,18 +942,32 @@ def _save_models_unlocked(self): os.makedirs(os.path.dirname(settings.TTFT_MODEL_PATH), exist_ok=True) joblib.dump(self.ttft_model, settings.TTFT_MODEL_PATH) logging.info("TTFT model saved.") - + # Save model-specific exports if self.model_type == ModelType.XGBOOST: try: booster = self.ttft_model.get_booster() raw_trees = booster.get_dump(dump_format="json") trees = [json.loads(t) for t in raw_trees] - + ttft_json_path = settings.TTFT_MODEL_PATH.replace('.joblib', '_trees.json') with open(ttft_json_path, 'w') as f: json.dump(trees, f, indent=2) logging.info(f"TTFT XGBoost trees saved to {ttft_json_path}") + + # Export to TreeLite for production inference + if TREELITE_AVAILABLE: + try: + tl_model = treelite.frontend.from_xgboost(booster) + tl_model.export_lib( + toolchain='gcc', + libpath=settings.TTFT_TREELITE_PATH, + params={'parallel_comp': 8}, + verbose=False + ) + logging.info(f"TTFT TreeLite model exported to {settings.TTFT_TREELITE_PATH}") + except Exception as e: + logging.error(f"Error exporting TTFT to TreeLite: {e}", exc_info=True) except Exception as e: logging.error(f"Error saving TTFT XGBoost trees: {e}", exc_info=True) @@ -948,18 +976,32 @@ def _save_models_unlocked(self): # Save LightGBM model as text format ttft_txt_path = settings.TTFT_MODEL_PATH.replace('.joblib', '_lgb.txt') self.ttft_model.booster_.save_model(ttft_txt_path) - + # Save feature importances as JSON - feature_names = ['kv_cache_percentage', 'input_token_length', + feature_names = ['kv_cache_percentage', 'input_token_length', 'num_request_waiting', 'num_request_running', 'prefix_cache_score', 'effective_input_tokens', 'prefill_score_bucket'] importances = dict(zip(feature_names, self.ttft_model.feature_importances_)) - + ttft_imp_path = settings.TTFT_MODEL_PATH.replace('.joblib', '_importances.json') with open(ttft_imp_path, 'w') as f: json.dump(importances, f, indent=2) - + logging.info(f"TTFT LightGBM model saved to {ttft_txt_path}") logging.info(f"TTFT LightGBM importances saved to {ttft_imp_path}") + + # Export to TreeLite for production inference + if TREELITE_AVAILABLE: + try: + tl_model = treelite.frontend.from_lightgbm(self.ttft_model.booster_) + tl_model.export_lib( + toolchain='gcc', + libpath=settings.TTFT_TREELITE_PATH, + params={'parallel_comp': 8}, + verbose=False + ) + logging.info(f"TTFT TreeLite model exported to {settings.TTFT_TREELITE_PATH}") + except Exception as e: + logging.error(f"Error exporting TTFT to TreeLite: {e}", exc_info=True) except Exception as e: logging.error(f"Error saving TTFT LightGBM exports: {e}", exc_info=True) @@ -972,18 +1014,32 @@ def _save_models_unlocked(self): os.makedirs(os.path.dirname(settings.TPOT_MODEL_PATH), exist_ok=True) joblib.dump(self.tpot_model, settings.TPOT_MODEL_PATH) logging.info("TPOT model saved.") - + # Save model-specific exports if self.model_type == ModelType.XGBOOST: try: booster = self.tpot_model.get_booster() raw_trees = booster.get_dump(dump_format="json") trees = [json.loads(t) for t in raw_trees] - + tpot_json_path = settings.TPOT_MODEL_PATH.replace('.joblib', '_trees.json') with open(tpot_json_path, 'w') as f: json.dump(trees, f, indent=2) logging.info(f"TPOT XGBoost trees saved to {tpot_json_path}") + + # Export to TreeLite for production inference + if TREELITE_AVAILABLE: + try: + tl_model = treelite.frontend.from_xgboost(booster) + tl_model.export_lib( + toolchain='gcc', + libpath=settings.TPOT_TREELITE_PATH, + params={'parallel_comp': 8}, + verbose=False + ) + logging.info(f"TPOT TreeLite model exported to {settings.TPOT_TREELITE_PATH}") + except Exception as e: + logging.error(f"Error exporting TPOT to TreeLite: {e}", exc_info=True) except Exception as e: logging.error(f"Error saving TPOT XGBoost trees: {e}", exc_info=True) @@ -992,18 +1048,32 @@ def _save_models_unlocked(self): # Save LightGBM model as text format tpot_txt_path = settings.TPOT_MODEL_PATH.replace('.joblib', '_lgb.txt') self.tpot_model.booster_.save_model(tpot_txt_path) - + # Save feature importances as JSON - feature_names = ['kv_cache_percentage', 'input_token_length', + feature_names = ['kv_cache_percentage', 'input_token_length', 'num_request_waiting', 'num_request_running', 'num_tokens_generated'] importances = dict(zip(feature_names, self.tpot_model.feature_importances_)) - + tpot_imp_path = settings.TPOT_MODEL_PATH.replace('.joblib', '_importances.json') with open(tpot_imp_path, 'w') as f: json.dump(importances, f, indent=2) - + logging.info(f"TPOT LightGBM model saved to {tpot_txt_path}") logging.info(f"TPOT LightGBM importances saved to {tpot_imp_path}") + + # Export to TreeLite for production inference + if TREELITE_AVAILABLE: + try: + tl_model = treelite.frontend.from_lightgbm(self.tpot_model.booster_) + tl_model.export_lib( + toolchain='gcc', + libpath=settings.TPOT_TREELITE_PATH, + params={'parallel_comp': 8}, + verbose=False + ) + logging.info(f"TPOT TreeLite model exported to {settings.TPOT_TREELITE_PATH}") + except Exception as e: + logging.error(f"Error exporting TPOT to TreeLite: {e}", exc_info=True) except Exception as e: logging.error(f"Error saving TPOT LightGBM exports: {e}", exc_info=True) @@ -1630,7 +1700,9 @@ async def model_info(model_name: str): "ttft": settings.TTFT_MODEL_PATH, "tpot": settings.TPOT_MODEL_PATH, "ttft_scaler": settings.TTFT_SCALER_PATH, - "tpot_scaler": settings.TPOT_SCALER_PATH + "tpot_scaler": settings.TPOT_SCALER_PATH, + "ttft_treelite": settings.TTFT_TREELITE_PATH, + "tpot_treelite": settings.TPOT_TREELITE_PATH, } if model_name not in model_paths: @@ -1663,22 +1735,30 @@ async def download_model(model_name: str): "ttft": settings.TTFT_MODEL_PATH, "tpot": settings.TPOT_MODEL_PATH, "ttft_scaler": settings.TTFT_SCALER_PATH, - "tpot_scaler": settings.TPOT_SCALER_PATH + "tpot_scaler": settings.TPOT_SCALER_PATH, + "ttft_treelite": settings.TTFT_TREELITE_PATH, + "tpot_treelite": settings.TPOT_TREELITE_PATH, } - + if model_name not in model_paths: raise HTTPException(status_code=404, detail=f"Unknown model: {model_name}") - + model_path = model_paths[model_name] - + if not os.path.exists(model_path): raise HTTPException(status_code=404, detail=f"Model {model_name} not found") - + # Return the file - filename = f"{model_name}.joblib" + if model_name.endswith('_treelite'): + filename = f"{model_name}.so" + media_type = 'application/octet-stream' + else: + filename = f"{model_name}.joblib" + media_type = 'application/octet-stream' + return FileResponse( model_path, - media_type='application/octet-stream', + media_type=media_type, filename=filename ) @@ -1863,4 +1943,4 @@ async def prefix_distribution(): } if __name__ == "__main__": - uvicorn.run("__main__:app", host="0.0.0.0", port=8000, reload=True) \ No newline at end of file + uvicorn.run("__main__:app", host="0.0.0.0", port=8000, reload=True)