Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion latencypredictor/manifests/dual-server-deployment.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@ data:
LATENCY_TPOT_MODEL_PATH: "/models/tpot.joblib"
LATENCY_TTFT_SCALER_PATH: "/models/ttft_scaler.joblib"
LATENCY_TPOT_SCALER_PATH: "/models/tpot_scaler.joblib"
LATENCY_TTFT_TREELITE_PATH: "/models/ttft_treelite.so"
LATENCY_TPOT_TREELITE_PATH: "/models/tpot_treelite.so"
LATENCY_MODEL_TYPE: "xgboost"

---
Expand All @@ -22,7 +24,7 @@ metadata:
name: prediction-server-config
namespace: default
data:
MODEL_SYNC_INTERVAL_SEC: "10" # Download models every 5 seconds
MODEL_SYNC_INTERVAL_SEC: "10" # Download models every 10 seconds
LATENCY_MODEL_TYPE: "xgboost"
PREDICT_HOST: "0.0.0.0"
PREDICT_PORT: "8001"
Expand All @@ -31,6 +33,9 @@ data:
LOCAL_TPOT_MODEL_PATH: "/local_models/tpot.joblib"
LOCAL_TTFT_SCALER_PATH: "/local_models/ttft_scaler.joblib"
LOCAL_TPOT_SCALER_PATH: "/local_models/tpot_scaler.joblib"
LOCAL_TTFT_TREELITE_PATH: "/local_models/ttft_treelite.so"
LOCAL_TPOT_TREELITE_PATH: "/local_models/tpot_treelite.so"
USE_TREELITE: "true" # Enable TreeLite for faster inference
HTTP_TIMEOUT: "30"

---
Expand Down
169 changes: 123 additions & 46 deletions latencypredictor/prediction_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,12 +42,19 @@
except ImportError:
LIGHTGBM_AVAILABLE = False
logging.warning("LightGBM not available. Install with: pip install lightgbm")

try:
import treelite_runtime
TREELITE_AVAILABLE = True
except ImportError:
TREELITE_AVAILABLE = False
logging.warning("TreeLite runtime not available. Install with: pip install treelite_runtime")

class ModelType(str, Enum):
BAYESIAN_RIDGE = "bayesian_ridge"
XGBOOST = "xgboost"
LIGHTGBM = "lightgbm"

TREELITE = "treelite"

class PredictSettings:
"""Configuration for the prediction server."""
Expand All @@ -60,6 +67,11 @@ class PredictSettings:
LOCAL_TPOT_MODEL_PATH: str = os.getenv("LOCAL_TPOT_MODEL_PATH", "/local_models/tpot.joblib")
LOCAL_TTFT_SCALER_PATH: str = os.getenv("LOCAL_TTFT_SCALER_PATH", "/local_models/ttft_scaler.joblib")
LOCAL_TPOT_SCALER_PATH: str = os.getenv("LOCAL_TPOT_SCALER_PATH", "/local_models/tpot_scaler.joblib")
LOCAL_TTFT_TREELITE_PATH: str = os.getenv("LOCAL_TTFT_TREELITE_PATH", "/local_models/ttft_treelite.so")
LOCAL_TPOT_TREELITE_PATH: str = os.getenv("LOCAL_TPOT_TREELITE_PATH", "/local_models/tpot_treelite.so")

# Use TreeLite for inference (preferred for production)
USE_TREELITE: bool = os.getenv("USE_TREELITE", "true").lower() == "true"

# Sync interval and model type
MODEL_SYNC_INTERVAL_SEC: int = int(os.getenv("MODEL_SYNC_INTERVAL_SEC", "10"))
Expand Down Expand Up @@ -94,6 +106,8 @@ def __init__(self):
settings.LOCAL_TPOT_MODEL_PATH,
settings.LOCAL_TTFT_SCALER_PATH,
settings.LOCAL_TPOT_SCALER_PATH,
settings.LOCAL_TTFT_TREELITE_PATH,
settings.LOCAL_TPOT_TREELITE_PATH,
]:
os.makedirs(os.path.dirname(path), exist_ok=True)

Expand Down Expand Up @@ -150,11 +164,21 @@ def sync_models(self) -> bool:
("ttft", settings.LOCAL_TTFT_MODEL_PATH),
("tpot", settings.LOCAL_TPOT_MODEL_PATH),
]

# Sync TreeLite models if enabled
if settings.USE_TREELITE and TREELITE_AVAILABLE:
to_sync += [
("ttft_treelite", settings.LOCAL_TTFT_TREELITE_PATH),
("tpot_treelite", settings.LOCAL_TPOT_TREELITE_PATH),
]

# Sync scalers only for Bayesian Ridge
if settings.MODEL_TYPE == ModelType.BAYESIAN_RIDGE:
to_sync += [
("ttft_scaler", settings.LOCAL_TTFT_SCALER_PATH),
("tpot_scaler", settings.LOCAL_TPOT_SCALER_PATH),
]

for name, path in to_sync:
if self._download_model_if_newer(name, path):
updated = True
Expand Down Expand Up @@ -189,31 +213,42 @@ class LightweightPredictor:
def __init__(self):
mt = settings.MODEL_TYPE
self.prefix_buckets = 4

# Add LightGBM fallback logic
if mt == ModelType.XGBOOST and not XGBOOST_AVAILABLE:
logging.warning("XGBoost not available. Falling back to Bayesian Ridge")
mt = ModelType.BAYESIAN_RIDGE
elif mt == ModelType.LIGHTGBM and not LIGHTGBM_AVAILABLE:
logging.warning("LightGBM not available. Falling back to Bayesian Ridge")
mt = ModelType.BAYESIAN_RIDGE

self.model_type = mt
self.quantile = settings.QUANTILE_ALPHA
self.use_treelite = settings.USE_TREELITE and TREELITE_AVAILABLE and mt in [ModelType.XGBOOST, ModelType.LIGHTGBM]

# Model storage
self.ttft_model = None
self.tpot_model = None
self.ttft_scaler = None
self.tpot_scaler = None

# TreeLite predictors (lightweight, compiled models)
self.ttft_predictor = None
self.tpot_predictor = None

self.lock = threading.RLock()
self.last_load: Optional[datetime] = None
logging.info(f"Predictor type: {self.model_type}, quantile: {self.quantile}")
logging.info(f"Predictor type: {self.model_type}, quantile: {self.quantile}, use_treelite: {self.use_treelite}")

@property
def is_ready(self) -> bool:
with self.lock:
if self.model_type == ModelType.BAYESIAN_RIDGE:
return all([self.ttft_model, self.tpot_model, self.ttft_scaler, self.tpot_scaler])
else: # XGBoost or LightGBM
elif self.use_treelite:
# For TreeLite, we need the compiled predictors
return all([self.ttft_predictor, self.tpot_predictor])
else: # XGBoost or LightGBM without TreeLite
return all([self.ttft_model, self.tpot_model])

def _prepare_features_with_interaction(self, df: pd.DataFrame, model_type: str) -> pd.DataFrame:
Expand Down Expand Up @@ -268,21 +303,38 @@ def _prepare_features_with_interaction(self, df: pd.DataFrame, model_type: str)
def load_models(self) -> bool:
try:
with self.lock:
new_ttft = joblib.load(settings.LOCAL_TTFT_MODEL_PATH) if os.path.exists(settings.LOCAL_TTFT_MODEL_PATH) else None
new_tpot = joblib.load(settings.LOCAL_TPOT_MODEL_PATH) if os.path.exists(settings.LOCAL_TPOT_MODEL_PATH) else None
if self.model_type == ModelType.BAYESIAN_RIDGE:
new_ttft_scaler = joblib.load(settings.LOCAL_TTFT_SCALER_PATH) if os.path.exists(settings.LOCAL_TTFT_SCALER_PATH) else None
new_tpot_scaler = joblib.load(settings.LOCAL_TPOT_SCALER_PATH) if os.path.exists(settings.LOCAL_TPOT_SCALER_PATH) else None
# Load TreeLite models if enabled
if self.use_treelite:
if os.path.exists(settings.LOCAL_TTFT_TREELITE_PATH):
self.ttft_predictor = treelite_runtime.Predictor(settings.LOCAL_TTFT_TREELITE_PATH, nthread=8)
logging.info("TTFT TreeLite model loaded")
else:
logging.warning(f"TreeLite model not found: {settings.LOCAL_TTFT_TREELITE_PATH}")

if os.path.exists(settings.LOCAL_TPOT_TREELITE_PATH):
self.tpot_predictor = treelite_runtime.Predictor(settings.LOCAL_TPOT_TREELITE_PATH, nthread=8)
logging.info("TPOT TreeLite model loaded")
else:
logging.warning(f"TreeLite model not found: {settings.LOCAL_TPOT_TREELITE_PATH}")
else:
new_ttft_scaler = new_tpot_scaler = None
# Load XGBoost/LightGBM/BayesianRidge models
new_ttft = joblib.load(settings.LOCAL_TTFT_MODEL_PATH) if os.path.exists(settings.LOCAL_TTFT_MODEL_PATH) else None
new_tpot = joblib.load(settings.LOCAL_TPOT_MODEL_PATH) if os.path.exists(settings.LOCAL_TPOT_MODEL_PATH) else None

if self.model_type == ModelType.BAYESIAN_RIDGE:
new_ttft_scaler = joblib.load(settings.LOCAL_TTFT_SCALER_PATH) if os.path.exists(settings.LOCAL_TTFT_SCALER_PATH) else None
new_tpot_scaler = joblib.load(settings.LOCAL_TPOT_SCALER_PATH) if os.path.exists(settings.LOCAL_TPOT_SCALER_PATH) else None
else:
new_ttft_scaler = new_tpot_scaler = None

if new_ttft: self.ttft_model = new_ttft
if new_tpot: self.tpot_model = new_tpot
if new_ttft_scaler: self.ttft_scaler = new_ttft_scaler
if new_tpot_scaler: self.tpot_scaler = new_tpot_scaler

if new_ttft: self.ttft_model = new_ttft
if new_tpot: self.tpot_model = new_tpot
if new_ttft_scaler: self.ttft_scaler = new_ttft_scaler
if new_tpot_scaler: self.tpot_scaler = new_tpot_scaler
self.last_load = datetime.now(timezone.utc)
if self.is_ready:
logging.info("Models loaded")
logging.info(f"Models loaded successfully (TreeLite: {self.use_treelite})")
return True
logging.warning("Models missing after load")
return False
Expand All @@ -296,9 +348,9 @@ def predict(self, features: dict) -> Tuple[float, float]:
with self.lock:
if not self.is_ready:
raise HTTPException(status_code=503, detail="Models not ready")

# Validation
required = ['kv_cache_percentage', 'input_token_length', 'num_request_waiting',
required = ['kv_cache_percentage', 'input_token_length', 'num_request_waiting',
'num_request_running', 'num_tokens_generated', 'prefix_cache_score']
for f in required:
if f not in features:
Expand All @@ -314,26 +366,38 @@ def predict(self, features: dict) -> Tuple[float, float]:
'num_request_running': features['num_request_running'],
'prefix_cache_score': features['prefix_cache_score']
}

tpot_raw_data = {
'kv_cache_percentage': features['kv_cache_percentage'],
'input_token_length': features['input_token_length'],
'num_request_waiting': features['num_request_waiting'],
'num_request_running': features['num_request_running'],
'num_tokens_generated': features['num_tokens_generated']
}

# Prepare features with interactions
df_ttft_raw = pd.DataFrame([ttft_raw_data])
df_ttft = self._prepare_features_with_interaction(df_ttft_raw, "ttft")



df_tpot_raw = pd.DataFrame([tpot_raw_data])
df_tpot = self._prepare_features_with_interaction(df_tpot_raw, "tpot")
#df_tpot = pd.DataFrame([tpot_raw_data])

if self.model_type == ModelType.BAYESIAN_RIDGE:

# Use TreeLite for inference if enabled
if self.use_treelite:
# TreeLite expects numpy arrays
ttft_arr = df_ttft.values.astype('float32')
tpot_arr = df_tpot.values.astype('float32')

# Create DMatrix for TreeLite
ttft_dmat = treelite_runtime.DMatrix(ttft_arr)
tpot_dmat = treelite_runtime.DMatrix(tpot_arr)

ttft_pred = self.ttft_predictor.predict(ttft_dmat)
tpot_pred = self.tpot_predictor.predict(tpot_dmat)

return float(ttft_pred[0]), float(tpot_pred[0])

elif self.model_type == ModelType.BAYESIAN_RIDGE:
ttft_for_scale = df_ttft.drop(columns=['prefill_score_bucket'], errors='ignore')
ttft_scaled = self.ttft_scaler.transform(ttft_for_scale)
tpot_scaled = self.tpot_scaler.transform(df_tpot)
Expand All @@ -344,19 +408,19 @@ def predict(self, features: dict) -> Tuple[float, float]:
std_factor = 1.28 if self.quantile == 0.9 else (2.0 if self.quantile == 0.95 else 0.674)
ttft_pred = ttft_pred_mean[0] + std_factor * ttft_std[0]
tpot_pred = tpot_pred_mean[0] + std_factor * tpot_std[0]

return ttft_pred, tpot_pred

elif self.model_type == ModelType.XGBOOST:
ttft_pred = self.ttft_model.predict(df_ttft)
tpot_pred = self.tpot_model.predict(df_tpot)

return ttft_pred[0], tpot_pred[0]

else: # LightGBM
ttft_pred = self.ttft_model.predict(df_ttft)
tpot_pred = self.tpot_model.predict(df_tpot)

return ttft_pred[0], tpot_pred[0]

except ValueError as ve:
Expand All @@ -374,9 +438,9 @@ def predict_batch(self, features_list: List[dict]) -> Tuple[np.ndarray, np.ndarr
with self.lock:
if not self.is_ready:
raise HTTPException(status_code=503, detail="Models not ready")

# Validation
required = ['kv_cache_percentage', 'input_token_length', 'num_request_waiting',
required = ['kv_cache_percentage', 'input_token_length', 'num_request_waiting',
'num_request_running', 'num_tokens_generated', 'prefix_cache_score']
for i, features in enumerate(features_list):
for f in required:
Expand All @@ -388,7 +452,7 @@ def predict_batch(self, features_list: List[dict]) -> Tuple[np.ndarray, np.ndarr
# Create raw feature data (without interaction)
ttft_raw_data = []
tpot_raw_data = []

for features in features_list:
ttft_raw_data.append({
'kv_cache_percentage': features['kv_cache_percentage'],
Expand All @@ -397,48 +461,61 @@ def predict_batch(self, features_list: List[dict]) -> Tuple[np.ndarray, np.ndarr
'num_request_running': features['num_request_running'],
'prefix_cache_score': features['prefix_cache_score']
})

tpot_raw_data.append({
'kv_cache_percentage': features['kv_cache_percentage'],
'input_token_length': features['input_token_length'],
'num_request_waiting': features['num_request_waiting'],
'num_request_running': features['num_request_running'],
'num_tokens_generated': features['num_tokens_generated']
})

# Prepare features with interactions
df_ttft_raw = pd.DataFrame(ttft_raw_data)
df_ttft_batch = self._prepare_features_with_interaction(df_ttft_raw, "ttft")
#df_ttft_batch = pd.DataFrame(ttft_raw_data)


df_tpot_raw = pd.DataFrame(tpot_raw_data)
df_tpot_batch = self._prepare_features_with_interaction(df_tpot_raw, "tpot")
#df_tpot_batch = pd.DataFrame(tpot_raw_data)

if self.model_type == ModelType.BAYESIAN_RIDGE:
# Use TreeLite for batch inference if enabled
if self.use_treelite:
# TreeLite expects numpy arrays
ttft_arr = df_ttft_batch.values.astype('float32')
tpot_arr = df_tpot_batch.values.astype('float32')

# Create DMatrix for TreeLite
ttft_dmat = treelite_runtime.DMatrix(ttft_arr)
tpot_dmat = treelite_runtime.DMatrix(tpot_arr)

ttft_pred = self.ttft_predictor.predict(ttft_dmat)
tpot_pred = self.tpot_predictor.predict(tpot_dmat)

return ttft_pred, tpot_pred

elif self.model_type == ModelType.BAYESIAN_RIDGE:
ttft_for_scale = df_ttft_batch.drop(columns=['prefill_score_bucket'], errors='ignore')
ttft_scaled = self.ttft_scaler.transform(ttft_for_scale)
tpot_scaled = self.tpot_scaler.transform(df_tpot_batch)

ttft_pred_mean, ttft_std = self.ttft_model.predict(ttft_scaled, return_std=True)
tpot_pred_mean, tpot_std = self.tpot_model.predict(tpot_scaled, return_std=True)

std_factor = 1.28 if self.quantile == 0.9 else (2.0 if self.quantile == 0.95 else 0.674)
ttft_pred = ttft_pred_mean + std_factor * ttft_std
tpot_pred = tpot_pred_mean + std_factor * tpot_std

return ttft_pred, tpot_pred

elif self.model_type == ModelType.XGBOOST:
ttft_pred = self.ttft_model.predict(df_ttft_batch)
tpot_pred = self.tpot_model.predict(df_tpot_batch)

return ttft_pred, tpot_pred

else: # LightGBM
ttft_pred = self.ttft_model.predict(df_ttft_batch)
tpot_pred = self.tpot_model.predict(df_tpot_batch)

return ttft_pred, tpot_pred

except ValueError as ve:
Expand Down Expand Up @@ -773,4 +850,4 @@ async def shutdown():


if __name__ == "__main__":
uvicorn.run("__main__:app", host=settings.HOST, port=settings.PORT, reload=True)
uvicorn.run("__main__:app", host=settings.HOST, port=settings.PORT, reload=True)
2 changes: 2 additions & 0 deletions latencypredictor/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -10,3 +10,5 @@ requests
xgboost
aiohttp
lightgbm
treelite
treelite_runtime
Loading