@@ -224,15 +224,20 @@ def search_tool(self, query: str, top_k: int = 5) -> str:
224
224
return formatted_result
225
225
226
226
def ask (self , query : str , max_tokens : int = 2048 , top_k : int = 10 , top_p : float = 0.8 , temperature : float = 0.3 ,
227
- model : str = os .getenv ("DEFAULT_MODEL" )) -> str :
227
+ model : str = os .getenv ("DEFAULT_MODEL" )) -> Dict [ str , Any ] :
228
228
"""
229
229
Ask a question using the advanced Memory Alpha RAG system with tool use.
230
+ Returns a dictionary with answer and token usage information.
230
231
"""
231
232
232
233
if not model :
233
234
raise ValueError ("model must be provided or set in DEFAULT_MODEL environment variable." )
234
235
235
236
logger .info (f"Starting tool-enabled RAG for query: { query } " )
237
+
238
+ # Initialize token tracking
239
+ total_input_tokens = 0
240
+ total_output_tokens = 0
236
241
237
242
# Define the search tool
238
243
search_tool_definition = {
@@ -317,6 +322,20 @@ def ask(self, query: str, max_tokens: int = 2048, top_k: int = 10, top_p: float
317
322
logger .info (f"LLM response type: { type (response_message )} " )
318
323
logger .debug (f"Response message content: { response_message .get ('content' , 'No content' )[:200 ]} ..." )
319
324
325
+ # Estimate tokens based on content length
326
+ # Rough estimation: ~4 characters per token for English text
327
+ content = response_message .get ('content' , '' )
328
+ estimated_output_tokens = len (content ) // 4
329
+ total_output_tokens += estimated_output_tokens
330
+
331
+ # Estimate input tokens from current message content
332
+ input_text = ' ' .join ([msg .get ('content' , '' ) for msg in messages ])
333
+ estimated_input_tokens = len (input_text ) // 4
334
+ # Only add the increment from this iteration to avoid double counting
335
+ total_input_tokens = estimated_input_tokens
336
+
337
+ logger .info (f"Estimated tokens - Input: { estimated_input_tokens } , Output: { estimated_output_tokens } " )
338
+
320
339
# Check if the model wants to use a tool
321
340
tool_calls = getattr (response_message , 'tool_calls' , None ) or response_message .get ('tool_calls' )
322
341
if tool_calls :
@@ -377,15 +396,37 @@ def ask(self, query: str, max_tokens: int = 2048, top_k: int = 10, top_p: float
377
396
378
397
self ._update_history (query , final_response )
379
398
logger .info ("Returning final answer" )
380
- return final_response
399
+
400
+ return {
401
+ "answer" : final_response ,
402
+ "token_usage" : {
403
+ "input_tokens" : total_input_tokens ,
404
+ "output_tokens" : total_output_tokens ,
405
+ "total_tokens" : total_input_tokens + total_output_tokens
406
+ }
407
+ }
381
408
382
409
except Exception as e :
383
410
logger .error (f"Chat failed: { e } " )
384
- return f"Error processing query: { str (e )} "
411
+ return {
412
+ "answer" : f"Error processing query: { str (e )} " ,
413
+ "token_usage" : {
414
+ "input_tokens" : total_input_tokens ,
415
+ "output_tokens" : total_output_tokens ,
416
+ "total_tokens" : total_input_tokens + total_output_tokens
417
+ }
418
+ }
385
419
386
420
# Fallback if max iterations reached
387
421
logger .warning (f"Max iterations reached for query: { query } " )
388
- return "Query processing exceeded maximum iterations. Please try a simpler question."
422
+ return {
423
+ "answer" : "Query processing exceeded maximum iterations. Please try a simpler question." ,
424
+ "token_usage" : {
425
+ "input_tokens" : total_input_tokens ,
426
+ "output_tokens" : total_output_tokens ,
427
+ "total_tokens" : total_input_tokens + total_output_tokens
428
+ }
429
+ }
389
430
390
431
def _update_history (self , question : str , answer : str ):
391
432
"""Update conversation history."""
0 commit comments