-
Notifications
You must be signed in to change notification settings - Fork 26
fix: critical issues - resource leak, security vulnerability, and tok… #6
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -3,7 +3,6 @@ | |||||||||||||||
| # Toolify: Empower any LLM with function calling capabilities. | ||||||||||||||||
| # Copyright (C) 2025 FunnyCups (https://github.com/funnycups) | ||||||||||||||||
|
|
||||||||||||||||
| import os | ||||||||||||||||
| import re | ||||||||||||||||
| import json | ||||||||||||||||
| import uuid | ||||||||||||||||
|
|
@@ -18,6 +17,7 @@ | |||||||||||||||
| import tiktoken | ||||||||||||||||
| from typing import List, Dict, Any, Optional, Literal, Union | ||||||||||||||||
| from collections import OrderedDict | ||||||||||||||||
| from contextlib import asynccontextmanager | ||||||||||||||||
|
|
||||||||||||||||
| from fastapi import FastAPI, Request, Header, HTTPException, Depends | ||||||||||||||||
| from fastapi.responses import JSONResponse, StreamingResponse | ||||||||||||||||
|
|
@@ -97,11 +97,12 @@ def count_tokens(self, messages: list, model: str = "gpt-3.5-turbo") -> int: | |||||||||||||||
|
|
||||||||||||||||
| def _count_chat_tokens(self, messages: list, encoder, model: str) -> int: | ||||||||||||||||
| """Accurate token calculation for chat models | ||||||||||||||||
|
|
||||||||||||||||
| Based on OpenAI's token counting documentation: | ||||||||||||||||
| - Each message has a fixed overhead | ||||||||||||||||
| - Content tokens are counted per message | ||||||||||||||||
| - Special tokens for message formatting | ||||||||||||||||
| - Tool calls are counted as structured JSON | ||||||||||||||||
| """ | ||||||||||||||||
| # Token overhead varies by model | ||||||||||||||||
| if model.startswith(("gpt-3.5-turbo", "gpt-35-turbo")): | ||||||||||||||||
|
|
@@ -112,11 +113,11 @@ def _count_chat_tokens(self, messages: list, encoder, model: str) -> int: | |||||||||||||||
| # Most models including gpt-4, gpt-4o, o1, etc. | ||||||||||||||||
| tokens_per_message = 3 | ||||||||||||||||
| tokens_per_name = 1 | ||||||||||||||||
|
|
||||||||||||||||
| num_tokens = 0 | ||||||||||||||||
| for message in messages: | ||||||||||||||||
| num_tokens += tokens_per_message | ||||||||||||||||
|
|
||||||||||||||||
| # Count tokens for each field in the message | ||||||||||||||||
| for key, value in message.items(): | ||||||||||||||||
| if key == "content": | ||||||||||||||||
|
|
@@ -136,10 +137,25 @@ def _count_chat_tokens(self, messages: list, encoder, model: str) -> int: | |||||||||||||||
| elif key == "role": | ||||||||||||||||
| # Role is already counted in tokens_per_message | ||||||||||||||||
| pass | ||||||||||||||||
| elif key == "tool_calls": | ||||||||||||||||
| # Count tokens for tool calls structure | ||||||||||||||||
| # Tool calls are serialized as JSON in the model's context | ||||||||||||||||
| if isinstance(value, list): | ||||||||||||||||
| # Convert tool_calls to JSON to count actual tokens | ||||||||||||||||
| tool_calls_json = json.dumps(value, separators=(',', ':')) | ||||||||||||||||
| num_tokens += len(encoder.encode(tool_calls_json)) | ||||||||||||||||
| # Add overhead for tool call formatting (approximately 5 tokens per tool call) | ||||||||||||||||
| num_tokens += len(value) * 5 | ||||||||||||||||
| elif key == "tool_call_id": | ||||||||||||||||
| # Count tokens for tool call ID | ||||||||||||||||
| if isinstance(value, str): | ||||||||||||||||
| num_tokens += len(encoder.encode(value)) | ||||||||||||||||
| # Add overhead for tool response formatting | ||||||||||||||||
| num_tokens += 3 | ||||||||||||||||
| elif isinstance(value, str): | ||||||||||||||||
| # Other string fields | ||||||||||||||||
| num_tokens += len(encoder.encode(value)) | ||||||||||||||||
|
|
||||||||||||||||
| # Every reply is primed with assistant role | ||||||||||||||||
| num_tokens += 3 | ||||||||||||||||
| return num_tokens | ||||||||||||||||
|
|
@@ -801,9 +817,14 @@ def parse_function_calls_xml(xml_string: str, trigger_signal: str) -> Optional[L | |||||||||||||||
| def _coerce_value(v: str): | ||||||||||||||||
| try: | ||||||||||||||||
| return json.loads(v) | ||||||||||||||||
| except Exception: | ||||||||||||||||
| pass | ||||||||||||||||
| return v | ||||||||||||||||
| except json.JSONDecodeError: | ||||||||||||||||
| # Not a JSON value, treat as string (this is expected for many parameters) | ||||||||||||||||
| logger.debug(f"🔧 Parameter value is not JSON, treating as string: {v[:50]}{'...' if len(v) > 50 else ''}") | ||||||||||||||||
| return v | ||||||||||||||||
| except Exception as e: | ||||||||||||||||
| # Unexpected error during parsing | ||||||||||||||||
| logger.warning(f"⚠️ Unexpected error parsing parameter value: {e}") | ||||||||||||||||
| return v | ||||||||||||||||
|
|
||||||||||||||||
| for k, v in arg_matches: | ||||||||||||||||
| args[k] = _coerce_value(v) | ||||||||||||||||
|
|
@@ -861,8 +882,22 @@ def find_upstream(model_name: str) -> tuple[Dict[str, Any], str]: | |||||||||||||||
|
|
||||||||||||||||
| return service, actual_model_name | ||||||||||||||||
|
|
||||||||||||||||
| app = FastAPI() | ||||||||||||||||
| http_client = httpx.AsyncClient() | ||||||||||||||||
| # Global HTTP client (will be initialized in lifespan) | ||||||||||||||||
| http_client = None | ||||||||||||||||
|
|
||||||||||||||||
| @asynccontextmanager | ||||||||||||||||
| async def lifespan(app: FastAPI): | ||||||||||||||||
| """Manage application lifespan - startup and shutdown""" | ||||||||||||||||
| global http_client | ||||||||||||||||
| # Startup | ||||||||||||||||
| http_client = httpx.AsyncClient() | ||||||||||||||||
| logger.info("✅ HTTP client initialized") | ||||||||||||||||
| yield | ||||||||||||||||
| # Shutdown | ||||||||||||||||
| await http_client.aclose() | ||||||||||||||||
| logger.info("✅ HTTP client closed") | ||||||||||||||||
|
|
||||||||||||||||
| app = FastAPI(lifespan=lifespan) | ||||||||||||||||
|
|
||||||||||||||||
| @app.middleware("http") | ||||||||||||||||
| async def debug_middleware(request: Request, call_next): | ||||||||||||||||
|
|
@@ -919,12 +954,24 @@ async def general_exception_handler(request: Request, exc: Exception): | |||||||||||||||
|
|
||||||||||||||||
| async def verify_api_key(authorization: str = Header(...)): | ||||||||||||||||
| """Dependency: verify client API key""" | ||||||||||||||||
| client_key = authorization.replace("Bearer ", "") | ||||||||||||||||
| # Check authorization header format (case-insensitive) | ||||||||||||||||
| if not authorization.lower().startswith("bearer "): | ||||||||||||||||
| raise HTTPException(status_code=401, detail="Invalid authorization header format") | ||||||||||||||||
|
|
||||||||||||||||
| # Extract the token (everything after "Bearer ") | ||||||||||||||||
| client_key = authorization[7:].strip() | ||||||||||||||||
|
|
||||||||||||||||
| # Ensure the key is not empty | ||||||||||||||||
| if not client_key: | ||||||||||||||||
| raise HTTPException(status_code=401, detail="Missing API key") | ||||||||||||||||
|
|
||||||||||||||||
| if app_config.features.key_passthrough: | ||||||||||||||||
| # In passthrough mode, skip allowed_keys check | ||||||||||||||||
| return client_key | ||||||||||||||||
|
|
||||||||||||||||
| if client_key not in ALLOWED_CLIENT_KEYS: | ||||||||||||||||
| raise HTTPException(status_code=401, detail="Unauthorized") | ||||||||||||||||
|
|
||||||||||||||||
| return client_key | ||||||||||||||||
|
|
||||||||||||||||
| def preprocess_messages(messages: List[Dict[str, Any]]) -> List[Dict[str, Any]]: | ||||||||||||||||
|
|
@@ -991,42 +1038,38 @@ async def chat_completions( | |||||||||||||||
| ): | ||||||||||||||||
| """Main chat completion endpoint, proxy and inject function calling capabilities.""" | ||||||||||||||||
| start_time = time.time() | ||||||||||||||||
|
|
||||||||||||||||
| # Count input tokens | ||||||||||||||||
| prompt_tokens = token_counter.count_tokens(body.messages, body.model) | ||||||||||||||||
| logger.info(f"📊 Request to {body.model} - Input tokens: {prompt_tokens}") | ||||||||||||||||
|
|
||||||||||||||||
|
|
||||||||||||||||
| try: | ||||||||||||||||
| logger.debug(f"🔧 Received request, model: {body.model}") | ||||||||||||||||
| logger.debug(f"🔧 Number of messages: {len(body.messages)}") | ||||||||||||||||
| logger.debug(f"🔧 Number of tools: {len(body.tools) if body.tools else 0}") | ||||||||||||||||
| logger.debug(f"🔧 Streaming: {body.stream}") | ||||||||||||||||
|
|
||||||||||||||||
| upstream, actual_model = find_upstream(body.model) | ||||||||||||||||
| upstream_url = f"{upstream['base_url']}/chat/completions" | ||||||||||||||||
|
|
||||||||||||||||
| logger.debug(f"🔧 Starting message preprocessing, original message count: {len(body.messages)}") | ||||||||||||||||
| processed_messages = preprocess_messages(body.messages) | ||||||||||||||||
| logger.debug(f"🔧 Preprocessing completed, processed message count: {len(processed_messages)}") | ||||||||||||||||
|
|
||||||||||||||||
| if not validate_message_structure(processed_messages): | ||||||||||||||||
| logger.error(f"❌ Message structure validation failed, but continuing processing") | ||||||||||||||||
|
|
||||||||||||||||
| request_body_dict = body.model_dump(exclude_unset=True) | ||||||||||||||||
| request_body_dict["model"] = actual_model | ||||||||||||||||
| request_body_dict["messages"] = processed_messages | ||||||||||||||||
| is_fc_enabled = app_config.features.enable_function_calling | ||||||||||||||||
| has_tools_in_request = bool(body.tools) | ||||||||||||||||
| has_function_call = is_fc_enabled and has_tools_in_request | ||||||||||||||||
|
|
||||||||||||||||
| logger.debug(f"🔧 Request body constructed, message count: {len(processed_messages)}") | ||||||||||||||||
|
|
||||||||||||||||
| except Exception as e: | ||||||||||||||||
| logger.error(f"❌ Request preprocessing failed: {str(e)}") | ||||||||||||||||
| logger.error(f"❌ Error type: {type(e).__name__}") | ||||||||||||||||
| if hasattr(app_config, 'debug') and app_config.debug: | ||||||||||||||||
| logger.error(f"❌ Error stack: {traceback.format_exc()}") | ||||||||||||||||
|
|
||||||||||||||||
| return JSONResponse( | ||||||||||||||||
| status_code=422, | ||||||||||||||||
| content={ | ||||||||||||||||
|
|
@@ -1040,9 +1083,9 @@ async def chat_completions( | |||||||||||||||
|
|
||||||||||||||||
| if has_function_call: | ||||||||||||||||
| logger.debug(f"🔧 Using global trigger signal for this request: {GLOBAL_TRIGGER_SIGNAL}") | ||||||||||||||||
|
|
||||||||||||||||
| function_prompt, _ = generate_function_prompt(body.tools, GLOBAL_TRIGGER_SIGNAL) | ||||||||||||||||
|
|
||||||||||||||||
| tool_choice_prompt = safe_process_tool_choice(body.tool_choice) | ||||||||||||||||
| if tool_choice_prompt: | ||||||||||||||||
| function_prompt += tool_choice_prompt | ||||||||||||||||
|
|
@@ -1062,6 +1105,11 @@ async def chat_completions( | |||||||||||||||
| if "tool_choice" in request_body_dict: | ||||||||||||||||
| del request_body_dict["tool_choice"] | ||||||||||||||||
|
|
||||||||||||||||
| # Count input tokens after preprocessing and system prompt injection | ||||||||||||||||
| # This ensures we count the actual messages being sent to upstream | ||||||||||||||||
| prompt_tokens = token_counter.count_tokens(request_body_dict["messages"], body.model) | ||||||||||||||||
| logger.info(f"📊 Request to {body.model} - Input tokens: {prompt_tokens}") | ||||||||||||||||
|
|
||||||||||||||||
| headers = { | ||||||||||||||||
| "Content-Type": "application/json", | ||||||||||||||||
| "Authorization": f"Bearer {_api_key}" if app_config.features.key_passthrough else f"Bearer {upstream['api_key']}", | ||||||||||||||||
|
|
@@ -1087,13 +1135,32 @@ async def chat_completions( | |||||||||||||||
|
|
||||||||||||||||
| # Count output tokens and handle usage | ||||||||||||||||
| completion_text = "" | ||||||||||||||||
| completion_tool_calls = [] | ||||||||||||||||
| if response_json.get("choices") and len(response_json["choices"]) > 0: | ||||||||||||||||
| content = response_json["choices"][0].get("message", {}).get("content") | ||||||||||||||||
| choice_message = response_json["choices"][0].get("message", {}) | ||||||||||||||||
| content = choice_message.get("content") | ||||||||||||||||
| if content: | ||||||||||||||||
| completion_text = content | ||||||||||||||||
|
|
||||||||||||||||
| # Also check for tool_calls in the response | ||||||||||||||||
| if "tool_calls" in choice_message: | ||||||||||||||||
| completion_tool_calls = choice_message.get("tool_calls", []) | ||||||||||||||||
|
|
||||||||||||||||
| # Calculate our estimated tokens | ||||||||||||||||
| estimated_completion_tokens = token_counter.count_text_tokens(completion_text, body.model) if completion_text else 0 | ||||||||||||||||
| # Count content tokens | ||||||||||||||||
| content_tokens = token_counter.count_text_tokens(completion_text, body.model) if completion_text else 0 | ||||||||||||||||
|
|
||||||||||||||||
| # Count tool_calls tokens if present | ||||||||||||||||
| tool_calls_tokens = 0 | ||||||||||||||||
| if completion_tool_calls: | ||||||||||||||||
| # Create a temporary message to count tool_calls tokens | ||||||||||||||||
| temp_msg = {"role": "assistant", "tool_calls": completion_tool_calls} | ||||||||||||||||
| # Count just the tool_calls part | ||||||||||||||||
| tool_calls_tokens = token_counter.count_tokens([temp_msg], body.model) | ||||||||||||||||
| # Subtract the message overhead to get just tool_calls tokens | ||||||||||||||||
| tool_calls_tokens -= 3 # Remove assistant priming overhead | ||||||||||||||||
| logger.debug(f"🔧 Counted {tool_calls_tokens} tokens for {len(completion_tool_calls)} tool calls") | ||||||||||||||||
|
|
||||||||||||||||
| estimated_completion_tokens = content_tokens + tool_calls_tokens | ||||||||||||||||
|
Comment on lines
+1149
to
+1163
|
||||||||||||||||
| estimated_prompt_tokens = prompt_tokens | ||||||||||||||||
| estimated_total_tokens = estimated_prompt_tokens + estimated_completion_tokens | ||||||||||||||||
| elapsed_time = time.time() - start_time | ||||||||||||||||
|
|
@@ -1239,9 +1306,10 @@ async def chat_completions( | |||||||||||||||
| async def stream_with_token_count(): | ||||||||||||||||
| completion_tokens = 0 | ||||||||||||||||
| completion_text = "" | ||||||||||||||||
| completion_tool_calls = [] # Track streamed tool calls | ||||||||||||||||
| done_received = False | ||||||||||||||||
| upstream_usage_chunk = None # Store upstream usage chunk if any | ||||||||||||||||
|
|
||||||||||||||||
| async for chunk in stream_proxy_with_fc_transform(upstream_url, request_body_dict, headers, body.model, has_function_call, GLOBAL_TRIGGER_SIGNAL): | ||||||||||||||||
| # Check if this is the [DONE] marker | ||||||||||||||||
| if chunk.startswith(b"data: "): | ||||||||||||||||
|
|
@@ -1253,7 +1321,7 @@ async def stream_with_token_count(): | |||||||||||||||
| break | ||||||||||||||||
| elif line_data: | ||||||||||||||||
| chunk_json = json.loads(line_data) | ||||||||||||||||
|
|
||||||||||||||||
| # Check if this chunk contains usage information | ||||||||||||||||
| if "usage" in chunk_json: | ||||||||||||||||
| upstream_usage_chunk = chunk_json | ||||||||||||||||
|
|
@@ -1262,21 +1330,44 @@ async def stream_with_token_count(): | |||||||||||||||
| if not ("choices" in chunk_json and len(chunk_json["choices"]) > 0): | ||||||||||||||||
| # Don't yield upstream usage-only chunk now; we'll handle usage later | ||||||||||||||||
| continue | ||||||||||||||||
|
|
||||||||||||||||
| # Process regular content chunks | ||||||||||||||||
| if "choices" in chunk_json and len(chunk_json["choices"]) > 0: | ||||||||||||||||
| delta = chunk_json["choices"][0].get("delta", {}) | ||||||||||||||||
| content = delta.get("content", "") | ||||||||||||||||
| if content: | ||||||||||||||||
| completion_text += content | ||||||||||||||||
|
|
||||||||||||||||
| # Also check for tool_calls in the delta | ||||||||||||||||
| if "tool_calls" in delta: | ||||||||||||||||
| tool_calls_in_delta = delta.get("tool_calls", []) | ||||||||||||||||
| if tool_calls_in_delta: | ||||||||||||||||
| # Tool calls might be streamed incrementally | ||||||||||||||||
| # We'll store them for token counting | ||||||||||||||||
| completion_tool_calls.extend(tool_calls_in_delta) | ||||||||||||||||
| logger.debug(f"🔧 Detected {len(tool_calls_in_delta)} tool calls in stream chunk") | ||||||||||||||||
| except (json.JSONDecodeError, KeyError, UnicodeDecodeError) as e: | ||||||||||||||||
| logger.debug(f"Failed to parse chunk for token counting: {e}") | ||||||||||||||||
| pass | ||||||||||||||||
|
|
||||||||||||||||
| yield chunk | ||||||||||||||||
|
|
||||||||||||||||
| # Calculate our estimated tokens | ||||||||||||||||
| estimated_completion_tokens = token_counter.count_text_tokens(completion_text, body.model) if completion_text else 0 | ||||||||||||||||
| # Count content tokens | ||||||||||||||||
| content_tokens = token_counter.count_text_tokens(completion_text, body.model) if completion_text else 0 | ||||||||||||||||
|
|
||||||||||||||||
| # Count tool_calls tokens if present | ||||||||||||||||
| tool_calls_tokens = 0 | ||||||||||||||||
| if completion_tool_calls: | ||||||||||||||||
| # Create a temporary message to count tool_calls tokens | ||||||||||||||||
| temp_msg = {"role": "assistant", "tool_calls": completion_tool_calls} | ||||||||||||||||
| # Count just the tool_calls part | ||||||||||||||||
| tool_calls_tokens = token_counter.count_tokens([temp_msg], body.model) | ||||||||||||||||
| # Subtract the message overhead to get just tool_calls tokens | ||||||||||||||||
| tool_calls_tokens -= 3 # Remove assistant priming overhead | ||||||||||||||||
| logger.debug(f"🔧 Counted {tool_calls_tokens} tokens for {len(completion_tool_calls)} tool calls in stream") | ||||||||||||||||
|
Comment on lines
+1367
to
+1368
|
||||||||||||||||
| tool_calls_tokens -= 3 # Remove assistant priming overhead | |
| logger.debug(f"🔧 Counted {tool_calls_tokens} tokens for {len(completion_tool_calls)} tool calls in stream") | |
| # Subtract both per-message overhead and reply priming (+3) | |
| # Most models: per-message overhead is 3, gpt-3.5-turbo/gpt-35-turbo: 4 | |
| per_message_overhead = 4 if body.model in ("gpt-3.5-turbo", "gpt-35-turbo") else 3 | |
| tool_calls_tokens -= (per_message_overhead + 3) | |
| logger.debug(f"🔧 Counted {tool_calls_tokens} tokens for {len(completion_tool_calls)} tool calls in stream (subtracted {per_message_overhead}+3 overhead)") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Parsing Authorization by slicing index 7 assumes a single space and no extra whitespace. Use robust parsing to handle multiple spaces/tabs: scheme, _, token = authorization.partition(' '); if scheme.lower() != 'bearer' or not token.strip(): raise HTTPException(...). This avoids edge cases and keeps the check case-insensitive.