Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
167 changes: 129 additions & 38 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
# Toolify: Empower any LLM with function calling capabilities.
# Copyright (C) 2025 FunnyCups (https://github.com/funnycups)

import os
import re
import json
import uuid
Expand All @@ -18,6 +17,7 @@
import tiktoken
from typing import List, Dict, Any, Optional, Literal, Union
from collections import OrderedDict
from contextlib import asynccontextmanager

from fastapi import FastAPI, Request, Header, HTTPException, Depends
from fastapi.responses import JSONResponse, StreamingResponse
Expand Down Expand Up @@ -97,11 +97,12 @@ def count_tokens(self, messages: list, model: str = "gpt-3.5-turbo") -> int:

def _count_chat_tokens(self, messages: list, encoder, model: str) -> int:
"""Accurate token calculation for chat models

Based on OpenAI's token counting documentation:
- Each message has a fixed overhead
- Content tokens are counted per message
- Special tokens for message formatting
- Tool calls are counted as structured JSON
"""
# Token overhead varies by model
if model.startswith(("gpt-3.5-turbo", "gpt-35-turbo")):
Expand All @@ -112,11 +113,11 @@ def _count_chat_tokens(self, messages: list, encoder, model: str) -> int:
# Most models including gpt-4, gpt-4o, o1, etc.
tokens_per_message = 3
tokens_per_name = 1

num_tokens = 0
for message in messages:
num_tokens += tokens_per_message

# Count tokens for each field in the message
for key, value in message.items():
if key == "content":
Expand All @@ -136,10 +137,25 @@ def _count_chat_tokens(self, messages: list, encoder, model: str) -> int:
elif key == "role":
# Role is already counted in tokens_per_message
pass
elif key == "tool_calls":
# Count tokens for tool calls structure
# Tool calls are serialized as JSON in the model's context
if isinstance(value, list):
# Convert tool_calls to JSON to count actual tokens
tool_calls_json = json.dumps(value, separators=(',', ':'))
num_tokens += len(encoder.encode(tool_calls_json))
# Add overhead for tool call formatting (approximately 5 tokens per tool call)
num_tokens += len(value) * 5
elif key == "tool_call_id":
# Count tokens for tool call ID
if isinstance(value, str):
num_tokens += len(encoder.encode(value))
# Add overhead for tool response formatting
num_tokens += 3
elif isinstance(value, str):
# Other string fields
num_tokens += len(encoder.encode(value))

# Every reply is primed with assistant role
num_tokens += 3
return num_tokens
Expand Down Expand Up @@ -801,9 +817,14 @@ def parse_function_calls_xml(xml_string: str, trigger_signal: str) -> Optional[L
def _coerce_value(v: str):
try:
return json.loads(v)
except Exception:
pass
return v
except json.JSONDecodeError:
# Not a JSON value, treat as string (this is expected for many parameters)
logger.debug(f"🔧 Parameter value is not JSON, treating as string: {v[:50]}{'...' if len(v) > 50 else ''}")
return v
except Exception as e:
# Unexpected error during parsing
logger.warning(f"⚠️ Unexpected error parsing parameter value: {e}")
return v

for k, v in arg_matches:
args[k] = _coerce_value(v)
Expand Down Expand Up @@ -861,8 +882,22 @@ def find_upstream(model_name: str) -> tuple[Dict[str, Any], str]:

return service, actual_model_name

app = FastAPI()
http_client = httpx.AsyncClient()
# Global HTTP client (will be initialized in lifespan)
http_client = None

@asynccontextmanager
async def lifespan(app: FastAPI):
"""Manage application lifespan - startup and shutdown"""
global http_client
# Startup
http_client = httpx.AsyncClient()
logger.info("✅ HTTP client initialized")
yield
# Shutdown
await http_client.aclose()
logger.info("✅ HTTP client closed")

app = FastAPI(lifespan=lifespan)

@app.middleware("http")
async def debug_middleware(request: Request, call_next):
Expand Down Expand Up @@ -919,12 +954,24 @@ async def general_exception_handler(request: Request, exc: Exception):

async def verify_api_key(authorization: str = Header(...)):
"""Dependency: verify client API key"""
client_key = authorization.replace("Bearer ", "")
# Check authorization header format (case-insensitive)
if not authorization.lower().startswith("bearer "):
raise HTTPException(status_code=401, detail="Invalid authorization header format")

# Extract the token (everything after "Bearer ")
client_key = authorization[7:].strip()

# Ensure the key is not empty
if not client_key:
raise HTTPException(status_code=401, detail="Missing API key")
Comment on lines +957 to +966
Copy link

Copilot AI Oct 19, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Parsing Authorization by slicing index 7 assumes a single space and no extra whitespace. Use robust parsing to handle multiple spaces/tabs: scheme, _, token = authorization.partition(' '); if scheme.lower() != 'bearer' or not token.strip(): raise HTTPException(...). This avoids edge cases and keeps the check case-insensitive.

Copilot uses AI. Check for mistakes.

if app_config.features.key_passthrough:
# In passthrough mode, skip allowed_keys check
return client_key

if client_key not in ALLOWED_CLIENT_KEYS:
raise HTTPException(status_code=401, detail="Unauthorized")

return client_key

def preprocess_messages(messages: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
Expand Down Expand Up @@ -991,42 +1038,38 @@ async def chat_completions(
):
"""Main chat completion endpoint, proxy and inject function calling capabilities."""
start_time = time.time()

# Count input tokens
prompt_tokens = token_counter.count_tokens(body.messages, body.model)
logger.info(f"📊 Request to {body.model} - Input tokens: {prompt_tokens}")


try:
logger.debug(f"🔧 Received request, model: {body.model}")
logger.debug(f"🔧 Number of messages: {len(body.messages)}")
logger.debug(f"🔧 Number of tools: {len(body.tools) if body.tools else 0}")
logger.debug(f"🔧 Streaming: {body.stream}")

upstream, actual_model = find_upstream(body.model)
upstream_url = f"{upstream['base_url']}/chat/completions"

logger.debug(f"🔧 Starting message preprocessing, original message count: {len(body.messages)}")
processed_messages = preprocess_messages(body.messages)
logger.debug(f"🔧 Preprocessing completed, processed message count: {len(processed_messages)}")

if not validate_message_structure(processed_messages):
logger.error(f"❌ Message structure validation failed, but continuing processing")

request_body_dict = body.model_dump(exclude_unset=True)
request_body_dict["model"] = actual_model
request_body_dict["messages"] = processed_messages
is_fc_enabled = app_config.features.enable_function_calling
has_tools_in_request = bool(body.tools)
has_function_call = is_fc_enabled and has_tools_in_request

logger.debug(f"🔧 Request body constructed, message count: {len(processed_messages)}")

except Exception as e:
logger.error(f"❌ Request preprocessing failed: {str(e)}")
logger.error(f"❌ Error type: {type(e).__name__}")
if hasattr(app_config, 'debug') and app_config.debug:
logger.error(f"❌ Error stack: {traceback.format_exc()}")

return JSONResponse(
status_code=422,
content={
Expand All @@ -1040,9 +1083,9 @@ async def chat_completions(

if has_function_call:
logger.debug(f"🔧 Using global trigger signal for this request: {GLOBAL_TRIGGER_SIGNAL}")

function_prompt, _ = generate_function_prompt(body.tools, GLOBAL_TRIGGER_SIGNAL)

tool_choice_prompt = safe_process_tool_choice(body.tool_choice)
if tool_choice_prompt:
function_prompt += tool_choice_prompt
Expand All @@ -1062,6 +1105,11 @@ async def chat_completions(
if "tool_choice" in request_body_dict:
del request_body_dict["tool_choice"]

# Count input tokens after preprocessing and system prompt injection
# This ensures we count the actual messages being sent to upstream
prompt_tokens = token_counter.count_tokens(request_body_dict["messages"], body.model)
logger.info(f"📊 Request to {body.model} - Input tokens: {prompt_tokens}")

headers = {
"Content-Type": "application/json",
"Authorization": f"Bearer {_api_key}" if app_config.features.key_passthrough else f"Bearer {upstream['api_key']}",
Expand All @@ -1087,13 +1135,32 @@ async def chat_completions(

# Count output tokens and handle usage
completion_text = ""
completion_tool_calls = []
if response_json.get("choices") and len(response_json["choices"]) > 0:
content = response_json["choices"][0].get("message", {}).get("content")
choice_message = response_json["choices"][0].get("message", {})
content = choice_message.get("content")
if content:
completion_text = content

# Also check for tool_calls in the response
if "tool_calls" in choice_message:
completion_tool_calls = choice_message.get("tool_calls", [])

# Calculate our estimated tokens
estimated_completion_tokens = token_counter.count_text_tokens(completion_text, body.model) if completion_text else 0
# Count content tokens
content_tokens = token_counter.count_text_tokens(completion_text, body.model) if completion_text else 0

# Count tool_calls tokens if present
tool_calls_tokens = 0
if completion_tool_calls:
# Create a temporary message to count tool_calls tokens
temp_msg = {"role": "assistant", "tool_calls": completion_tool_calls}
# Count just the tool_calls part
tool_calls_tokens = token_counter.count_tokens([temp_msg], body.model)
# Subtract the message overhead to get just tool_calls tokens
tool_calls_tokens -= 3 # Remove assistant priming overhead
logger.debug(f"🔧 Counted {tool_calls_tokens} tokens for {len(completion_tool_calls)} tool calls")

estimated_completion_tokens = content_tokens + tool_calls_tokens
Comment on lines +1149 to +1163
Copy link

Copilot AI Oct 19, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

tool_calls token isolation subtracts only the reply priming (+3) but not the per-message overhead (3 for most models, 4 for gpt-3.5-turbo/gpt-35-turbo). This overestimates tool_calls tokens by 3–4 tokens. Subtract both the per-message overhead and the reply priming, e.g.: compute msg_overhead = 4 if model startswith(('gpt-3.5-turbo','gpt-35-turbo')) else 3, then do tool_calls_tokens -= (msg_overhead + 3). Even better, add a TokenCounter helper to return these overheads to avoid duplicating model-specific logic.

Copilot uses AI. Check for mistakes.
Comment on lines +1149 to +1163
Copy link

Copilot AI Oct 19, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The logic to estimate completion tokens (content + tool_calls with model-specific overhead corrections) is duplicated here and in the streaming path. Consider extracting a single helper (e.g., TokenCounter.count_tool_calls_tokens(tool_calls, model) and a shared estimate_completion_tokens(content, tool_calls, model)) to centralize the model-specific overhead handling and avoid divergence.

Copilot uses AI. Check for mistakes.
estimated_prompt_tokens = prompt_tokens
estimated_total_tokens = estimated_prompt_tokens + estimated_completion_tokens
elapsed_time = time.time() - start_time
Expand Down Expand Up @@ -1239,9 +1306,10 @@ async def chat_completions(
async def stream_with_token_count():
completion_tokens = 0
completion_text = ""
completion_tool_calls = [] # Track streamed tool calls
done_received = False
upstream_usage_chunk = None # Store upstream usage chunk if any

async for chunk in stream_proxy_with_fc_transform(upstream_url, request_body_dict, headers, body.model, has_function_call, GLOBAL_TRIGGER_SIGNAL):
# Check if this is the [DONE] marker
if chunk.startswith(b"data: "):
Expand All @@ -1253,7 +1321,7 @@ async def stream_with_token_count():
break
elif line_data:
chunk_json = json.loads(line_data)

# Check if this chunk contains usage information
if "usage" in chunk_json:
upstream_usage_chunk = chunk_json
Expand All @@ -1262,21 +1330,44 @@ async def stream_with_token_count():
if not ("choices" in chunk_json and len(chunk_json["choices"]) > 0):
# Don't yield upstream usage-only chunk now; we'll handle usage later
continue

# Process regular content chunks
if "choices" in chunk_json and len(chunk_json["choices"]) > 0:
delta = chunk_json["choices"][0].get("delta", {})
content = delta.get("content", "")
if content:
completion_text += content

# Also check for tool_calls in the delta
if "tool_calls" in delta:
tool_calls_in_delta = delta.get("tool_calls", [])
if tool_calls_in_delta:
# Tool calls might be streamed incrementally
# We'll store them for token counting
completion_tool_calls.extend(tool_calls_in_delta)
logger.debug(f"🔧 Detected {len(tool_calls_in_delta)} tool calls in stream chunk")
except (json.JSONDecodeError, KeyError, UnicodeDecodeError) as e:
logger.debug(f"Failed to parse chunk for token counting: {e}")
pass

yield chunk

# Calculate our estimated tokens
estimated_completion_tokens = token_counter.count_text_tokens(completion_text, body.model) if completion_text else 0
# Count content tokens
content_tokens = token_counter.count_text_tokens(completion_text, body.model) if completion_text else 0

# Count tool_calls tokens if present
tool_calls_tokens = 0
if completion_tool_calls:
# Create a temporary message to count tool_calls tokens
temp_msg = {"role": "assistant", "tool_calls": completion_tool_calls}
# Count just the tool_calls part
tool_calls_tokens = token_counter.count_tokens([temp_msg], body.model)
# Subtract the message overhead to get just tool_calls tokens
tool_calls_tokens -= 3 # Remove assistant priming overhead
logger.debug(f"🔧 Counted {tool_calls_tokens} tokens for {len(completion_tool_calls)} tool calls in stream")
Comment on lines +1367 to +1368
Copy link

Copilot AI Oct 19, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same issue as non-stream path: only the reply priming (+3) is removed, but the per-message overhead isn't. This inflates tool_calls token counts by 3–4 tokens. Subtract both the per-message overhead (3 for most models, 4 for gpt-3.5-turbo/gpt-35-turbo) and the +3 reply priming.

Suggested change
tool_calls_tokens -= 3 # Remove assistant priming overhead
logger.debug(f"🔧 Counted {tool_calls_tokens} tokens for {len(completion_tool_calls)} tool calls in stream")
# Subtract both per-message overhead and reply priming (+3)
# Most models: per-message overhead is 3, gpt-3.5-turbo/gpt-35-turbo: 4
per_message_overhead = 4 if body.model in ("gpt-3.5-turbo", "gpt-35-turbo") else 3
tool_calls_tokens -= (per_message_overhead + 3)
logger.debug(f"🔧 Counted {tool_calls_tokens} tokens for {len(completion_tool_calls)} tool calls in stream (subtracted {per_message_overhead}+3 overhead)")

Copilot uses AI. Check for mistakes.

estimated_completion_tokens = content_tokens + tool_calls_tokens
estimated_prompt_tokens = prompt_tokens
estimated_total_tokens = estimated_prompt_tokens + estimated_completion_tokens
elapsed_time = time.time() - start_time
Expand Down Expand Up @@ -1389,15 +1480,15 @@ def _build_tool_call_sse_chunks(parsed_tools: List[Dict[str, Any]], model_id: st

initial_chunk = {
"id": f"chatcmpl-{uuid.uuid4().hex}", "object": "chat.completion.chunk",
"created": int(os.path.getmtime(__file__)), "model": model_id,
"created": int(time.time()), "model": model_id,
"choices": [{"index": 0, "delta": {"role": "assistant", "content": None, "tool_calls": tool_calls}, "finish_reason": None}],
}
chunks.append(f"data: {json.dumps(initial_chunk)}\n\n")


final_chunk = {
"id": f"chatcmpl-{uuid.uuid4().hex}", "object": "chat.completion.chunk",
"created": int(os.path.getmtime(__file__)), "model": model_id,
"created": int(time.time()), "model": model_id,
"choices": [{"index": 0, "delta": {}, "finish_reason": "tool_calls"}],
}
chunks.append(f"data: {json.dumps(final_chunk)}\n\n")
Expand Down Expand Up @@ -1472,7 +1563,7 @@ def _build_tool_call_sse_chunks(parsed_tools: List[Dict[str, Any]], model_id: st
yield_chunk = {
"id": f"chatcmpl-passthrough-{uuid.uuid4().hex}",
"object": "chat.completion.chunk",
"created": int(os.path.getmtime(__file__)),
"created": int(time.time()),
"model": model,
"choices": [{"index": 0, "delta": {"content": content_to_yield}}]
}
Expand Down Expand Up @@ -1513,7 +1604,7 @@ def _build_tool_call_sse_chunks(parsed_tools: List[Dict[str, Any]], model_id: st
# If stream has ended but buffer still has remaining characters insufficient to form signal, output them
final_yield_chunk = {
"id": f"chatcmpl-finalflush-{uuid.uuid4().hex}", "object": "chat.completion.chunk",
"created": int(os.path.getmtime(__file__)), "model": model,
"created": int(time.time()), "model": model,
"choices": [{"index": 0, "delta": {"content": detector.content_buffer}}]
}
yield f"data: {json.dumps(final_yield_chunk)}\n\n".encode('utf-8')
Expand Down
Loading