16
16
from pathlib import Path
17
17
import importlib
18
18
import os
19
+ from .memory import Memory
19
20
20
21
# Configure logging
21
22
logging .basicConfig (level = logging .INFO )
@@ -48,6 +49,13 @@ class Agent(BaseModel):
48
49
semantic_model : Optional [Any ] = Field (None , description = "SentenceTransformer model for semantic matching." )
49
50
team : Optional [List ['Agent' ]] = Field (None , description = "List of assistants in the team." )
50
51
auto_tool : bool = Field (False , description = "Whether to automatically detect and call tools." )
52
+ memory : Memory = Field (default_factory = Memory )
53
+ memory_config : Dict = Field (
54
+ default_factory = lambda : {
55
+ "max_context_length" : 4000 ,
56
+ "summarization_threshold" : 3000
57
+ }
58
+ )
51
59
52
60
# Allow arbitrary types
53
61
model_config = ConfigDict (arbitrary_types_allowed = True )
@@ -56,6 +64,11 @@ def __init__(self, **kwargs):
56
64
super ().__init__ (** kwargs )
57
65
# Initialize the model and tools here if needed
58
66
self ._initialize_model ()
67
+ # Initialize memory with config
68
+ self .memory = Memory (
69
+ max_context_length = self .memory_config .get ("max_context_length" , 4000 ),
70
+ summarization_threshold = self .memory_config .get ("summarization_threshold" , 3000 )
71
+ )
59
72
# Initialize tools as an empty list if not provided
60
73
if self .tools is None :
61
74
self .tools = []
@@ -218,20 +231,31 @@ def print_response(
218
231
markdown : bool = False ,
219
232
team : Optional [List ['Agent' ]] = None ,
220
233
** kwargs ,
221
- ) -> Union [str , Dict ]: # Add return type hint
234
+ ) -> Union [str , Dict ]:
222
235
"""Print the agent's response to the console and return it."""
236
+
237
+ # Store user message if provided
238
+ if message and isinstance (message , str ):
239
+ self .memory .add_message (role = "user" , content = message )
223
240
224
241
if stream :
225
242
# Handle streaming response
226
243
response = ""
227
244
for chunk in self ._stream_response (message , markdown = markdown , ** kwargs ):
228
- print (chunk )
245
+ print (chunk , end = "" , flush = True )
229
246
response += chunk
247
+ # Store agent response
248
+ if response :
249
+ self .memory .add_message (role = "assistant" , content = response )
250
+ print () # New line after streaming
230
251
return response
231
252
else :
232
253
# Generate and return the response
233
254
response = self ._generate_response (message , markdown = markdown , team = team , ** kwargs )
234
255
print (response ) # Print the response to the console
256
+ # Store agent response
257
+ if response :
258
+ self .memory .add_message (role = "assistant" , content = response )
235
259
return response
236
260
237
261
@@ -294,12 +318,10 @@ def _generate_response(self, message: str, markdown: bool = False, team: Optiona
294
318
# Use the specified team if provided
295
319
if team is not None :
296
320
return self ._generate_team_response (message , team , markdown = markdown , ** kwargs )
297
-
298
321
# Initialize tool_outputs as an empty dictionary
299
322
tool_outputs = {}
300
323
responses = []
301
324
tool_calls = []
302
-
303
325
# Use the LLM to analyze the query and dynamically select tools when auto_tool is enabled
304
326
if self .auto_tool :
305
327
tool_calls = self ._analyze_query_and_select_tools (message )
@@ -347,13 +369,17 @@ def _generate_response(self, message: str, markdown: bool = False, team: Optiona
347
369
try :
348
370
# Prepare the context for the LLM
349
371
context = {
372
+ "conversation_history" : self .memory .get_context (self .llm_instance ),
350
373
"tool_outputs" : tool_outputs ,
351
374
"rag_context" : self .rag .retrieve (message ) if self .rag else None ,
352
- "knowledge_base_context " : self ._find_all_relevant_keys (message , self . _flatten_data ( self . knowledge_base ) ) if self .knowledge_base else None ,
375
+ "knowledge_base " : self ._get_knowledge_context (message ) if self .knowledge_base else None ,
353
376
}
354
-
377
+ # 3. Build a memory-aware prompt.
378
+ prompt = self ._build_memory_prompt (message , context )
379
+ # To (convert MemoryEntry objects to dicts and remove metadata):
380
+ memory_entries = [{"role" : e .role , "content" : e .content } for e in self .memory .storage .retrieve ()]
355
381
# Generate a response using the LLM
356
- llm_response = self .llm_instance .generate (prompt = message , context = context , ** kwargs )
382
+ llm_response = self .llm_instance .generate (prompt = prompt , context = context , memory = memory_entries , ** kwargs )
357
383
responses .append (f"**Analysis:**\n \n { llm_response } " )
358
384
except Exception as e :
359
385
logger .error (f"Failed to generate LLM response: { e } " )
@@ -363,25 +389,30 @@ def _generate_response(self, message: str, markdown: bool = False, team: Optiona
363
389
# Retrieve relevant context using RAG
364
390
rag_context = self .rag .retrieve (message ) if self .rag else None
365
391
# Retrieve relevant context from the knowledge base (API result)
366
- knowledge_base_context = None
367
- if self .knowledge_base :
368
- # Flatten the knowledge base
369
- flattened_data = self ._flatten_data (self .knowledge_base )
370
- # Find all relevant key-value pairs in the knowledge base
371
- relevant_values = self ._find_all_relevant_keys (message , flattened_data )
372
- if relevant_values :
373
- knowledge_base_context = ", " .join (relevant_values )
392
+ # knowledge_base_context = None
393
+ # if self.knowledge_base:
394
+ # # Flatten the knowledge base
395
+ # flattened_data = self._flatten_data(self.knowledge_base)
396
+ # # Find all relevant key-value pairs in the knowledge base
397
+ # relevant_values = self._find_all_relevant_keys(message, flattened_data)
398
+ # if relevant_values:
399
+ # knowledge_base_context = ", ".join(relevant_values)
374
400
375
401
# Combine both contexts (RAG and knowledge base)
376
402
context = {
403
+ "conversation_history" : self .memory .get_context (self .llm_instance ),
377
404
"rag_context" : rag_context ,
378
- "knowledge_base_context " : knowledge_base_context ,
405
+ "knowledge_base " : self . _get_knowledge_context ( message ) ,
379
406
}
380
407
# Prepare the prompt with instructions, description, and context
381
- prompt = self ._build_prompt (message , context )
408
+ # 3. Build a memory-aware prompt.
409
+ prompt = self ._build_memory_prompt (message , context )
410
+ # To (convert MemoryEntry objects to dicts and remove metadata):
411
+ memory_entries = [{"role" : e .role , "content" : e .content } for e in self .memory .storage .retrieve ()]
382
412
383
413
# Generate the response using the LLM
384
- response = self .llm_instance .generate (prompt = prompt , context = context , ** kwargs )
414
+ response = self .llm_instance .generate (prompt = prompt , context = context , memory = memory_entries , ** kwargs )
415
+
385
416
386
417
# Format the response based on the json_output flag
387
418
if self .json_output :
@@ -394,9 +425,37 @@ def _generate_response(self, message: str, markdown: bool = False, team: Optiona
394
425
if markdown :
395
426
return f"**Response:**\n \n { response } "
396
427
return response
397
- # Combine all responses into a single string
398
428
return "\n \n " .join (responses )
399
429
430
+ # Modified prompt construction with memory integration
431
+ def _build_memory_prompt (self , user_input : str , context : dict ) -> str :
432
+ """Enhanced prompt builder with memory context."""
433
+ prompt_parts = []
434
+
435
+ if self .description :
436
+ prompt_parts .append (f"# ROLE\n { self .description } " )
437
+
438
+ if self .instructions :
439
+ prompt_parts .append (f"# INSTRUCTIONS\n " + "\n " .join (f"- { i } " for i in self .instructions ))
440
+
441
+ if context ['conversation_history' ]:
442
+ prompt_parts .append (f"# CONVERSATION HISTORY\n { context ['conversation_history' ]} " )
443
+
444
+ if context ['knowledge_base' ]:
445
+ prompt_parts .append (f"# KNOWLEDGE BASE\n { context ['knowledge_base' ]} " )
446
+
447
+ prompt_parts .append (f"# USER INPUT\n { user_input } " )
448
+
449
+ return "\n \n " .join (prompt_parts )
450
+
451
+ def _get_knowledge_context (self , message : str ) -> str :
452
+ """Retrieve and format knowledge base context."""
453
+ if not self .knowledge_base :
454
+ return ""
455
+
456
+ flattened = self ._flatten_data (self .knowledge_base )
457
+ relevant = self ._find_all_relevant_keys (message , flattened )
458
+ return "\n " .join (f"- { item } " for item in relevant ) if relevant else ""
400
459
def _generate_team_response (self , message : str , team : List ['Agent' ], markdown : bool = False , ** kwargs ) -> str :
401
460
"""Generate a response using a team of assistants."""
402
461
responses = []
@@ -543,17 +602,21 @@ def cli_app(
543
602
"""Run the agent in a CLI app."""
544
603
from rich .prompt import Prompt
545
604
605
+ # Print initial message if provided
546
606
if message :
547
607
self .print_response (message = message , ** kwargs )
548
608
549
609
_exit_on = exit_on or ["exit" , "quit" , "bye" ]
550
610
while True :
551
- message = Prompt .ask (f"[bold] { self .emoji } { self .user_name } [/bold]" )
552
- if message in _exit_on :
611
+ try :
612
+ message = Prompt .ask (f"[bold] { self .emoji } { self .user_name } [/bold]" )
613
+ if message in _exit_on :
614
+ break
615
+ self .print_response (message = message , ** kwargs )
616
+ except KeyboardInterrupt :
617
+ print ("\n \n Session ended. Goodbye!" )
553
618
break
554
619
555
- self .print_response (message = message , ** kwargs )
556
-
557
620
def _generate_api (self ):
558
621
"""Generate an API for the agent if api=True."""
559
622
from .api .api_generator import APIGenerator
0 commit comments