diff --git a/main.py b/main.py index 53fd2a3..5afb68a 100644 --- a/main.py +++ b/main.py @@ -3,7 +3,6 @@ # Toolify: Empower any LLM with function calling capabilities. # Copyright (C) 2025 FunnyCups (https://github.com/funnycups) -import os import re import json import uuid @@ -18,6 +17,7 @@ import tiktoken from typing import List, Dict, Any, Optional, Literal, Union from collections import OrderedDict +from contextlib import asynccontextmanager from fastapi import FastAPI, Request, Header, HTTPException, Depends from fastapi.responses import JSONResponse, StreamingResponse @@ -97,11 +97,12 @@ def count_tokens(self, messages: list, model: str = "gpt-3.5-turbo") -> int: def _count_chat_tokens(self, messages: list, encoder, model: str) -> int: """Accurate token calculation for chat models - + Based on OpenAI's token counting documentation: - Each message has a fixed overhead - Content tokens are counted per message - Special tokens for message formatting + - Tool calls are counted as structured JSON """ # Token overhead varies by model if model.startswith(("gpt-3.5-turbo", "gpt-35-turbo")): @@ -112,11 +113,11 @@ def _count_chat_tokens(self, messages: list, encoder, model: str) -> int: # Most models including gpt-4, gpt-4o, o1, etc. tokens_per_message = 3 tokens_per_name = 1 - + num_tokens = 0 for message in messages: num_tokens += tokens_per_message - + # Count tokens for each field in the message for key, value in message.items(): if key == "content": @@ -136,10 +137,25 @@ def _count_chat_tokens(self, messages: list, encoder, model: str) -> int: elif key == "role": # Role is already counted in tokens_per_message pass + elif key == "tool_calls": + # Count tokens for tool calls structure + # Tool calls are serialized as JSON in the model's context + if isinstance(value, list): + # Convert tool_calls to JSON to count actual tokens + tool_calls_json = json.dumps(value, separators=(',', ':')) + num_tokens += len(encoder.encode(tool_calls_json)) + # Add overhead for tool call formatting (approximately 5 tokens per tool call) + num_tokens += len(value) * 5 + elif key == "tool_call_id": + # Count tokens for tool call ID + if isinstance(value, str): + num_tokens += len(encoder.encode(value)) + # Add overhead for tool response formatting + num_tokens += 3 elif isinstance(value, str): # Other string fields num_tokens += len(encoder.encode(value)) - + # Every reply is primed with assistant role num_tokens += 3 return num_tokens @@ -801,9 +817,14 @@ def parse_function_calls_xml(xml_string: str, trigger_signal: str) -> Optional[L def _coerce_value(v: str): try: return json.loads(v) - except Exception: - pass - return v + except json.JSONDecodeError: + # Not a JSON value, treat as string (this is expected for many parameters) + logger.debug(f"🔧 Parameter value is not JSON, treating as string: {v[:50]}{'...' if len(v) > 50 else ''}") + return v + except Exception as e: + # Unexpected error during parsing + logger.warning(f"⚠️ Unexpected error parsing parameter value: {e}") + return v for k, v in arg_matches: args[k] = _coerce_value(v) @@ -861,8 +882,22 @@ def find_upstream(model_name: str) -> tuple[Dict[str, Any], str]: return service, actual_model_name -app = FastAPI() -http_client = httpx.AsyncClient() +# Global HTTP client (will be initialized in lifespan) +http_client = None + +@asynccontextmanager +async def lifespan(app: FastAPI): + """Manage application lifespan - startup and shutdown""" + global http_client + # Startup + http_client = httpx.AsyncClient() + logger.info("✅ HTTP client initialized") + yield + # Shutdown + await http_client.aclose() + logger.info("✅ HTTP client closed") + +app = FastAPI(lifespan=lifespan) @app.middleware("http") async def debug_middleware(request: Request, call_next): @@ -919,12 +954,24 @@ async def general_exception_handler(request: Request, exc: Exception): async def verify_api_key(authorization: str = Header(...)): """Dependency: verify client API key""" - client_key = authorization.replace("Bearer ", "") + # Check authorization header format (case-insensitive) + if not authorization.lower().startswith("bearer "): + raise HTTPException(status_code=401, detail="Invalid authorization header format") + + # Extract the token (everything after "Bearer ") + client_key = authorization[7:].strip() + + # Ensure the key is not empty + if not client_key: + raise HTTPException(status_code=401, detail="Missing API key") + if app_config.features.key_passthrough: # In passthrough mode, skip allowed_keys check return client_key + if client_key not in ALLOWED_CLIENT_KEYS: raise HTTPException(status_code=401, detail="Unauthorized") + return client_key def preprocess_messages(messages: List[Dict[str, Any]]) -> List[Dict[str, Any]]: @@ -991,42 +1038,38 @@ async def chat_completions( ): """Main chat completion endpoint, proxy and inject function calling capabilities.""" start_time = time.time() - - # Count input tokens - prompt_tokens = token_counter.count_tokens(body.messages, body.model) - logger.info(f"📊 Request to {body.model} - Input tokens: {prompt_tokens}") - + try: logger.debug(f"🔧 Received request, model: {body.model}") logger.debug(f"🔧 Number of messages: {len(body.messages)}") logger.debug(f"🔧 Number of tools: {len(body.tools) if body.tools else 0}") logger.debug(f"🔧 Streaming: {body.stream}") - + upstream, actual_model = find_upstream(body.model) upstream_url = f"{upstream['base_url']}/chat/completions" - + logger.debug(f"🔧 Starting message preprocessing, original message count: {len(body.messages)}") processed_messages = preprocess_messages(body.messages) logger.debug(f"🔧 Preprocessing completed, processed message count: {len(processed_messages)}") - + if not validate_message_structure(processed_messages): logger.error(f"❌ Message structure validation failed, but continuing processing") - + request_body_dict = body.model_dump(exclude_unset=True) request_body_dict["model"] = actual_model request_body_dict["messages"] = processed_messages is_fc_enabled = app_config.features.enable_function_calling has_tools_in_request = bool(body.tools) has_function_call = is_fc_enabled and has_tools_in_request - + logger.debug(f"🔧 Request body constructed, message count: {len(processed_messages)}") - + except Exception as e: logger.error(f"❌ Request preprocessing failed: {str(e)}") logger.error(f"❌ Error type: {type(e).__name__}") if hasattr(app_config, 'debug') and app_config.debug: logger.error(f"❌ Error stack: {traceback.format_exc()}") - + return JSONResponse( status_code=422, content={ @@ -1040,9 +1083,9 @@ async def chat_completions( if has_function_call: logger.debug(f"🔧 Using global trigger signal for this request: {GLOBAL_TRIGGER_SIGNAL}") - + function_prompt, _ = generate_function_prompt(body.tools, GLOBAL_TRIGGER_SIGNAL) - + tool_choice_prompt = safe_process_tool_choice(body.tool_choice) if tool_choice_prompt: function_prompt += tool_choice_prompt @@ -1062,6 +1105,11 @@ async def chat_completions( if "tool_choice" in request_body_dict: del request_body_dict["tool_choice"] + # Count input tokens after preprocessing and system prompt injection + # This ensures we count the actual messages being sent to upstream + prompt_tokens = token_counter.count_tokens(request_body_dict["messages"], body.model) + logger.info(f"📊 Request to {body.model} - Input tokens: {prompt_tokens}") + headers = { "Content-Type": "application/json", "Authorization": f"Bearer {_api_key}" if app_config.features.key_passthrough else f"Bearer {upstream['api_key']}", @@ -1087,13 +1135,32 @@ async def chat_completions( # Count output tokens and handle usage completion_text = "" + completion_tool_calls = [] if response_json.get("choices") and len(response_json["choices"]) > 0: - content = response_json["choices"][0].get("message", {}).get("content") + choice_message = response_json["choices"][0].get("message", {}) + content = choice_message.get("content") if content: completion_text = content - + # Also check for tool_calls in the response + if "tool_calls" in choice_message: + completion_tool_calls = choice_message.get("tool_calls", []) + # Calculate our estimated tokens - estimated_completion_tokens = token_counter.count_text_tokens(completion_text, body.model) if completion_text else 0 + # Count content tokens + content_tokens = token_counter.count_text_tokens(completion_text, body.model) if completion_text else 0 + + # Count tool_calls tokens if present + tool_calls_tokens = 0 + if completion_tool_calls: + # Create a temporary message to count tool_calls tokens + temp_msg = {"role": "assistant", "tool_calls": completion_tool_calls} + # Count just the tool_calls part + tool_calls_tokens = token_counter.count_tokens([temp_msg], body.model) + # Subtract the message overhead to get just tool_calls tokens + tool_calls_tokens -= 3 # Remove assistant priming overhead + logger.debug(f"🔧 Counted {tool_calls_tokens} tokens for {len(completion_tool_calls)} tool calls") + + estimated_completion_tokens = content_tokens + tool_calls_tokens estimated_prompt_tokens = prompt_tokens estimated_total_tokens = estimated_prompt_tokens + estimated_completion_tokens elapsed_time = time.time() - start_time @@ -1239,9 +1306,10 @@ async def chat_completions( async def stream_with_token_count(): completion_tokens = 0 completion_text = "" + completion_tool_calls = [] # Track streamed tool calls done_received = False upstream_usage_chunk = None # Store upstream usage chunk if any - + async for chunk in stream_proxy_with_fc_transform(upstream_url, request_body_dict, headers, body.model, has_function_call, GLOBAL_TRIGGER_SIGNAL): # Check if this is the [DONE] marker if chunk.startswith(b"data: "): @@ -1253,7 +1321,7 @@ async def stream_with_token_count(): break elif line_data: chunk_json = json.loads(line_data) - + # Check if this chunk contains usage information if "usage" in chunk_json: upstream_usage_chunk = chunk_json @@ -1262,21 +1330,44 @@ async def stream_with_token_count(): if not ("choices" in chunk_json and len(chunk_json["choices"]) > 0): # Don't yield upstream usage-only chunk now; we'll handle usage later continue - + # Process regular content chunks if "choices" in chunk_json and len(chunk_json["choices"]) > 0: delta = chunk_json["choices"][0].get("delta", {}) content = delta.get("content", "") if content: completion_text += content + + # Also check for tool_calls in the delta + if "tool_calls" in delta: + tool_calls_in_delta = delta.get("tool_calls", []) + if tool_calls_in_delta: + # Tool calls might be streamed incrementally + # We'll store them for token counting + completion_tool_calls.extend(tool_calls_in_delta) + logger.debug(f"🔧 Detected {len(tool_calls_in_delta)} tool calls in stream chunk") except (json.JSONDecodeError, KeyError, UnicodeDecodeError) as e: logger.debug(f"Failed to parse chunk for token counting: {e}") pass - + yield chunk - + # Calculate our estimated tokens - estimated_completion_tokens = token_counter.count_text_tokens(completion_text, body.model) if completion_text else 0 + # Count content tokens + content_tokens = token_counter.count_text_tokens(completion_text, body.model) if completion_text else 0 + + # Count tool_calls tokens if present + tool_calls_tokens = 0 + if completion_tool_calls: + # Create a temporary message to count tool_calls tokens + temp_msg = {"role": "assistant", "tool_calls": completion_tool_calls} + # Count just the tool_calls part + tool_calls_tokens = token_counter.count_tokens([temp_msg], body.model) + # Subtract the message overhead to get just tool_calls tokens + tool_calls_tokens -= 3 # Remove assistant priming overhead + logger.debug(f"🔧 Counted {tool_calls_tokens} tokens for {len(completion_tool_calls)} tool calls in stream") + + estimated_completion_tokens = content_tokens + tool_calls_tokens estimated_prompt_tokens = prompt_tokens estimated_total_tokens = estimated_prompt_tokens + estimated_completion_tokens elapsed_time = time.time() - start_time @@ -1389,7 +1480,7 @@ def _build_tool_call_sse_chunks(parsed_tools: List[Dict[str, Any]], model_id: st initial_chunk = { "id": f"chatcmpl-{uuid.uuid4().hex}", "object": "chat.completion.chunk", - "created": int(os.path.getmtime(__file__)), "model": model_id, + "created": int(time.time()), "model": model_id, "choices": [{"index": 0, "delta": {"role": "assistant", "content": None, "tool_calls": tool_calls}, "finish_reason": None}], } chunks.append(f"data: {json.dumps(initial_chunk)}\n\n") @@ -1397,7 +1488,7 @@ def _build_tool_call_sse_chunks(parsed_tools: List[Dict[str, Any]], model_id: st final_chunk = { "id": f"chatcmpl-{uuid.uuid4().hex}", "object": "chat.completion.chunk", - "created": int(os.path.getmtime(__file__)), "model": model_id, + "created": int(time.time()), "model": model_id, "choices": [{"index": 0, "delta": {}, "finish_reason": "tool_calls"}], } chunks.append(f"data: {json.dumps(final_chunk)}\n\n") @@ -1472,7 +1563,7 @@ def _build_tool_call_sse_chunks(parsed_tools: List[Dict[str, Any]], model_id: st yield_chunk = { "id": f"chatcmpl-passthrough-{uuid.uuid4().hex}", "object": "chat.completion.chunk", - "created": int(os.path.getmtime(__file__)), + "created": int(time.time()), "model": model, "choices": [{"index": 0, "delta": {"content": content_to_yield}}] } @@ -1513,7 +1604,7 @@ def _build_tool_call_sse_chunks(parsed_tools: List[Dict[str, Any]], model_id: st # If stream has ended but buffer still has remaining characters insufficient to form signal, output them final_yield_chunk = { "id": f"chatcmpl-finalflush-{uuid.uuid4().hex}", "object": "chat.completion.chunk", - "created": int(os.path.getmtime(__file__)), "model": model, + "created": int(time.time()), "model": model, "choices": [{"index": 0, "delta": {"content": detector.content_buffer}}] } yield f"data: {json.dumps(final_yield_chunk)}\n\n".encode('utf-8') diff --git a/test_token_counting.py b/test_token_counting.py new file mode 100644 index 0000000..32d584f --- /dev/null +++ b/test_token_counting.py @@ -0,0 +1,145 @@ +#!/usr/bin/env python3 +""" +Test script to verify token counting accuracy for tool calls +""" + +import json +import sys +import os + +# Add parent directory to path +sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) + +from main import TokenCounter + +def test_token_counting(): + """Test token counting with various message types including tool calls""" + + token_counter = TokenCounter() + + # Test 1: Simple message without tool calls + print("=" * 60) + print("Test 1: Simple message without tool calls") + messages1 = [ + {"role": "user", "content": "Hello, how are you?"}, + {"role": "assistant", "content": "I'm doing well, thank you!"} + ] + tokens1 = token_counter.count_tokens(messages1, "gpt-4") + print(f"Messages: {json.dumps(messages1, indent=2)}") + print(f"Token count: {tokens1}") + print() + + # Test 2: Message with tool calls + print("=" * 60) + print("Test 2: Message with tool calls") + messages2 = [ + {"role": "user", "content": "What's the weather like in New York?"}, + { + "role": "assistant", + "content": None, + "tool_calls": [ + { + "id": "call_abc123", + "type": "function", + "function": { + "name": "get_weather", + "arguments": json.dumps({"location": "New York", "units": "celsius"}) + } + } + ] + } + ] + tokens2 = token_counter.count_tokens(messages2, "gpt-4") + print(f"Messages: {json.dumps(messages2, indent=2)}") + print(f"Token count: {tokens2}") + print() + + # Test 3: Message with tool response + print("=" * 60) + print("Test 3: Message with tool response") + messages3 = [ + {"role": "user", "content": "What's the weather like in New York?"}, + { + "role": "assistant", + "content": None, + "tool_calls": [ + { + "id": "call_abc123", + "type": "function", + "function": { + "name": "get_weather", + "arguments": json.dumps({"location": "New York", "units": "celsius"}) + } + } + ] + }, + { + "role": "tool", + "tool_call_id": "call_abc123", + "content": "Temperature: 22°C, Condition: Partly cloudy" + } + ] + tokens3 = token_counter.count_tokens(messages3, "gpt-4") + print(f"Messages: {json.dumps(messages3, indent=2)}") + print(f"Token count: {tokens3}") + print() + + # Test 4: Multiple tool calls + print("=" * 60) + print("Test 4: Multiple tool calls") + messages4 = [ + {"role": "user", "content": "Compare the weather in New York and London"}, + { + "role": "assistant", + "content": "I'll check the weather in both cities for you.", + "tool_calls": [ + { + "id": "call_123", + "type": "function", + "function": { + "name": "get_weather", + "arguments": json.dumps({"location": "New York", "units": "celsius"}) + } + }, + { + "id": "call_456", + "type": "function", + "function": { + "name": "get_weather", + "arguments": json.dumps({"location": "London", "units": "celsius"}) + } + } + ] + } + ] + tokens4 = token_counter.count_tokens(messages4, "gpt-4") + print(f"Messages: {json.dumps(messages4, indent=2)}") + print(f"Token count: {tokens4}") + print() + + # Compare token counts + print("=" * 60) + print("Summary:") + print(f"Test 1 (no tools): {tokens1} tokens") + print(f"Test 2 (1 tool call): {tokens2} tokens") + print(f"Test 3 (tool + response): {tokens3} tokens") + print(f"Test 4 (2 tool calls + content): {tokens4} tokens") + print() + print("Token difference analysis:") + print(f"Tool call overhead (Test 2 - Test 1): {tokens2 - tokens1} tokens") + print(f"Tool response overhead (Test 3 - Test 2): {tokens3 - tokens2} tokens") + print(f"Multiple tools overhead (Test 4 vs Test 2): {tokens4 - tokens2} tokens") + + # Test with different models + print() + print("=" * 60) + print("Model comparison for Test 2:") + for model in ["gpt-4", "gpt-3.5-turbo", "o1", "o3"]: + try: + tokens = token_counter.count_tokens(messages2, model) + print(f" {model}: {tokens} tokens") + except Exception as e: + print(f" {model}: Error - {e}") + +if __name__ == "__main__": + test_token_counting()