diff --git a/camel/agents/chat_agent.py b/camel/agents/chat_agent.py
index 4aeed31110..5b0fca62f5 100644
--- a/camel/agents/chat_agent.py
+++ b/camel/agents/chat_agent.py
@@ -20,7 +20,6 @@
import hashlib
import inspect
import json
-import math
import os
import random
import re
@@ -68,6 +67,7 @@
from camel.memories import (
AgentMemory,
ChatHistoryMemory,
+ ContextRecord,
MemoryRecord,
ScoreBasedContextCreator,
)
@@ -102,6 +102,19 @@
from camel.utils.commons import dependencies_required
from camel.utils.context_utils import ContextUtility
+TOKEN_LIMIT_ERROR_MARKERS = (
+ "context_length_exceeded",
+ "prompt is too long",
+ "exceeded your current quota",
+ "tokens must be reduced",
+ "context length",
+ "token count",
+ "context limit",
+)
+
+SUMMARY_MAX_DEPTH = 3
+SUMMARY_PROGRESS_RECORD_THRESHOLD = 2
+
if TYPE_CHECKING:
from camel.terminators import ResponseTerminator
@@ -354,9 +367,9 @@ class ChatAgent(BaseAgent):
message_window_size (int, optional): The maximum number of previous
messages to include in the context window. If `None`, no windowing
is performed. (default: :obj:`None`)
- token_limit (int, optional): The maximum number of tokens in a context.
- The context will be automatically pruned to fulfill the limitation.
- If `None`, it will be set according to the backend model.
+ summarize_threshold (int, optional): The percentage of the context
+ window that triggers summarization. If `None`, will trigger
+ summarization when the context window is full.
(default: :obj:`None`)
output_language (str, optional): The language to be output by the
agent. (default: :obj:`None`)
@@ -436,6 +449,7 @@ def __init__(
] = None,
memory: Optional[AgentMemory] = None,
message_window_size: Optional[int] = None,
+ summarize_threshold: Optional[int] = None,
token_limit: Optional[int] = None,
output_language: Optional[str] = None,
tools: Optional[List[Union[FunctionTool, Callable]]] = None,
@@ -473,10 +487,17 @@ def __init__(
# Assign unique ID
self.agent_id = agent_id if agent_id else str(uuid.uuid4())
+ if token_limit is not None:
+ logger.warning(
+ "`token_limit` parameter is deprecated and will be ignored. "
+ "Configure the model's token limit in the backend settings "
+ "instead."
+ )
+
# Set up memory
context_creator = ScoreBasedContextCreator(
self.model_backend.token_counter,
- token_limit or self.model_backend.token_limit,
+ self.model_backend.token_limit,
)
self._memory: AgentMemory = memory or ChatHistoryMemory(
@@ -503,6 +524,16 @@ def __init__(
)
self.init_messages()
+ # Set up summarize threshold with validation
+ if summarize_threshold is not None:
+ if not (0 < summarize_threshold <= 100):
+ raise ValueError(
+ f"summarize_threshold must be between 0 and 100, "
+ f"got {summarize_threshold}"
+ )
+ self.summarize_threshold = summarize_threshold
+ self._reset_summary_state()
+
# Set up role name and role type
self.role_name: str = (
getattr(self.system_message, "role_name", None) or "assistant"
@@ -550,6 +581,9 @@ def __init__(
self._context_utility: Optional[ContextUtility] = None
self._context_summary_agent: Optional["ChatAgent"] = None
self.stream_accumulate = stream_accumulate
+ self._last_tool_call_record: Optional[ToolCallingRecord] = None
+ self._last_tool_call_signature: Optional[str] = None
+ self._last_token_limit_tool_signature: Optional[str] = None
def reset(self):
r"""Resets the :obj:`ChatAgent` to its initial state."""
@@ -761,6 +795,334 @@ def _get_full_tool_schemas(self) -> List[Dict[str, Any]]:
for func_tool in self._internal_tools.values()
]
+ @staticmethod
+ def _is_token_limit_error(error: Exception) -> bool:
+ r"""Return True when the exception message indicates a token limit."""
+ error_message = str(error).lower()
+ return any(
+ marker in error_message for marker in TOKEN_LIMIT_ERROR_MARKERS
+ )
+
+ @staticmethod
+ def _is_tool_related_record(record: MemoryRecord) -> bool:
+ r"""Determine whether the given memory record
+ belongs to a tool call."""
+ if record.role_at_backend in {
+ OpenAIBackendRole.TOOL,
+ OpenAIBackendRole.FUNCTION,
+ }:
+ return True
+
+ if (
+ record.role_at_backend == OpenAIBackendRole.ASSISTANT
+ and isinstance(record.message, FunctionCallingMessage)
+ ):
+ return True
+
+ return False
+
+ def _find_indices_to_remove_for_last_tool_pair(
+ self, recent_records: List[ContextRecord]
+ ) -> List[int]:
+ """Find indices of records that should be removed to clean up the most
+ recent incomplete tool interaction pair.
+
+ This method identifies tool call/result pairs by tool_call_id and
+ returns the exact indices to remove, allowing non-contiguous deletions.
+
+ Logic:
+ - If the last record is a tool result (TOOL/FUNCTION) with a
+ tool_call_id, find the matching assistant call anywhere in history
+ and return both indices.
+ - If the last record is an assistant tool call without a result yet,
+ return just that index.
+ - For normal messages (non tool-related): remove just the last one.
+ - Fallback: If no tool_call_id is available, use heuristic (last 2 if
+ tool-related, otherwise last 1).
+
+ Returns:
+ List[int]: Indices to remove (may be non-contiguous).
+ """
+ if not recent_records:
+ return []
+
+ last_idx = len(recent_records) - 1
+ last_record = recent_records[last_idx].memory_record
+
+ # Case A: Last is an ASSISTANT tool call with no result yet
+ if (
+ last_record.role_at_backend == OpenAIBackendRole.ASSISTANT
+ and isinstance(last_record.message, FunctionCallingMessage)
+ and last_record.message.result is None
+ ):
+ return [last_idx]
+
+ # Case B: Last is TOOL/FUNCTION result, try id-based pairing
+ if last_record.role_at_backend in {
+ OpenAIBackendRole.TOOL,
+ OpenAIBackendRole.FUNCTION,
+ }:
+ tool_id = None
+ if isinstance(last_record.message, FunctionCallingMessage):
+ tool_id = last_record.message.tool_call_id
+
+ if tool_id:
+ for idx in range(len(recent_records) - 2, -1, -1):
+ rec = recent_records[idx].memory_record
+ if rec.role_at_backend != OpenAIBackendRole.ASSISTANT:
+ continue
+
+ # Check if this assistant message contains the tool_call_id
+ matched = False
+
+ # Case 1: FunctionCallingMessage (single tool call)
+ if isinstance(rec.message, FunctionCallingMessage):
+ if rec.message.tool_call_id == tool_id:
+ matched = True
+
+ # Case 2: BaseMessage with multiple tool_calls in meta_dict
+ elif (
+ hasattr(rec.message, "meta_dict")
+ and rec.message.meta_dict
+ ):
+ tool_calls_list = rec.message.meta_dict.get(
+ "tool_calls", []
+ )
+ if isinstance(tool_calls_list, list):
+ for tc in tool_calls_list:
+ if (
+ isinstance(tc, dict)
+ and tc.get("id") == tool_id
+ ):
+ matched = True
+ break
+
+ if matched:
+ # Return both assistant call and tool result indices
+ return [idx, last_idx]
+
+ # Fallback: no tool_call_id, use heuristic
+ if self._is_tool_related_record(last_record):
+ # Remove last 2 (assume they are paired)
+ return [last_idx - 1, last_idx] if last_idx > 0 else [last_idx]
+ else:
+ return [last_idx]
+
+ # Default: non tool-related tail => remove last one
+ return [last_idx]
+
+ @staticmethod
+ def _serialize_tool_args(args: Dict[str, Any]) -> str:
+ try:
+ return json.dumps(args, ensure_ascii=False, sort_keys=True)
+ except TypeError:
+ return str(args)
+
+ @classmethod
+ def _build_tool_signature(
+ cls, func_name: str, args: Dict[str, Any]
+ ) -> str:
+ args_repr = cls._serialize_tool_args(args)
+ return f"{func_name}:{args_repr}"
+
+ def _describe_tool_call(
+ self, record: Optional[ToolCallingRecord]
+ ) -> Optional[str]:
+ if record is None:
+ return None
+ args_repr = self._serialize_tool_args(record.args)
+ return f"Tool `{record.tool_name}` invoked with arguments {args_repr}."
+
+ def _format_tool_limit_notice(self) -> Optional[str]:
+ description = self._describe_tool_call(self._last_tool_call_record)
+ if description is None:
+ return None
+ return "[Tool Call Causing Token Limit]\n" f"{description}"
+
+ def _reset_summary_state(self) -> None:
+ self._summary_state: Dict[str, Any] = {
+ "depth": 0,
+ "last_summary_log_length": 0,
+ "last_summary_user_signature": None,
+ "record_count_since_summary": 0,
+ "user_count_since_summary": 0,
+ "summary_token_count": 0, # Total tokens in summary messages
+ "last_summary_threshold": 0, # Last calculated threshold
+ "is_summary_message": set(), # Set of message that are summaries
+ }
+
+ def _calculate_next_summary_threshold(self) -> int:
+ r"""Calculate the next token threshold that should trigger
+ summarization.
+
+ The threshold calculation follows a progressive strategy:
+ - First time: token_limit * (summarize_threshold / 100)
+ - Subsequent times: (limit - summary_token) / 2 + summary_token
+
+ This ensures that as summaries accumulate, the threshold adapts
+ to maintain a reasonable balance between context and summaries.
+
+ Returns:
+ int: The token count threshold for next summarization.
+ """
+ if self.summarize_threshold is None:
+ return 0
+
+ token_limit = self.model_backend.token_limit
+ summary_token_count = self._summary_state.get("summary_token_count", 0)
+
+ # First summarization: use the percentage threshold
+ if summary_token_count == 0:
+ threshold = int(token_limit * self.summarize_threshold / 100)
+ else:
+ # Subsequent summarizations: adaptive threshold
+ threshold = int(
+ (token_limit - summary_token_count)
+ * self.summarize_threshold
+ / 100
+ + summary_token_count
+ )
+
+ return threshold
+
+ def _update_memory_with_summary(self, summary: Dict[str, Any]) -> None:
+ r"""Update memory with summary result.
+
+ This method handles memory clearing and restoration of summaries based
+ on whether it's a progressive or full compression.
+ """
+ if not summary or self.summarize_threshold is None:
+ return
+
+ summary_with_prefix = summary.get("summary_with_prefix")
+ if not summary_with_prefix:
+ return
+ summary_status = summary.get("status", "")
+ include_summaries = summary.get("include_summaries", False)
+
+ existing_summaries = []
+ if not include_summaries:
+ messages, _ = self.memory.get_context()
+ for msg in messages:
+ content = msg.get('content', '')
+ if isinstance(content, str) and content.startswith(
+ '[CAMEL_SUMMARY]'
+ ):
+ existing_summaries.append(msg)
+
+ # Clear memory
+ self.clear_memory()
+
+ # Restore old summaries (for progressive compression)
+ for old_summary in existing_summaries:
+ content = old_summary.get('content', '')
+ if not isinstance(content, str):
+ content = str(content)
+ summary_msg = BaseMessage.make_assistant_message(
+ role_name="Assistant", content=content
+ )
+ self.update_memory(summary_msg, OpenAIBackendRole.ASSISTANT)
+
+ # Add new summary
+ new_summary_msg = BaseMessage.make_assistant_message(
+ role_name="Assistant",
+ content=summary_with_prefix + " " + summary_status,
+ )
+ self.update_memory(new_summary_msg, OpenAIBackendRole.ASSISTANT)
+
+ # Update token count
+ try:
+ summary_tokens = (
+ self.model_backend.token_counter.count_tokens_from_messages(
+ [{"role": "assistant", "content": summary_with_prefix}]
+ )
+ )
+
+ if include_summaries: # Full compression - reset count
+ self._summary_state["summary_token_count"] = summary_tokens
+ logger.info(
+ f"Full compression: Summary with {summary_tokens} tokens. "
+ f"Total summary tokens reset to: {summary_tokens}"
+ )
+ else: # Progressive compression - accumulate
+ self._summary_state["summary_token_count"] += summary_tokens
+ logger.info(
+ f"Progressive compression: New summary "
+ f"with {summary_tokens} tokens. "
+ f"Total summary tokens: "
+ f"{self._summary_state['summary_token_count']}"
+ )
+ except Exception as e:
+ logger.warning(f"Failed to count summary tokens: {e}")
+
+ def _register_record_addition(self, role: OpenAIBackendRole) -> None:
+ state = getattr(self, "_summary_state", None)
+ if not state:
+ return
+
+ state["record_count_since_summary"] = (
+ state.get("record_count_since_summary", 0) + 1
+ )
+
+ if role == OpenAIBackendRole.USER:
+ state["user_count_since_summary"] = (
+ state.get("user_count_since_summary", 0) + 1
+ )
+
+ def _can_attempt_summary(self) -> Tuple[bool, Optional[str]]:
+ state = getattr(self, "_summary_state", None)
+ if not state:
+ return True, None
+
+ depth = state.get("depth", 0)
+
+ if depth >= SUMMARY_MAX_DEPTH:
+ return (
+ False,
+ "Maximum number of summaries reached for this session.",
+ )
+
+ if depth == 0:
+ return True, None
+
+ if (
+ state.get("record_count_since_summary", 0)
+ >= SUMMARY_PROGRESS_RECORD_THRESHOLD
+ ):
+ return True, None
+
+ return (
+ False,
+ "Context was summarized recently but the conversation did not add "
+ "enough new exchanges to summarize again.",
+ )
+
+ def _on_summary(self, records_before_summary: List[ContextRecord]) -> None:
+ state = getattr(self, "_summary_state", None)
+ if state is None:
+ return
+
+ state["depth"] = state.get("depth", 0) + 1
+ state["last_summary_log_length"] = len(records_before_summary)
+ state["record_count_since_summary"] = 0
+ state["user_count_since_summary"] = 0
+ state["last_summary_user_signature"] = (
+ self._extract_last_user_signature(records_before_summary)
+ )
+
+ @staticmethod
+ def _extract_last_user_signature(
+ records: List[ContextRecord],
+ ) -> Optional[str]:
+ for context_record in reversed(records):
+ memory_record = context_record.memory_record
+ if memory_record.role_at_backend == OpenAIBackendRole.USER:
+ content = getattr(memory_record.message, "content", None)
+ if isinstance(content, str):
+ return content
+ return None
+ return None
+
def _get_external_tool_names(self) -> Set[str]:
r"""Returns a set of external tool names."""
return set(self._external_tool_schemas.keys())
@@ -822,16 +1184,6 @@ def update_memory(
) -> None:
r"""Updates the agent memory with a new message.
- If the single *message* exceeds the model's context window, it will
- be **automatically split into multiple smaller chunks** before being
- written into memory. This prevents later failures in
- `ScoreBasedContextCreator` where an over-sized message cannot fit
- into the available token budget at all.
-
- This slicing logic handles both regular text messages (in the
- `content` field) and long tool call results (in the `result` field of
- a `FunctionCallingMessage`).
-
Args:
message (BaseMessage): The new message to add to the stored
messages.
@@ -841,151 +1193,16 @@ def update_memory(
(default: :obj:`None`)
(default: obj:`None`)
"""
-
- # 1. Helper to write a record to memory
- def _write_single_record(
- message: BaseMessage, role: OpenAIBackendRole, timestamp: float
- ):
- self.memory.write_record(
- MemoryRecord(
- message=message,
- role_at_backend=role,
- timestamp=timestamp,
- agent_id=self.agent_id,
- )
- )
-
- base_ts = (
- timestamp
+ record = MemoryRecord(
+ message=message,
+ role_at_backend=role,
+ timestamp=timestamp
if timestamp is not None
- else time.time_ns() / 1_000_000_000
- )
-
- # 2. Get token handling utilities, fallback if unavailable
- try:
- context_creator = self.memory.get_context_creator()
- token_counter = context_creator.token_counter
- token_limit = context_creator.token_limit
- except AttributeError:
- _write_single_record(message, role, base_ts)
- return
-
- # 3. Check if slicing is necessary
- try:
- current_tokens = token_counter.count_tokens_from_messages(
- [message.to_openai_message(role)]
- )
-
- _, ctx_tokens = self.memory.get_context()
-
- remaining_budget = max(0, token_limit - ctx_tokens)
-
- if current_tokens <= remaining_budget:
- _write_single_record(message, role, base_ts)
- return
- except Exception as e:
- logger.warning(
- f"Token calculation failed before chunking, "
- f"writing message as-is. Error: {e}"
- )
- _write_single_record(message, role, base_ts)
- return
-
- # 4. Perform slicing
- logger.warning(
- f"Message with {current_tokens} tokens exceeds remaining budget "
- f"of {remaining_budget}. Slicing into smaller chunks."
+ else time.time_ns() / 1_000_000_000, # Nanosecond precision
+ agent_id=self.agent_id,
)
-
- text_to_chunk: Optional[str] = None
- is_function_result = False
-
- if isinstance(message, FunctionCallingMessage) and isinstance(
- message.result, str
- ):
- text_to_chunk = message.result
- is_function_result = True
- elif isinstance(message.content, str):
- text_to_chunk = message.content
-
- if not text_to_chunk or not text_to_chunk.strip():
- _write_single_record(message, role, base_ts)
- return
- # Encode the entire text to get a list of all token IDs
- try:
- all_token_ids = token_counter.encode(text_to_chunk)
- except Exception as e:
- logger.error(f"Failed to encode text for chunking: {e}")
- _write_single_record(message, role, base_ts) # Fallback
- return
-
- if not all_token_ids:
- _write_single_record(message, role, base_ts) # Nothing to chunk
- return
-
- # 1. Base chunk size: one-tenth of the smaller of (a) total token
- # limit and (b) current remaining budget. This prevents us from
- # creating chunks that are guaranteed to overflow the
- # immediate context window.
- base_chunk_size = max(1, remaining_budget) // 10
-
- # 2. Each chunk gets a textual prefix such as:
- # "[chunk 3/12 of a long message]\n"
- # The prefix itself consumes tokens, so if we do not subtract its
- # length the *total* tokens of the outgoing message (prefix + body)
- # can exceed the intended bound. We estimate the prefix length
- # with a representative example that is safely long enough for the
- # vast majority of cases (three-digit indices).
- sample_prefix = "[chunk 1/1000 of a long message]\n"
- prefix_token_len = len(token_counter.encode(sample_prefix))
-
- # 3. The real capacity for the message body is therefore the base
- # chunk size minus the prefix length. Fallback to at least one
- # token to avoid zero or negative sizes.
- chunk_body_limit = max(1, base_chunk_size - prefix_token_len)
-
- # 4. Calculate how many chunks we will need with this body size.
- num_chunks = math.ceil(len(all_token_ids) / chunk_body_limit)
- group_id = str(uuid.uuid4())
-
- for i in range(num_chunks):
- start_idx = i * chunk_body_limit
- end_idx = start_idx + chunk_body_limit
- chunk_token_ids = all_token_ids[start_idx:end_idx]
-
- chunk_body = token_counter.decode(chunk_token_ids)
-
- prefix = f"[chunk {i + 1}/{num_chunks} of a long message]\n"
- new_body = prefix + chunk_body
-
- if is_function_result and isinstance(
- message, FunctionCallingMessage
- ):
- new_msg: BaseMessage = FunctionCallingMessage(
- role_name=message.role_name,
- role_type=message.role_type,
- meta_dict=message.meta_dict,
- content=message.content,
- func_name=message.func_name,
- args=message.args,
- result=new_body,
- tool_call_id=message.tool_call_id,
- )
- else:
- new_msg = message.create_new_instance(new_body)
-
- meta = (new_msg.meta_dict or {}).copy()
- meta.update(
- {
- "chunk_idx": i + 1,
- "chunk_total": num_chunks,
- "chunk_group_id": group_id,
- }
- )
- new_msg.meta_dict = meta
-
- # Increment timestamp slightly to maintain order
- _write_single_record(new_msg, role, base_ts + i * 1e-6)
+ self.memory.write_record(record)
+ self._register_record_addition(role)
def load_memory(self, memory: AgentMemory) -> None:
r"""Load the provided memory into the agent.
@@ -1070,6 +1287,7 @@ def summarize(
summary_prompt: Optional[str] = None,
response_format: Optional[Type[BaseModel]] = None,
working_directory: Optional[Union[str, Path]] = None,
+ include_summaries: bool = False,
) -> Dict[str, Any]:
r"""Summarize the agent's current conversation context and persist it
to a markdown file.
@@ -1089,6 +1307,11 @@ def summarize(
defining the expected structure of the response. If provided,
the summary will be generated as structured output and included
in the result.
+ include_summaries (bool): Whether to include previously generated
+ summaries in the content to be summarized. If False (default),
+ only non-summary messages will be summarized. If True, all
+ messages including previous summaries will be summarized
+ (full compression). (default: :obj:`False`)
working_directory (Optional[str|Path]): Optional directory to save
the markdown summary file. If provided, overrides the default
directory used by ContextUtility.
@@ -1142,6 +1365,12 @@ def summarize(
role = message.get('role', 'unknown')
content = message.get('content', '')
+ # Skip summary messages if include_summaries is False
+ if not include_summaries and isinstance(content, str):
+ # Check if this is a summary message by looking for marker
+ if content.startswith('[CAMEL_SUMMARY]'):
+ continue
+
# Handle tool call messages (assistant calling tools)
tool_calls = message.get('tool_calls')
if tool_calls and isinstance(tool_calls, (list, tuple)):
@@ -1321,6 +1550,13 @@ def summarize(
}
result.update(result_dict)
+
+ # Add prefix to summary content in result
+ if self.summarize_threshold is not None:
+ summary_with_prefix = f"[CAMEL_SUMMARY] {summary_content}"
+ result["summary_with_prefix"] = summary_with_prefix
+ result["include_summaries"] = include_summaries
+
logger.info("Conversation summary saved to %s", file_path)
return result
@@ -1336,6 +1572,7 @@ async def asummarize(
summary_prompt: Optional[str] = None,
response_format: Optional[Type[BaseModel]] = None,
working_directory: Optional[Union[str, Path]] = None,
+ include_summaries: bool = False,
) -> Dict[str, Any]:
r"""Asynchronously summarize the agent's current conversation context
and persist it to a markdown file.
@@ -1358,6 +1595,11 @@ async def asummarize(
working_directory (Optional[str|Path]): Optional directory to save
the markdown summary file. If provided, overrides the default
directory used by ContextUtility.
+ include_summaries (bool): Whether to include previously generated
+ summaries in the content to be summarized. If False (default),
+ only non-summary messages will be summarized. If True, all
+ messages including previous summaries will be summarized
+ (full compression). (default: :obj:`False`)
Returns:
Dict[str, Any]: A dictionary containing the summary text, file
@@ -1398,6 +1640,12 @@ async def asummarize(
role = message.get('role', 'unknown')
content = message.get('content', '')
+ # Skip summary messages if include_summaries is False
+ if not include_summaries and isinstance(content, str):
+ # Check if this is a summary message by looking for marker
+ if content.startswith('[CAMEL_SUMMARY]'):
+ continue
+
# Handle tool call messages (assistant calling tools)
tool_calls = message.get('tool_calls')
if tool_calls and isinstance(tool_calls, (list, tuple)):
@@ -1586,6 +1834,13 @@ async def asummarize(
}
result.update(result_dict)
+
+ # Add prefix to summary content in result
+ if self.summarize_threshold is not None:
+ summary_with_prefix = f"[CAMEL_SUMMARY] {summary_content}"
+ result["summary_with_prefix"] = summary_with_prefix
+ result["include_summaries"] = include_summaries
+
logger.info("Conversation summary saved to %s", file_path)
return result
@@ -1604,7 +1859,14 @@ def clear_memory(self) -> None:
self.memory.clear()
if self.system_message is not None:
- self.update_memory(self.system_message, OpenAIBackendRole.SYSTEM)
+ self.memory.write_record(
+ MemoryRecord(
+ message=self.system_message,
+ role_at_backend=OpenAIBackendRole.SYSTEM,
+ timestamp=time.time_ns() / 1_000_000_000,
+ agent_id=self.agent_id,
+ )
+ )
def _generate_system_message_for_output_language(
self,
@@ -1638,17 +1900,8 @@ def init_messages(self) -> None:
r"""Initializes the stored messages list with the current system
message.
"""
- self.memory.clear()
- # Write system message to memory if provided
- if self.system_message is not None:
- self.memory.write_record(
- MemoryRecord(
- message=self.system_message,
- role_at_backend=OpenAIBackendRole.SYSTEM,
- timestamp=time.time_ns() / 1_000_000_000,
- agent_id=self.agent_id,
- )
- )
+ self._reset_summary_state()
+ self.clear_memory()
def reset_to_original_system_message(self) -> None:
r"""Reset system message to original, removing any appended context.
@@ -2072,22 +2325,136 @@ def _step_impl(
try:
openai_messages, num_tokens = self.memory.get_context()
+ if self.summarize_threshold is not None:
+ threshold = self._calculate_next_summary_threshold()
+ summary_token_count = self._summary_state.get(
+ "summary_token_count", 0
+ )
+ token_limit = self.model_backend.token_limit
+
+ if num_tokens <= token_limit:
+ # Check if summary tokens exceed 60% of limit
+ # If so, perform full compression (including summaries)
+ if summary_token_count > token_limit * 0.6:
+ logger.info(
+ f"Summary tokens ({summary_token_count}) "
+ f"exceed 60% of limit ({token_limit * 0.6}). "
+ f"Performing full compression."
+ )
+ # Summarize everything (including summaries)
+ summary = self.summarize(include_summaries=True)
+ self._update_memory_with_summary(summary)
+ elif num_tokens > threshold:
+ logger.info(
+ f"Token count ({num_tokens}) exceed threshold "
+ f"({threshold}). Triggering summarization."
+ )
+ # Only summarize non-summary content
+ summary = self.summarize(include_summaries=False)
+ self._update_memory_with_summary(summary)
accumulated_context_tokens += num_tokens
except RuntimeError as e:
return self._step_terminate(
e.args[1], tool_call_records, "max_tokens_exceeded"
)
- # Get response from model backend
- response = self._get_model_response(
- openai_messages,
- num_tokens=num_tokens,
- current_iteration=iteration_count,
- response_format=response_format,
- tool_schemas=[]
- if disable_tools
- else self._get_full_tool_schemas(),
- prev_num_openai_messages=prev_num_openai_messages,
- )
+ # Get response from model backend with token limit error handling
+ try:
+ response = self._get_model_response(
+ openai_messages,
+ num_tokens=num_tokens,
+ current_iteration=iteration_count,
+ response_format=response_format,
+ tool_schemas=[]
+ if disable_tools
+ else self._get_full_tool_schemas(),
+ prev_num_openai_messages=prev_num_openai_messages,
+ )
+ except Exception as exc:
+ logger.exception("Model error: %s", exc)
+
+ if self._is_token_limit_error(exc):
+ tool_signature = self._last_tool_call_signature
+ if (
+ tool_signature is not None
+ and tool_signature
+ == self._last_token_limit_tool_signature
+ ):
+ description = self._describe_tool_call(
+ self._last_tool_call_record
+ )
+ repeated_msg = (
+ "Context exceeded again by the same tool call."
+ )
+ if description:
+ repeated_msg += f" {description}"
+ raise RuntimeError(repeated_msg) from exc
+
+ allowed, reason = self._can_attempt_summary()
+ if not allowed:
+ error_message = (
+ "Context limit exceeded and further summarization "
+ "is not possible."
+ )
+ if reason:
+ error_message += f" {reason}"
+ raise RuntimeError(error_message) from exc
+
+ logger.warning(
+ "Token limit exceeded error detected. "
+ "Summarizing context."
+ )
+
+ recent_records: List[ContextRecord]
+ try:
+ recent_records = self.memory.retrieve()
+ except Exception: # pragma: no cover - defensive guard
+ recent_records = []
+
+ indices_to_remove = (
+ self._find_indices_to_remove_for_last_tool_pair(
+ recent_records
+ )
+ )
+ self.memory.remove_records_by_indices(indices_to_remove)
+
+ summary = self.summarize()
+ if isinstance(input_message, BaseMessage):
+ input_message = input_message.content
+
+ tool_notice = self._format_tool_limit_notice()
+ summary_messages = (
+ "Due to token limit being exceeded in this request, "
+ "the conversation has been summarized."
+ "the content after [Input Message] is the summary "
+ "of [Input Message]"
+ + ""
+ + str(input_message)
+ + ""
+ + "\n\n"
+ "[Context Summary]\n\n"
+ + summary.get("summary", "")
+ + "\n\n[Summary status]\n\n"
+ + str(summary.get("status", ""))
+ + "\n\nThe following is the user's original input. "
+ "Please continue based on existing information."
+ )
+ if tool_notice:
+ summary_messages += "\n\n" + tool_notice
+ self.clear_memory()
+ summary_messages = BaseMessage.make_assistant_message(
+ role_name="Assistant", content=summary_messages
+ )
+ self.update_memory(
+ summary_messages, OpenAIBackendRole.ASSISTANT
+ )
+ self._on_summary(recent_records)
+
+ self._last_token_limit_tool_signature = tool_signature
+
+ return self._step_impl(input_message, response_format)
+
+ raise
+
prev_num_openai_messages = len(openai_messages)
iteration_count += 1
@@ -2282,6 +2649,7 @@ async def _astep_non_streaming_task(
step_token_usage = self._create_token_usage_tracker()
iteration_count: int = 0
prev_num_openai_messages: int = 0
+
while True:
if self.pause_event is not None and not self.pause_event.is_set():
if isinstance(self.pause_event, asyncio.Event):
@@ -2292,21 +2660,140 @@ async def _astep_non_streaming_task(
await loop.run_in_executor(None, self.pause_event.wait)
try:
openai_messages, num_tokens = self.memory.get_context()
+ if self.summarize_threshold is not None:
+ threshold = self._calculate_next_summary_threshold()
+ summary_token_count = self._summary_state.get(
+ "summary_token_count", 0
+ )
+ token_limit = self.model_backend.token_limit
+
+ if num_tokens <= token_limit:
+ # Check if summary tokens exceed 60% of limit
+ # If so, perform full compression (including summaries)
+ if summary_token_count > token_limit * 0.6:
+ logger.info(
+ f"Summary tokens ({summary_token_count}) "
+ f"exceed 60% of limit ({token_limit * 0.6}). "
+ f"Full compression."
+ )
+ # Summarize everything (including summaries)
+ summary = await self.asummarize(
+ include_summaries=True
+ )
+ self._update_memory_with_summary(summary)
+ elif num_tokens > threshold:
+ logger.info(
+ f"Token count ({num_tokens}) exceed threshold "
+ "({threshold}). Triggering summarization."
+ )
+ # Only summarize non-summary content
+ summary = await self.asummarize(
+ include_summaries=False
+ )
+ self._update_memory_with_summary(summary)
accumulated_context_tokens += num_tokens
except RuntimeError as e:
return self._step_terminate(
e.args[1], tool_call_records, "max_tokens_exceeded"
)
- response = await self._aget_model_response(
- openai_messages,
- num_tokens=num_tokens,
- current_iteration=iteration_count,
- response_format=response_format,
- tool_schemas=[]
- if disable_tools
- else self._get_full_tool_schemas(),
- prev_num_openai_messages=prev_num_openai_messages,
- )
+ # Get response from model backend with token limit error handling
+ try:
+ response = await self._aget_model_response(
+ openai_messages,
+ num_tokens=num_tokens,
+ current_iteration=iteration_count,
+ response_format=response_format,
+ tool_schemas=[]
+ if disable_tools
+ else self._get_full_tool_schemas(),
+ prev_num_openai_messages=prev_num_openai_messages,
+ )
+ except Exception as exc:
+ if self._is_token_limit_error(exc):
+ tool_signature = self._last_tool_call_signature
+ if (
+ tool_signature is not None
+ and tool_signature
+ == self._last_token_limit_tool_signature
+ ):
+ description = self._describe_tool_call(
+ self._last_tool_call_record
+ )
+ repeated_msg = (
+ "Context exceeded again by the same tool call."
+ )
+ if description:
+ repeated_msg += f" {description}"
+ raise RuntimeError(repeated_msg) from exc
+
+ allowed, reason = self._can_attempt_summary()
+ if not allowed:
+ error_message = (
+ "Context limit exceeded and further summarization "
+ "is not possible."
+ )
+ if reason:
+ error_message += f" {reason}"
+ raise RuntimeError(error_message) from exc
+
+ logger.warning(
+ "Token limit exceeded error detected. "
+ "Summarizing context."
+ )
+
+ recent_records: List[ContextRecord]
+ try:
+ recent_records = self.memory.retrieve()
+ except Exception: # pragma: no cover - defensive guard
+ recent_records = []
+
+ indices_to_remove = (
+ self._find_indices_to_remove_for_last_tool_pair(
+ recent_records
+ )
+ )
+ self.memory.remove_records_by_indices(indices_to_remove)
+
+ summary = await self.asummarize()
+ if isinstance(input_message, BaseMessage):
+ input_message = input_message.content
+
+ tool_notice = self._format_tool_limit_notice()
+ summary_messages = (
+ "Due to token limit being exceeded in this request, "
+ "the conversation has been summarized."
+ "the content after [Input Message] is the summary"
+ + ""
+ + str(input_message)
+ + ""
+ + "\n\n"
+ "[Context Summary]\n\n"
+ + summary.get("summary", "")
+ + "\n\n[Summary status]\n\n"
+ + str(summary.get("status", ""))
+ + "\n\nThe following is the user's original input. "
+ "Please continue based on existing information."
+ )
+
+ if tool_notice:
+ summary_messages += "\n\n" + tool_notice
+ self.clear_memory()
+ summary_messages = BaseMessage.make_assistant_message(
+ role_name="Assistant", content=summary_messages
+ )
+ self.update_memory(
+ summary_messages, OpenAIBackendRole.ASSISTANT
+ )
+ self._on_summary(recent_records)
+
+ self._last_token_limit_tool_signature = tool_signature
+
+ return await self._astep_non_streaming_task(
+ input_message, response_format
+ )
+
+ raise
+
prev_num_openai_messages = len(openai_messages)
iteration_count += 1
@@ -2479,6 +2966,8 @@ def _get_model_response(
if response:
break
except RateLimitError as e:
+ if self._is_token_limit_error(e):
+ raise
last_error = e
if attempt < self.retry_attempts - 1:
delay = min(self.retry_delay * (2**attempt), 60.0)
@@ -2496,7 +2985,6 @@ def _get_model_response(
except Exception:
logger.error(
f"Model error: {self.model_backend.model_type}",
- exc_info=True,
)
raise
else:
@@ -2543,6 +3031,8 @@ async def _aget_model_response(
if response:
break
except RateLimitError as e:
+ if self._is_token_limit_error(e):
+ raise
last_error = e
if attempt < self.retry_attempts - 1:
delay = min(self.retry_delay * (2**attempt), 60.0)
diff --git a/camel/memories/agent_memories.py b/camel/memories/agent_memories.py
index 01abee68a1..1789cec469 100644
--- a/camel/memories/agent_memories.py
+++ b/camel/memories/agent_memories.py
@@ -129,6 +129,16 @@ def clean_tool_calls(self) -> None:
# Save the modified records back to storage
self._chat_history_block.storage.save(record_dicts)
+ def pop_records(self, count: int) -> List[MemoryRecord]:
+ r"""Removes the most recent records from chat history memory."""
+ return self._chat_history_block.pop_records(count)
+
+ def remove_records_by_indices(
+ self, indices: List[int]
+ ) -> List[MemoryRecord]:
+ r"""Removes records at specified indices from chat history memory."""
+ return self._chat_history_block.remove_records_by_indices(indices)
+
class VectorDBMemory(AgentMemory):
r"""An agent memory wrapper of :obj:`VectorDBBlock`. This memory queries
@@ -193,6 +203,20 @@ def clear(self) -> None:
r"""Removes all records from the vector database memory."""
self._vectordb_block.clear()
+ def pop_records(self, count: int) -> List[MemoryRecord]:
+ r"""Rolling back is unsupported for vector database memory."""
+ raise NotImplementedError(
+ "VectorDBMemory does not support removing historical records."
+ )
+
+ def remove_records_by_indices(
+ self, indices: List[int]
+ ) -> List[MemoryRecord]:
+ r"""Removing by indices is unsupported for vector database memory."""
+ raise NotImplementedError(
+ "VectorDBMemory does not support removing records by indices."
+ )
+
class LongtermAgentMemory(AgentMemory):
r"""An implementation of the :obj:`AgentMemory` abstract base class for
@@ -277,3 +301,13 @@ def clear(self) -> None:
r"""Removes all records from the memory."""
self.chat_history_block.clear()
self.vector_db_block.clear()
+
+ def pop_records(self, count: int) -> List[MemoryRecord]:
+ r"""Removes recent chat history records while leaving vector memory."""
+ return self.chat_history_block.pop_records(count)
+
+ def remove_records_by_indices(
+ self, indices: List[int]
+ ) -> List[MemoryRecord]:
+ r"""Removes records at specified indices from chat history."""
+ return self.chat_history_block.remove_records_by_indices(indices)
diff --git a/camel/memories/base.py b/camel/memories/base.py
index f9d4a0ad83..4cf2ccce15 100644
--- a/camel/memories/base.py
+++ b/camel/memories/base.py
@@ -45,6 +45,32 @@ def write_record(self, record: MemoryRecord) -> None:
"""
self.write_records([record])
+ def pop_records(self, count: int) -> List[MemoryRecord]:
+ r"""Removes records from the memory and returns the removed records.
+
+ Args:
+ count (int): Number of records to remove.
+
+ Returns:
+ List[MemoryRecord]: The records that were removed from the memory
+ in their original order.
+ """
+ raise NotImplementedError
+
+ def remove_records_by_indices(
+ self, indices: List[int]
+ ) -> List[MemoryRecord]:
+ r"""Removes records at specified indices from the memory.
+
+ Args:
+ indices (List[int]): List of indices to remove. Indices should be
+ valid positions in the current record list.
+
+ Returns:
+ List[MemoryRecord]: The removed records in their original order.
+ """
+ raise NotImplementedError
+
@abstractmethod
def clear(self) -> None:
r"""Clears all messages from the memory."""
diff --git a/camel/memories/blocks/chat_history_block.py b/camel/memories/blocks/chat_history_block.py
index 8beaa909e4..1f311f131a 100644
--- a/camel/memories/blocks/chat_history_block.py
+++ b/camel/memories/blocks/chat_history_block.py
@@ -167,3 +167,118 @@ def write_records(self, records: List[MemoryRecord]) -> None:
def clear(self) -> None:
r"""Clears all chat messages from the memory."""
self.storage.clear()
+
+ def pop_records(self, count: int) -> List[MemoryRecord]:
+ r"""Removes the most recent records from the memory.
+
+ Args:
+ count (int): Number of records to remove from the end of the
+ conversation history. A value of 0 results in no changes.
+
+ Returns:
+ List[MemoryRecord]: The removed records in chronological order.
+ """
+ if not isinstance(count, int):
+ raise TypeError("`count` must be an integer.")
+ if count < 0:
+ raise ValueError("`count` must be non-negative.")
+ if count == 0:
+ return []
+
+ record_dicts = self.storage.load()
+ if not record_dicts:
+ return []
+
+ # Preserve initial system/developer instruction if present.
+ protected_prefix = (
+ 1
+ if (
+ record_dicts
+ and record_dicts[0]['role_at_backend']
+ in {
+ OpenAIBackendRole.SYSTEM.value,
+ OpenAIBackendRole.DEVELOPER.value,
+ }
+ )
+ else 0
+ )
+
+ removable_count = max(len(record_dicts) - protected_prefix, 0)
+ if removable_count == 0:
+ return []
+
+ pop_count = min(count, removable_count)
+ split_index = len(record_dicts) - pop_count
+
+ popped_dicts = record_dicts[split_index:]
+ remaining_dicts = record_dicts[:split_index]
+
+ self.storage.clear()
+ if remaining_dicts:
+ self.storage.save(remaining_dicts)
+
+ return [MemoryRecord.from_dict(record) for record in popped_dicts]
+
+ def remove_records_by_indices(
+ self, indices: List[int]
+ ) -> List[MemoryRecord]:
+ r"""Removes records at specified indices from the memory.
+
+ Args:
+ indices (List[int]): List of indices to remove. Indices are
+ positions in the current record list (0-based).
+ System/developer messages at index 0 are protected and will
+ not be removed.
+
+ Returns:
+ List[MemoryRecord]: The removed records in their original order.
+ """
+ if not indices:
+ return []
+
+ record_dicts = self.storage.load()
+ if not record_dicts:
+ return []
+
+ # Preserve initial system/developer instruction if present.
+ protected_prefix = (
+ 1
+ if (
+ record_dicts
+ and record_dicts[0]['role_at_backend']
+ in {
+ OpenAIBackendRole.SYSTEM.value,
+ OpenAIBackendRole.DEVELOPER.value,
+ }
+ )
+ else 0
+ )
+
+ # Filter out protected indices and invalid ones
+ valid_indices = sorted(
+ {
+ idx
+ for idx in indices
+ if idx >= protected_prefix and idx < len(record_dicts)
+ }
+ )
+
+ if not valid_indices:
+ return []
+
+ # Extract records to remove (in original order)
+ removed_records = [record_dicts[idx] for idx in valid_indices]
+
+ # Build remaining records by excluding removed indices
+ remaining_dicts = [
+ record
+ for idx, record in enumerate(record_dicts)
+ if idx not in valid_indices
+ ]
+
+ # Save back to storage
+ self.storage.clear()
+ if remaining_dicts:
+ self.storage.save(remaining_dicts)
+
+ return [MemoryRecord.from_dict(record) for record in removed_records]
diff --git a/camel/memories/context_creators/score_based.py b/camel/memories/context_creators/score_based.py
index 6d2b9ea349..6733a38f8e 100644
--- a/camel/memories/context_creators/score_based.py
+++ b/camel/memories/context_creators/score_based.py
@@ -11,41 +11,24 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
-from collections import defaultdict
-from typing import Dict, List, Optional, Tuple
-from pydantic import BaseModel
+from typing import List, Optional, Tuple
-from camel.logger import get_logger
from camel.memories.base import BaseContextCreator
from camel.memories.records import ContextRecord
-from camel.messages import FunctionCallingMessage, OpenAIMessage
+from camel.messages import OpenAIMessage
from camel.types.enums import OpenAIBackendRole
from camel.utils import BaseTokenCounter
-logger = get_logger(__name__)
-
-
-class _ContextUnit(BaseModel):
- idx: int
- record: ContextRecord
- num_tokens: int
-
class ScoreBasedContextCreator(BaseContextCreator):
- r"""A default implementation of context creation strategy, which inherits
- from :obj:`BaseContextCreator`.
-
- This class provides a strategy to generate a conversational context from
- a list of chat history records while ensuring the total token count of
- the context does not exceed a specified limit. It prunes messages based
- on their score if the total token count exceeds the limit.
+ r"""A context creation strategy that orders records chronologically.
Args:
- token_counter (BaseTokenCounter): An instance responsible for counting
- tokens in a message.
- token_limit (int): The maximum number of tokens allowed in the
- generated context.
+ token_counter (BaseTokenCounter): Token counter instance used to
+ compute the combined token count of the returned messages.
+ token_limit (int): Retained for API compatibility. No longer used to
+ filter records.
"""
def __init__(
@@ -66,376 +49,34 @@ def create_context(
self,
records: List[ContextRecord],
) -> Tuple[List[OpenAIMessage], int]:
- r"""Constructs conversation context from chat history while respecting
- token limits.
-
- Key strategies:
- 1. System message is always prioritized and preserved
- 2. Truncation removes low-score messages first
- 3. Final output maintains chronological order and in history memory,
- the score of each message decreases according to keep_rate. The
- newer the message, the higher the score.
- 4. Tool calls and their responses are kept together to maintain
- API compatibility
-
- Args:
- records (List[ContextRecord]): List of context records with scores
- and timestamps.
-
- Returns:
- Tuple[List[OpenAIMessage], int]:
- - Ordered list of OpenAI messages
- - Total token count of the final context
-
- Raises:
- RuntimeError: If system message alone exceeds token limit
- """
- # ======================
- # 1. System Message Handling
- # ======================
- system_unit, regular_units = self._extract_system_message(records)
- system_tokens = system_unit.num_tokens if system_unit else 0
+ """Returns messages sorted by timestamp and their total token count."""
- # Check early if system message alone exceeds token limit
- if system_tokens > self.token_limit:
- raise RuntimeError(
- f"System message alone exceeds token limit"
- f": {system_tokens} > {self.token_limit}",
- system_tokens,
- )
+ system_record: Optional[ContextRecord] = None
+ remaining_records: List[ContextRecord] = []
- # ======================
- # 2. Deduplication & Initial Processing
- # ======================
- seen_uuids = set()
- if system_unit:
- seen_uuids.add(system_unit.record.memory_record.uuid)
-
- # Process non-system messages with deduplication
- for idx, record in enumerate(records):
+ for record in records:
if (
- record.memory_record.role_at_backend
+ system_record is None
+ and record.memory_record.role_at_backend
== OpenAIBackendRole.SYSTEM
):
+ system_record = record
continue
- if record.memory_record.uuid in seen_uuids:
- continue
- seen_uuids.add(record.memory_record.uuid)
-
- token_count = self.token_counter.count_tokens_from_messages(
- [record.memory_record.to_openai_message()]
- )
- regular_units.append(
- _ContextUnit(
- idx=idx,
- record=record,
- num_tokens=token_count,
- )
- )
-
- # ======================
- # 3. Tool Call Relationship Mapping
- # ======================
- tool_call_groups = self._group_tool_calls_and_responses(regular_units)
-
- # ======================
- # 4. Token Calculation
- # ======================
- total_tokens = system_tokens + sum(u.num_tokens for u in regular_units)
-
- # ======================
- # 5. Early Return if Within Limit
- # ======================
- if total_tokens <= self.token_limit:
- sorted_units = sorted(
- regular_units, key=self._conversation_sort_key
- )
- return self._assemble_output(sorted_units, system_unit)
-
- # ======================
- # 6. Truncation Logic with Tool Call Awareness
- # ======================
- remaining_units = self._truncate_with_tool_call_awareness(
- regular_units, tool_call_groups, system_tokens
- )
-
- # Log only after truncation is actually performed so that both
- # the original and the final token counts are visible.
- tokens_after = system_tokens + sum(
- u.num_tokens for u in remaining_units
- )
- logger.warning(
- "Context truncation performed: "
- f"before={total_tokens}, after={tokens_after}, "
- f"limit={self.token_limit}"
- )
-
- # ======================
- # 7. Output Assembly
- # ======================
-
- # In case system message is the only message in memory when sorted
- # units are empty, raise an error
- if system_unit and len(remaining_units) == 0 and len(records) > 1:
- raise RuntimeError(
- "System message and current message exceeds token limit ",
- total_tokens,
- )
-
- # Sort remaining units chronologically
- final_units = sorted(remaining_units, key=self._conversation_sort_key)
- return self._assemble_output(final_units, system_unit)
-
- def _group_tool_calls_and_responses(
- self, units: List[_ContextUnit]
- ) -> Dict[str, List[_ContextUnit]]:
- r"""Groups tool calls with their corresponding responses based on
- `tool_call_id`.
-
- This improved logic robustly gathers all messages (assistant requests
- and tool responses, including chunks) that share a `tool_call_id`.
-
- Args:
- units (List[_ContextUnit]): List of context units to analyze.
-
- Returns:
- Dict[str, List[_ContextUnit]]: Mapping from `tool_call_id` to a
- list of related units.
- """
- tool_call_groups: Dict[str, List[_ContextUnit]] = defaultdict(list)
-
- for unit in units:
- # FunctionCallingMessage stores tool_call_id.
- message = unit.record.memory_record.message
- tool_call_id = getattr(message, 'tool_call_id', None)
-
- if tool_call_id:
- tool_call_groups[tool_call_id].append(unit)
-
- # Filter out empty or incomplete groups if necessary,
- # though defaultdict and getattr handle this gracefully.
- return dict(tool_call_groups)
-
- def _truncate_with_tool_call_awareness(
- self,
- regular_units: List[_ContextUnit],
- tool_call_groups: Dict[str, List[_ContextUnit]],
- system_tokens: int,
- ) -> List[_ContextUnit]:
- r"""Truncates messages while preserving tool call-response pairs.
- This method implements a more sophisticated truncation strategy:
- 1. It treats tool call groups (request + responses) and standalone
- messages as individual items to be included.
- 2. It sorts all items by score and greedily adds them to the context.
- 3. **Partial Truncation**: If a complete tool group is too large to
- fit,it attempts to add the request message and as many of the most
- recent response chunks as the token budget allows.
-
- Args:
- regular_units (List[_ContextUnit]): All regular message units.
- tool_call_groups (Dict[str, List[_ContextUnit]]): Grouped tool
- calls.
- system_tokens (int): Tokens used by the system message.
-
- Returns:
- List[_ContextUnit]: A list of units that fit within the token
- limit.
- """
-
- # Create a set for quick lookup of units belonging to any tool call
- tool_call_unit_ids = {
- unit.record.memory_record.uuid
- for group in tool_call_groups.values()
- for unit in group
- }
-
- # Separate standalone units from tool call groups
- standalone_units = [
- u
- for u in regular_units
- if u.record.memory_record.uuid not in tool_call_unit_ids
- ]
-
- # Prepare all items (standalone units and groups) for sorting
- all_potential_items: List[Dict] = []
- for unit in standalone_units:
- all_potential_items.append(
- {
- "type": "standalone",
- "score": unit.record.score,
- "timestamp": unit.record.timestamp,
- "tokens": unit.num_tokens,
- "item": unit,
- }
- )
- for group in tool_call_groups.values():
- all_potential_items.append(
- {
- "type": "group",
- "score": max(u.record.score for u in group),
- "timestamp": max(u.record.timestamp for u in group),
- "tokens": sum(u.num_tokens for u in group),
- "item": group,
- }
- )
-
- # Sort all potential items by score (high to low), then timestamp
- all_potential_items.sort(key=lambda x: (-x["score"], -x["timestamp"]))
-
- remaining_units: List[_ContextUnit] = []
- current_tokens = system_tokens
-
- for item_dict in all_potential_items:
- item_type = item_dict["type"]
- item = item_dict["item"]
- item_tokens = item_dict["tokens"]
-
- if current_tokens + item_tokens <= self.token_limit:
- # The whole item (standalone or group) fits, so add it
- if item_type == "standalone":
- remaining_units.append(item)
- else: # item_type == "group"
- remaining_units.extend(item)
- current_tokens += item_tokens
-
- elif item_type == "group":
- # The group does not fit completely; try partial inclusion.
- request_unit: Optional[_ContextUnit] = None
- response_units: List[_ContextUnit] = []
-
- for unit in item:
- # Assistant msg with `args` is the request
- if (
- isinstance(
- unit.record.memory_record.message,
- FunctionCallingMessage,
- )
- and unit.record.memory_record.message.args is not None
- ):
- request_unit = unit
- else:
- response_units.append(unit)
-
- # A group must have a request to be considered for inclusion.
- if request_unit is None:
- continue
-
- # Check if we can at least fit the request.
- if (
- current_tokens + request_unit.num_tokens
- <= self.token_limit
- ):
- units_to_add = [request_unit]
- tokens_to_add = request_unit.num_tokens
-
- # Sort responses by timestamp to add newest chunks first
- response_units.sort(
- key=lambda u: u.record.timestamp, reverse=True
- )
+ remaining_records.append(record)
- for resp_unit in response_units:
- if (
- current_tokens
- + tokens_to_add
- + resp_unit.num_tokens
- <= self.token_limit
- ):
- units_to_add.append(resp_unit)
- tokens_to_add += resp_unit.num_tokens
+ remaining_records.sort(key=lambda record: record.timestamp)
- # A request must be followed by at least one response
- if len(units_to_add) > 1:
- remaining_units.extend(units_to_add)
- current_tokens += tokens_to_add
+ messages: List[OpenAIMessage] = []
+ if system_record is not None:
+ messages.append(system_record.memory_record.to_openai_message())
- return remaining_units
-
- def _extract_system_message(
- self, records: List[ContextRecord]
- ) -> Tuple[Optional[_ContextUnit], List[_ContextUnit]]:
- r"""Extracts the system message from records and validates it.
-
- Args:
- records (List[ContextRecord]): List of context records
- representing conversation history.
-
- Returns:
- Tuple[Optional[_ContextUnit], List[_ContextUnit]]: containing:
- - The system message as a `_ContextUnit`, if valid; otherwise,
- `None`.
- - An empty list, serving as the initial container for regular
- messages.
- """
- if not records:
- return None, []
-
- first_record = records[0]
- if (
- first_record.memory_record.role_at_backend
- != OpenAIBackendRole.SYSTEM
- ):
- return None, []
-
- message = first_record.memory_record.to_openai_message()
- tokens = self.token_counter.count_tokens_from_messages([message])
- system_message_unit = _ContextUnit(
- idx=0,
- record=first_record,
- num_tokens=tokens,
+ messages.extend(
+ record.memory_record.to_openai_message()
+ for record in remaining_records
)
- return system_message_unit, []
-
- def _conversation_sort_key(
- self, unit: _ContextUnit
- ) -> Tuple[float, float]:
- r"""Defines the sorting key for assembling the final output.
-
- Sorting priority:
- - Primary: Sort by timestamp in ascending order (chronological order).
- - Secondary: Sort by score in descending order (higher scores first
- when timestamps are equal).
-
- Args:
- unit (_ContextUnit): A `_ContextUnit` representing a conversation
- record.
-
- Returns:
- Tuple[float, float]:
- - Timestamp for chronological sorting.
- - Negative score for descending order sorting.
- """
- return (unit.record.timestamp, -unit.record.score)
-
- def _assemble_output(
- self,
- context_units: List[_ContextUnit],
- system_unit: Optional[_ContextUnit],
- ) -> Tuple[List[OpenAIMessage], int]:
- r"""Assembles final message list with proper ordering and token count.
-
- Args:
- context_units (List[_ContextUnit]): Sorted list of regular message
- units.
- system_unit (Optional[_ContextUnit]): System message unit (if
- present).
-
- Returns:
- Tuple[List[OpenAIMessage], int]: Tuple of (ordered messages, total
- tokens)
- """
- messages = []
- total_tokens = 0
-
- # Add system message first if present
- if system_unit:
- messages.append(
- system_unit.record.memory_record.to_openai_message()
- )
- total_tokens += system_unit.num_tokens
- # Add sorted regular messages
- for unit in context_units:
- messages.append(unit.record.memory_record.to_openai_message())
- total_tokens += unit.num_tokens
+ if not messages:
+ return [], 0
+ total_tokens = self.token_counter.count_tokens_from_messages(messages)
return messages, total_tokens
diff --git a/camel/storages/vectordb_storages/oceanbase.py b/camel/storages/vectordb_storages/oceanbase.py
index 7cfe4b5c16..c250e87b0b 100644
--- a/camel/storages/vectordb_storages/oceanbase.py
+++ b/camel/storages/vectordb_storages/oceanbase.py
@@ -121,10 +121,11 @@ def __init__(
)
# Get the first index parameter
- first_index_param = next(iter(index_params))
- self._client.create_vidx_with_vec_index_param(
- table_name=self.table_name, vidx_param=first_index_param
- )
+ first_index_param = next(iter(index_params), None)
+ if first_index_param is not None:
+ self._client.create_vidx_with_vec_index_param(
+ table_name=self.table_name, vidx_param=first_index_param
+ )
logger.info(f"Created table {self.table_name} with vector index")
else:
diff --git a/test/agents/test_chat_agent.py b/test/agents/test_chat_agent.py
index ffc8a06da5..2e6f7538eb 100644
--- a/test/agents/test_chat_agent.py
+++ b/test/agents/test_chat_agent.py
@@ -560,6 +560,21 @@ def test_chat_agent_step_exceed_token_number(step_call_count=3):
system_message=system_msg,
token_limit=1,
)
+
+ original_get_context = assistant.memory.get_context
+
+ def mock_get_context():
+ messages, _ = original_get_context()
+ # Raise RuntimeError as if context size exceeded limit
+ raise RuntimeError(
+ "Context size exceeded",
+ {
+ "status": "error",
+ "message": "The context has exceeded the maximum token limit.",
+ },
+ )
+
+ assistant.memory.get_context = mock_get_context
assistant.model_backend.run = MagicMock(
return_value=model_backend_rsp_base
)
diff --git a/test/memories/test_chat_history_memory.py b/test/memories/test_chat_history_memory.py
index 22daecfe84..3eee644402 100644
--- a/test/memories/test_chat_history_memory.py
+++ b/test/memories/test_chat_history_memory.py
@@ -91,3 +91,69 @@ def test_chat_history_memory(memory: ChatHistoryMemory):
assert output_messages[0] == system_msg.to_openai_system_message()
assert output_messages[1] == user_msg.to_openai_user_message()
assert output_messages[2] == assistant_msg.to_openai_assistant_message()
+
+
+@pytest.mark.parametrize("memory", ["in-memory", "json"], indirect=True)
+def test_chat_history_memory_pop_records(memory: ChatHistoryMemory):
+ system_msg = BaseMessage(
+ "system",
+ role_type=RoleType.DEFAULT,
+ meta_dict=None,
+ content="System instructions",
+ )
+ user_msgs = [
+ BaseMessage(
+ "AI user",
+ role_type=RoleType.USER,
+ meta_dict=None,
+ content=f"Message {idx}",
+ )
+ for idx in range(3)
+ ]
+
+ records = [
+ MemoryRecord(
+ message=system_msg,
+ role_at_backend=OpenAIBackendRole.SYSTEM,
+ timestamp=datetime.now().timestamp(),
+ agent_id="system",
+ ),
+ *[
+ MemoryRecord(
+ message=msg,
+ role_at_backend=OpenAIBackendRole.USER,
+ timestamp=datetime.now().timestamp(),
+ agent_id="user",
+ )
+ for msg in user_msgs
+ ],
+ ]
+
+ memory.write_records(records)
+
+ popped = memory.pop_records(2)
+ assert [record.message.content for record in popped] == [
+ "Message 1",
+ "Message 2",
+ ]
+
+ remaining_messages, _ = memory.get_context()
+ assert [msg['content'] for msg in remaining_messages] == [
+ "System instructions",
+ "Message 0",
+ ]
+
+ # Attempting to pop more than available should leave system message intact.
+ popped = memory.pop_records(5)
+ assert [record.message.content for record in popped] == ["Message 0"]
+
+ remaining_messages, _ = memory.get_context()
+ assert [msg['content'] for msg in remaining_messages] == [
+ "System instructions",
+ ]
+
+ # Zero pop should be a no-op.
+ assert memory.pop_records(0) == []
+
+ with pytest.raises(ValueError):
+ memory.pop_records(-1)