Skip to content

Commit d6271b4

Browse files
authored
Enhance RAG API to return token usage information and improve error handling in chat script (#9)
1 parent d5dfdca commit d6271b4

File tree

3 files changed

+70
-9
lines changed

3 files changed

+70
-9
lines changed

api/memoryalpha/ask.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -24,14 +24,14 @@ def ask_endpoint_post(request: AskRequest):
2424
Accepts POST requests with JSON payload for cleaner API usage.
2525
"""
2626
try:
27-
answer = rag_instance.ask(
27+
result = rag_instance.ask(
2828
request.question,
2929
max_tokens=request.max_tokens,
3030
top_k=request.top_k,
3131
top_p=request.top_p,
3232
temperature=request.temperature
3333
)
34-
return JSONResponse(content={"response": answer})
34+
return JSONResponse(content=result)
3535
except Exception as e:
3636
return JSONResponse(status_code=500, content={"error": str(e)})
3737

@@ -48,13 +48,13 @@ def ask_endpoint(
4848
Now uses advanced tool-enabled RAG by default for better results.
4949
"""
5050
try:
51-
answer = rag_instance.ask(
51+
result = rag_instance.ask(
5252
question,
5353
max_tokens=max_tokens,
5454
top_k=top_k,
5555
top_p=top_p,
5656
temperature=temperature
5757
)
58-
return JSONResponse(content={"response": answer})
58+
return JSONResponse(content=result)
5959
except Exception as e:
6060
return JSONResponse(status_code=500, content={"error": str(e)})

api/memoryalpha/rag.py

Lines changed: 45 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -224,15 +224,20 @@ def search_tool(self, query: str, top_k: int = 5) -> str:
224224
return formatted_result
225225

226226
def ask(self, query: str, max_tokens: int = 2048, top_k: int = 10, top_p: float = 0.8, temperature: float = 0.3,
227-
model: str = os.getenv("DEFAULT_MODEL")) -> str:
227+
model: str = os.getenv("DEFAULT_MODEL")) -> Dict[str, Any]:
228228
"""
229229
Ask a question using the advanced Memory Alpha RAG system with tool use.
230+
Returns a dictionary with answer and token usage information.
230231
"""
231232

232233
if not model:
233234
raise ValueError("model must be provided or set in DEFAULT_MODEL environment variable.")
234235

235236
logger.info(f"Starting tool-enabled RAG for query: {query}")
237+
238+
# Initialize token tracking
239+
total_input_tokens = 0
240+
total_output_tokens = 0
236241

237242
# Define the search tool
238243
search_tool_definition = {
@@ -317,6 +322,20 @@ def ask(self, query: str, max_tokens: int = 2048, top_k: int = 10, top_p: float
317322
logger.info(f"LLM response type: {type(response_message)}")
318323
logger.debug(f"Response message content: {response_message.get('content', 'No content')[:200]}...")
319324

325+
# Estimate tokens based on content length
326+
# Rough estimation: ~4 characters per token for English text
327+
content = response_message.get('content', '')
328+
estimated_output_tokens = len(content) // 4
329+
total_output_tokens += estimated_output_tokens
330+
331+
# Estimate input tokens from current message content
332+
input_text = ' '.join([msg.get('content', '') for msg in messages])
333+
estimated_input_tokens = len(input_text) // 4
334+
# Only add the increment from this iteration to avoid double counting
335+
total_input_tokens = estimated_input_tokens
336+
337+
logger.info(f"Estimated tokens - Input: {estimated_input_tokens}, Output: {estimated_output_tokens}")
338+
320339
# Check if the model wants to use a tool
321340
tool_calls = getattr(response_message, 'tool_calls', None) or response_message.get('tool_calls')
322341
if tool_calls:
@@ -377,15 +396,37 @@ def ask(self, query: str, max_tokens: int = 2048, top_k: int = 10, top_p: float
377396

378397
self._update_history(query, final_response)
379398
logger.info("Returning final answer")
380-
return final_response
399+
400+
return {
401+
"answer": final_response,
402+
"token_usage": {
403+
"input_tokens": total_input_tokens,
404+
"output_tokens": total_output_tokens,
405+
"total_tokens": total_input_tokens + total_output_tokens
406+
}
407+
}
381408

382409
except Exception as e:
383410
logger.error(f"Chat failed: {e}")
384-
return f"Error processing query: {str(e)}"
411+
return {
412+
"answer": f"Error processing query: {str(e)}",
413+
"token_usage": {
414+
"input_tokens": total_input_tokens,
415+
"output_tokens": total_output_tokens,
416+
"total_tokens": total_input_tokens + total_output_tokens
417+
}
418+
}
385419

386420
# Fallback if max iterations reached
387421
logger.warning(f"Max iterations reached for query: {query}")
388-
return "Query processing exceeded maximum iterations. Please try a simpler question."
422+
return {
423+
"answer": "Query processing exceeded maximum iterations. Please try a simpler question.",
424+
"token_usage": {
425+
"input_tokens": total_input_tokens,
426+
"output_tokens": total_output_tokens,
427+
"total_tokens": total_input_tokens + total_output_tokens
428+
}
429+
}
389430

390431
def _update_history(self, question: str, answer: str):
391432
"""Update conversation history."""

chat.sh

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,10 +22,30 @@ ask_question() {
2222
local response
2323
response=$(curl -s \
2424
"${BASE_URL}/memoryalpha/rag/ask?question=${encoded_question}&thinkingmode=${THINKING_MODE}&max_tokens=${MAX_TOKENS}&top_k=${TOP_K}&top_p=${TOP_P}&temperature=${TEMPERATURE}")
25+
26+
# Check if response is valid JSON
27+
if ! echo "$response" | jq . >/dev/null 2>&1; then
28+
printf "Error: Invalid response received.\n"
29+
printf "Raw response: %s\n" "$response"
30+
echo "----------------------------------------"
31+
return
32+
fi
33+
2534
local answer
26-
answer=$(echo "$response" | jq -r '.response // empty')
35+
answer=$(echo "$response" | jq -r '.answer // empty')
2736
if [[ -n "$answer" ]]; then
2837
printf "%s\n" "$answer"
38+
39+
# Display token usage if available
40+
local input_tokens output_tokens total_tokens
41+
input_tokens=$(echo "$response" | jq -r '.token_usage.input_tokens // empty')
42+
output_tokens=$(echo "$response" | jq -r '.token_usage.output_tokens // empty')
43+
total_tokens=$(echo "$response" | jq -r '.token_usage.total_tokens // empty')
44+
45+
if [[ -n "$input_tokens" && -n "$output_tokens" && -n "$total_tokens" ]]; then
46+
echo
47+
printf "📊 Token Usage: Input: %s | Output: %s | Total: %s\n" "$input_tokens" "$output_tokens" "$total_tokens"
48+
fi
2949
else
3050
local error
3151
error=$(echo "$response" | jq -r '.error // empty')

0 commit comments

Comments
 (0)