Skip to content

Commit dcad510

Browse files
bump version to 0.0.5 and implement memory storage classes
1 parent 0176382 commit dcad510

19 files changed

+417
-64
lines changed

build/lib/semantio/agent.py

Lines changed: 86 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
from pathlib import Path
1717
import importlib
1818
import os
19+
from .memory import Memory
1920

2021
# Configure logging
2122
logging.basicConfig(level=logging.INFO)
@@ -48,6 +49,13 @@ class Agent(BaseModel):
4849
semantic_model: Optional[Any] = Field(None, description="SentenceTransformer model for semantic matching.")
4950
team: Optional[List['Agent']] = Field(None, description="List of assistants in the team.")
5051
auto_tool: bool = Field(False, description="Whether to automatically detect and call tools.")
52+
memory: Memory = Field(default_factory=Memory)
53+
memory_config: Dict = Field(
54+
default_factory=lambda: {
55+
"max_context_length": 4000,
56+
"summarization_threshold": 3000
57+
}
58+
)
5159

5260
# Allow arbitrary types
5361
model_config = ConfigDict(arbitrary_types_allowed=True)
@@ -56,6 +64,11 @@ def __init__(self, **kwargs):
5664
super().__init__(**kwargs)
5765
# Initialize the model and tools here if needed
5866
self._initialize_model()
67+
# Initialize memory with config
68+
self.memory = Memory(
69+
max_context_length=self.memory_config.get("max_context_length", 4000),
70+
summarization_threshold=self.memory_config.get("summarization_threshold", 3000)
71+
)
5972
# Initialize tools as an empty list if not provided
6073
if self.tools is None:
6174
self.tools = []
@@ -218,20 +231,31 @@ def print_response(
218231
markdown: bool = False,
219232
team: Optional[List['Agent']] = None,
220233
**kwargs,
221-
) -> Union[str, Dict]: # Add return type hint
234+
) -> Union[str, Dict]:
222235
"""Print the agent's response to the console and return it."""
236+
237+
# Store user message if provided
238+
if message and isinstance(message, str):
239+
self.memory.add_message(role="user", content=message)
223240

224241
if stream:
225242
# Handle streaming response
226243
response = ""
227244
for chunk in self._stream_response(message, markdown=markdown, **kwargs):
228-
print(chunk)
245+
print(chunk, end="", flush=True)
229246
response += chunk
247+
# Store agent response
248+
if response:
249+
self.memory.add_message(role="assistant", content=response)
250+
print() # New line after streaming
230251
return response
231252
else:
232253
# Generate and return the response
233254
response = self._generate_response(message, markdown=markdown, team=team, **kwargs)
234255
print(response) # Print the response to the console
256+
# Store agent response
257+
if response:
258+
self.memory.add_message(role="assistant", content=response)
235259
return response
236260

237261

@@ -294,12 +318,10 @@ def _generate_response(self, message: str, markdown: bool = False, team: Optiona
294318
# Use the specified team if provided
295319
if team is not None:
296320
return self._generate_team_response(message, team, markdown=markdown, **kwargs)
297-
298321
# Initialize tool_outputs as an empty dictionary
299322
tool_outputs = {}
300323
responses = []
301324
tool_calls = []
302-
303325
# Use the LLM to analyze the query and dynamically select tools when auto_tool is enabled
304326
if self.auto_tool:
305327
tool_calls = self._analyze_query_and_select_tools(message)
@@ -347,13 +369,17 @@ def _generate_response(self, message: str, markdown: bool = False, team: Optiona
347369
try:
348370
# Prepare the context for the LLM
349371
context = {
372+
"conversation_history": self.memory.get_context(self.llm_instance),
350373
"tool_outputs": tool_outputs,
351374
"rag_context": self.rag.retrieve(message) if self.rag else None,
352-
"knowledge_base_context": self._find_all_relevant_keys(message, self._flatten_data(self.knowledge_base)) if self.knowledge_base else None,
375+
"knowledge_base": self._get_knowledge_context(message) if self.knowledge_base else None,
353376
}
354-
377+
# 3. Build a memory-aware prompt.
378+
prompt = self._build_memory_prompt(message, context)
379+
# To (convert MemoryEntry objects to dicts and remove metadata):
380+
memory_entries = [{"role": e.role, "content": e.content} for e in self.memory.storage.retrieve()]
355381
# Generate a response using the LLM
356-
llm_response = self.llm_instance.generate(prompt=message, context=context, **kwargs)
382+
llm_response = self.llm_instance.generate(prompt=prompt, context=context, memory=memory_entries, **kwargs)
357383
responses.append(f"**Analysis:**\n\n{llm_response}")
358384
except Exception as e:
359385
logger.error(f"Failed to generate LLM response: {e}")
@@ -363,25 +389,30 @@ def _generate_response(self, message: str, markdown: bool = False, team: Optiona
363389
# Retrieve relevant context using RAG
364390
rag_context = self.rag.retrieve(message) if self.rag else None
365391
# Retrieve relevant context from the knowledge base (API result)
366-
knowledge_base_context = None
367-
if self.knowledge_base:
368-
# Flatten the knowledge base
369-
flattened_data = self._flatten_data(self.knowledge_base)
370-
# Find all relevant key-value pairs in the knowledge base
371-
relevant_values = self._find_all_relevant_keys(message, flattened_data)
372-
if relevant_values:
373-
knowledge_base_context = ", ".join(relevant_values)
392+
# knowledge_base_context = None
393+
# if self.knowledge_base:
394+
# # Flatten the knowledge base
395+
# flattened_data = self._flatten_data(self.knowledge_base)
396+
# # Find all relevant key-value pairs in the knowledge base
397+
# relevant_values = self._find_all_relevant_keys(message, flattened_data)
398+
# if relevant_values:
399+
# knowledge_base_context = ", ".join(relevant_values)
374400

375401
# Combine both contexts (RAG and knowledge base)
376402
context = {
403+
"conversation_history": self.memory.get_context(self.llm_instance),
377404
"rag_context": rag_context,
378-
"knowledge_base_context": knowledge_base_context,
405+
"knowledge_base": self._get_knowledge_context(message),
379406
}
380407
# Prepare the prompt with instructions, description, and context
381-
prompt = self._build_prompt(message, context)
408+
# 3. Build a memory-aware prompt.
409+
prompt = self._build_memory_prompt(message, context)
410+
# To (convert MemoryEntry objects to dicts and remove metadata):
411+
memory_entries = [{"role": e.role, "content": e.content} for e in self.memory.storage.retrieve()]
382412

383413
# Generate the response using the LLM
384-
response = self.llm_instance.generate(prompt=prompt, context=context, **kwargs)
414+
response = self.llm_instance.generate(prompt=prompt, context=context, memory=memory_entries, **kwargs)
415+
385416

386417
# Format the response based on the json_output flag
387418
if self.json_output:
@@ -394,9 +425,37 @@ def _generate_response(self, message: str, markdown: bool = False, team: Optiona
394425
if markdown:
395426
return f"**Response:**\n\n{response}"
396427
return response
397-
# Combine all responses into a single string
398428
return "\n\n".join(responses)
399429

430+
# Modified prompt construction with memory integration
431+
def _build_memory_prompt(self, user_input: str, context: dict) -> str:
432+
"""Enhanced prompt builder with memory context."""
433+
prompt_parts = []
434+
435+
if self.description:
436+
prompt_parts.append(f"# ROLE\n{self.description}")
437+
438+
if self.instructions:
439+
prompt_parts.append(f"# INSTRUCTIONS\n" + "\n".join(f"- {i}" for i in self.instructions))
440+
441+
if context['conversation_history']:
442+
prompt_parts.append(f"# CONVERSATION HISTORY\n{context['conversation_history']}")
443+
444+
if context['knowledge_base']:
445+
prompt_parts.append(f"# KNOWLEDGE BASE\n{context['knowledge_base']}")
446+
447+
prompt_parts.append(f"# USER INPUT\n{user_input}")
448+
449+
return "\n\n".join(prompt_parts)
450+
451+
def _get_knowledge_context(self, message: str) -> str:
452+
"""Retrieve and format knowledge base context."""
453+
if not self.knowledge_base:
454+
return ""
455+
456+
flattened = self._flatten_data(self.knowledge_base)
457+
relevant = self._find_all_relevant_keys(message, flattened)
458+
return "\n".join(f"- {item}" for item in relevant) if relevant else ""
400459
def _generate_team_response(self, message: str, team: List['Agent'], markdown: bool = False, **kwargs) -> str:
401460
"""Generate a response using a team of assistants."""
402461
responses = []
@@ -543,17 +602,21 @@ def cli_app(
543602
"""Run the agent in a CLI app."""
544603
from rich.prompt import Prompt
545604

605+
# Print initial message if provided
546606
if message:
547607
self.print_response(message=message, **kwargs)
548608

549609
_exit_on = exit_on or ["exit", "quit", "bye"]
550610
while True:
551-
message = Prompt.ask(f"[bold] {self.emoji} {self.user_name} [/bold]")
552-
if message in _exit_on:
611+
try:
612+
message = Prompt.ask(f"[bold] {self.emoji} {self.user_name} [/bold]")
613+
if message in _exit_on:
614+
break
615+
self.print_response(message=message, **kwargs)
616+
except KeyboardInterrupt:
617+
print("\n\nSession ended. Goodbye!")
553618
break
554619

555-
self.print_response(message=message, **kwargs)
556-
557620
def _generate_api(self):
558621
"""Generate an API for the agent if api=True."""
559622
from .api.api_generator import APIGenerator

build/lib/semantio/memory.py

Lines changed: 51 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,54 @@
1-
from typing import List, Dict
2-
1+
from .models import MemoryEntry
2+
from .storage import BaseMemoryStorage, InMemoryStorage, FileStorage
3+
from typing import List, Dict, Optional
4+
from .llm.base_llm import BaseLLM
35
class Memory:
4-
def __init__(self):
5-
self.history = []
6+
def __init__(
7+
self,
8+
storage: BaseMemoryStorage = InMemoryStorage(),
9+
max_context_length: int = 4000,
10+
summarization_threshold: int = 3000
11+
):
12+
self.storage = storage
13+
self.max_context_length = max_context_length
14+
self.summarization_threshold = summarization_threshold
15+
self._current_context = ""
16+
17+
def add_message(self, role: str, content: str, metadata: Optional[Dict] = None):
18+
entry = MemoryEntry(
19+
role=role,
20+
content=content,
21+
metadata=metadata or {}
22+
)
23+
self.storage.store(entry)
24+
self._manage_context()
25+
26+
def get_context(self, llm: Optional[BaseLLM] = None) -> str:
27+
if len(self._current_context) < self.summarization_threshold:
28+
return self._current_context
29+
30+
# Automatic summarization when context grows too large
31+
if llm:
32+
return self.summarize(llm)
33+
return self._current_context[:self.max_context_length]
34+
def _manage_context(self):
35+
# Include roles in the conversation history
36+
full_history = "\n".join([f"{e.role}: {e.content}" for e in self.storage.retrieve()])
37+
if len(full_history) > self.max_context_length:
38+
self._current_context = full_history[-self.max_context_length:]
39+
else:
40+
self._current_context = full_history
641

7-
def add_message(self, role: str, content: str):
8-
self.history.append({"role": role, "content": content})
42+
def summarize(self, llm: BaseLLM) -> str:
43+
# Include roles in the history for summarization
44+
history = "\n".join([f"{e.role}: {e.content}" for e in self.storage.retrieve()])
45+
prompt = f"""
46+
Summarize this conversation history maintaining key details and references:
47+
{history[-self.summarization_threshold:]}
48+
"""
49+
self._current_context = llm.generate(prompt)
50+
return self._current_context
951

10-
def get_history(self) -> List[Dict]:
11-
return self.history
52+
def clear(self):
53+
self.storage = InMemoryStorage()
54+
self._current_context = ""

build/lib/semantio/models.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
from pydantic import BaseModel, Field
2+
from datetime import datetime
3+
from typing import Dict
4+
5+
class MemoryEntry(BaseModel):
6+
role: str # "user" or "assistant"
7+
content: str
8+
timestamp: datetime = Field(default_factory=datetime.now)
9+
metadata: Dict = Field(default_factory=dict)
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
from .base_storage import BaseMemoryStorage
2+
from .in_memory_storage import InMemoryStorage
3+
from .local_storage import FileStorage
4+
5+
__all__ = ['BaseMemoryStorage', 'InMemoryStorage', 'FileStorage']
Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
from abc import ABC, abstractmethod
2+
from typing import List, Optional
3+
from ..models import MemoryEntry
4+
5+
class BaseMemoryStorage(ABC):
6+
@abstractmethod
7+
def store(self, entry: MemoryEntry):
8+
pass
9+
10+
@abstractmethod
11+
def retrieve(self, query: Optional[str] = None, limit: int = 20) -> List[MemoryEntry]:
12+
pass
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
# hashai/storage/in_memory_storage.py
2+
from typing import List, Optional
3+
from ..models import MemoryEntry
4+
from .base_storage import BaseMemoryStorage
5+
6+
class InMemoryStorage(BaseMemoryStorage):
7+
def __init__(self):
8+
self.history: List[MemoryEntry] = []
9+
10+
def store(self, entry: MemoryEntry):
11+
self.history.append(entry)
12+
13+
def retrieve(self, query: Optional[str] = None, limit: int = 10) -> List[MemoryEntry]:
14+
return self.history[-limit:]
Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
import json
2+
from typing import List, Optional
3+
from ..models import MemoryEntry
4+
from .base_storage import BaseMemoryStorage
5+
6+
class FileStorage(BaseMemoryStorage):
7+
def __init__(self, file_path: str = "memory.json"):
8+
self.file_path = file_path
9+
self.history = self._load_from_file()
10+
11+
def _load_from_file(self) -> List[MemoryEntry]:
12+
try:
13+
with open(self.file_path, "r") as f:
14+
data = json.load(f)
15+
return [MemoryEntry(**entry) for entry in data]
16+
except (FileNotFoundError, json.JSONDecodeError):
17+
return []
18+
19+
def _save_to_file(self):
20+
with open(self.file_path, "w") as f:
21+
data = [entry.dict() for entry in self.history]
22+
json.dump(data, f, default=str)
23+
24+
def store(self, entry: MemoryEntry):
25+
self.history.append(entry)
26+
self._save_to_file()
27+
28+
def retrieve(self, query: Optional[str] = None, limit: int = 20) -> List[MemoryEntry]:
29+
return self.history[-limit:]
38.9 KB
Binary file not shown.

dist/semantio-0.0.5.tar.gz

28.7 KB
Binary file not shown.

semantio.egg-info/PKG-INFO

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
Metadata-Version: 2.1
22
Name: semantio
3-
Version: 0.0.4
3+
Version: 0.0.5
44
Summary: A powerful SDK for building AI agents
55
Home-page: https://github.com/Syenah/semantio
66
Author: Rakesh

0 commit comments

Comments
 (0)