diff --git a/camel/agents/chat_agent.py b/camel/agents/chat_agent.py index 4aeed31110..5b0fca62f5 100644 --- a/camel/agents/chat_agent.py +++ b/camel/agents/chat_agent.py @@ -20,7 +20,6 @@ import hashlib import inspect import json -import math import os import random import re @@ -68,6 +67,7 @@ from camel.memories import ( AgentMemory, ChatHistoryMemory, + ContextRecord, MemoryRecord, ScoreBasedContextCreator, ) @@ -102,6 +102,19 @@ from camel.utils.commons import dependencies_required from camel.utils.context_utils import ContextUtility +TOKEN_LIMIT_ERROR_MARKERS = ( + "context_length_exceeded", + "prompt is too long", + "exceeded your current quota", + "tokens must be reduced", + "context length", + "token count", + "context limit", +) + +SUMMARY_MAX_DEPTH = 3 +SUMMARY_PROGRESS_RECORD_THRESHOLD = 2 + if TYPE_CHECKING: from camel.terminators import ResponseTerminator @@ -354,9 +367,9 @@ class ChatAgent(BaseAgent): message_window_size (int, optional): The maximum number of previous messages to include in the context window. If `None`, no windowing is performed. (default: :obj:`None`) - token_limit (int, optional): The maximum number of tokens in a context. - The context will be automatically pruned to fulfill the limitation. - If `None`, it will be set according to the backend model. + summarize_threshold (int, optional): The percentage of the context + window that triggers summarization. If `None`, will trigger + summarization when the context window is full. (default: :obj:`None`) output_language (str, optional): The language to be output by the agent. (default: :obj:`None`) @@ -436,6 +449,7 @@ def __init__( ] = None, memory: Optional[AgentMemory] = None, message_window_size: Optional[int] = None, + summarize_threshold: Optional[int] = None, token_limit: Optional[int] = None, output_language: Optional[str] = None, tools: Optional[List[Union[FunctionTool, Callable]]] = None, @@ -473,10 +487,17 @@ def __init__( # Assign unique ID self.agent_id = agent_id if agent_id else str(uuid.uuid4()) + if token_limit is not None: + logger.warning( + "`token_limit` parameter is deprecated and will be ignored. " + "Configure the model's token limit in the backend settings " + "instead." + ) + # Set up memory context_creator = ScoreBasedContextCreator( self.model_backend.token_counter, - token_limit or self.model_backend.token_limit, + self.model_backend.token_limit, ) self._memory: AgentMemory = memory or ChatHistoryMemory( @@ -503,6 +524,16 @@ def __init__( ) self.init_messages() + # Set up summarize threshold with validation + if summarize_threshold is not None: + if not (0 < summarize_threshold <= 100): + raise ValueError( + f"summarize_threshold must be between 0 and 100, " + f"got {summarize_threshold}" + ) + self.summarize_threshold = summarize_threshold + self._reset_summary_state() + # Set up role name and role type self.role_name: str = ( getattr(self.system_message, "role_name", None) or "assistant" @@ -550,6 +581,9 @@ def __init__( self._context_utility: Optional[ContextUtility] = None self._context_summary_agent: Optional["ChatAgent"] = None self.stream_accumulate = stream_accumulate + self._last_tool_call_record: Optional[ToolCallingRecord] = None + self._last_tool_call_signature: Optional[str] = None + self._last_token_limit_tool_signature: Optional[str] = None def reset(self): r"""Resets the :obj:`ChatAgent` to its initial state.""" @@ -761,6 +795,334 @@ def _get_full_tool_schemas(self) -> List[Dict[str, Any]]: for func_tool in self._internal_tools.values() ] + @staticmethod + def _is_token_limit_error(error: Exception) -> bool: + r"""Return True when the exception message indicates a token limit.""" + error_message = str(error).lower() + return any( + marker in error_message for marker in TOKEN_LIMIT_ERROR_MARKERS + ) + + @staticmethod + def _is_tool_related_record(record: MemoryRecord) -> bool: + r"""Determine whether the given memory record + belongs to a tool call.""" + if record.role_at_backend in { + OpenAIBackendRole.TOOL, + OpenAIBackendRole.FUNCTION, + }: + return True + + if ( + record.role_at_backend == OpenAIBackendRole.ASSISTANT + and isinstance(record.message, FunctionCallingMessage) + ): + return True + + return False + + def _find_indices_to_remove_for_last_tool_pair( + self, recent_records: List[ContextRecord] + ) -> List[int]: + """Find indices of records that should be removed to clean up the most + recent incomplete tool interaction pair. + + This method identifies tool call/result pairs by tool_call_id and + returns the exact indices to remove, allowing non-contiguous deletions. + + Logic: + - If the last record is a tool result (TOOL/FUNCTION) with a + tool_call_id, find the matching assistant call anywhere in history + and return both indices. + - If the last record is an assistant tool call without a result yet, + return just that index. + - For normal messages (non tool-related): remove just the last one. + - Fallback: If no tool_call_id is available, use heuristic (last 2 if + tool-related, otherwise last 1). + + Returns: + List[int]: Indices to remove (may be non-contiguous). + """ + if not recent_records: + return [] + + last_idx = len(recent_records) - 1 + last_record = recent_records[last_idx].memory_record + + # Case A: Last is an ASSISTANT tool call with no result yet + if ( + last_record.role_at_backend == OpenAIBackendRole.ASSISTANT + and isinstance(last_record.message, FunctionCallingMessage) + and last_record.message.result is None + ): + return [last_idx] + + # Case B: Last is TOOL/FUNCTION result, try id-based pairing + if last_record.role_at_backend in { + OpenAIBackendRole.TOOL, + OpenAIBackendRole.FUNCTION, + }: + tool_id = None + if isinstance(last_record.message, FunctionCallingMessage): + tool_id = last_record.message.tool_call_id + + if tool_id: + for idx in range(len(recent_records) - 2, -1, -1): + rec = recent_records[idx].memory_record + if rec.role_at_backend != OpenAIBackendRole.ASSISTANT: + continue + + # Check if this assistant message contains the tool_call_id + matched = False + + # Case 1: FunctionCallingMessage (single tool call) + if isinstance(rec.message, FunctionCallingMessage): + if rec.message.tool_call_id == tool_id: + matched = True + + # Case 2: BaseMessage with multiple tool_calls in meta_dict + elif ( + hasattr(rec.message, "meta_dict") + and rec.message.meta_dict + ): + tool_calls_list = rec.message.meta_dict.get( + "tool_calls", [] + ) + if isinstance(tool_calls_list, list): + for tc in tool_calls_list: + if ( + isinstance(tc, dict) + and tc.get("id") == tool_id + ): + matched = True + break + + if matched: + # Return both assistant call and tool result indices + return [idx, last_idx] + + # Fallback: no tool_call_id, use heuristic + if self._is_tool_related_record(last_record): + # Remove last 2 (assume they are paired) + return [last_idx - 1, last_idx] if last_idx > 0 else [last_idx] + else: + return [last_idx] + + # Default: non tool-related tail => remove last one + return [last_idx] + + @staticmethod + def _serialize_tool_args(args: Dict[str, Any]) -> str: + try: + return json.dumps(args, ensure_ascii=False, sort_keys=True) + except TypeError: + return str(args) + + @classmethod + def _build_tool_signature( + cls, func_name: str, args: Dict[str, Any] + ) -> str: + args_repr = cls._serialize_tool_args(args) + return f"{func_name}:{args_repr}" + + def _describe_tool_call( + self, record: Optional[ToolCallingRecord] + ) -> Optional[str]: + if record is None: + return None + args_repr = self._serialize_tool_args(record.args) + return f"Tool `{record.tool_name}` invoked with arguments {args_repr}." + + def _format_tool_limit_notice(self) -> Optional[str]: + description = self._describe_tool_call(self._last_tool_call_record) + if description is None: + return None + return "[Tool Call Causing Token Limit]\n" f"{description}" + + def _reset_summary_state(self) -> None: + self._summary_state: Dict[str, Any] = { + "depth": 0, + "last_summary_log_length": 0, + "last_summary_user_signature": None, + "record_count_since_summary": 0, + "user_count_since_summary": 0, + "summary_token_count": 0, # Total tokens in summary messages + "last_summary_threshold": 0, # Last calculated threshold + "is_summary_message": set(), # Set of message that are summaries + } + + def _calculate_next_summary_threshold(self) -> int: + r"""Calculate the next token threshold that should trigger + summarization. + + The threshold calculation follows a progressive strategy: + - First time: token_limit * (summarize_threshold / 100) + - Subsequent times: (limit - summary_token) / 2 + summary_token + + This ensures that as summaries accumulate, the threshold adapts + to maintain a reasonable balance between context and summaries. + + Returns: + int: The token count threshold for next summarization. + """ + if self.summarize_threshold is None: + return 0 + + token_limit = self.model_backend.token_limit + summary_token_count = self._summary_state.get("summary_token_count", 0) + + # First summarization: use the percentage threshold + if summary_token_count == 0: + threshold = int(token_limit * self.summarize_threshold / 100) + else: + # Subsequent summarizations: adaptive threshold + threshold = int( + (token_limit - summary_token_count) + * self.summarize_threshold + / 100 + + summary_token_count + ) + + return threshold + + def _update_memory_with_summary(self, summary: Dict[str, Any]) -> None: + r"""Update memory with summary result. + + This method handles memory clearing and restoration of summaries based + on whether it's a progressive or full compression. + """ + if not summary or self.summarize_threshold is None: + return + + summary_with_prefix = summary.get("summary_with_prefix") + if not summary_with_prefix: + return + summary_status = summary.get("status", "") + include_summaries = summary.get("include_summaries", False) + + existing_summaries = [] + if not include_summaries: + messages, _ = self.memory.get_context() + for msg in messages: + content = msg.get('content', '') + if isinstance(content, str) and content.startswith( + '[CAMEL_SUMMARY]' + ): + existing_summaries.append(msg) + + # Clear memory + self.clear_memory() + + # Restore old summaries (for progressive compression) + for old_summary in existing_summaries: + content = old_summary.get('content', '') + if not isinstance(content, str): + content = str(content) + summary_msg = BaseMessage.make_assistant_message( + role_name="Assistant", content=content + ) + self.update_memory(summary_msg, OpenAIBackendRole.ASSISTANT) + + # Add new summary + new_summary_msg = BaseMessage.make_assistant_message( + role_name="Assistant", + content=summary_with_prefix + " " + summary_status, + ) + self.update_memory(new_summary_msg, OpenAIBackendRole.ASSISTANT) + + # Update token count + try: + summary_tokens = ( + self.model_backend.token_counter.count_tokens_from_messages( + [{"role": "assistant", "content": summary_with_prefix}] + ) + ) + + if include_summaries: # Full compression - reset count + self._summary_state["summary_token_count"] = summary_tokens + logger.info( + f"Full compression: Summary with {summary_tokens} tokens. " + f"Total summary tokens reset to: {summary_tokens}" + ) + else: # Progressive compression - accumulate + self._summary_state["summary_token_count"] += summary_tokens + logger.info( + f"Progressive compression: New summary " + f"with {summary_tokens} tokens. " + f"Total summary tokens: " + f"{self._summary_state['summary_token_count']}" + ) + except Exception as e: + logger.warning(f"Failed to count summary tokens: {e}") + + def _register_record_addition(self, role: OpenAIBackendRole) -> None: + state = getattr(self, "_summary_state", None) + if not state: + return + + state["record_count_since_summary"] = ( + state.get("record_count_since_summary", 0) + 1 + ) + + if role == OpenAIBackendRole.USER: + state["user_count_since_summary"] = ( + state.get("user_count_since_summary", 0) + 1 + ) + + def _can_attempt_summary(self) -> Tuple[bool, Optional[str]]: + state = getattr(self, "_summary_state", None) + if not state: + return True, None + + depth = state.get("depth", 0) + + if depth >= SUMMARY_MAX_DEPTH: + return ( + False, + "Maximum number of summaries reached for this session.", + ) + + if depth == 0: + return True, None + + if ( + state.get("record_count_since_summary", 0) + >= SUMMARY_PROGRESS_RECORD_THRESHOLD + ): + return True, None + + return ( + False, + "Context was summarized recently but the conversation did not add " + "enough new exchanges to summarize again.", + ) + + def _on_summary(self, records_before_summary: List[ContextRecord]) -> None: + state = getattr(self, "_summary_state", None) + if state is None: + return + + state["depth"] = state.get("depth", 0) + 1 + state["last_summary_log_length"] = len(records_before_summary) + state["record_count_since_summary"] = 0 + state["user_count_since_summary"] = 0 + state["last_summary_user_signature"] = ( + self._extract_last_user_signature(records_before_summary) + ) + + @staticmethod + def _extract_last_user_signature( + records: List[ContextRecord], + ) -> Optional[str]: + for context_record in reversed(records): + memory_record = context_record.memory_record + if memory_record.role_at_backend == OpenAIBackendRole.USER: + content = getattr(memory_record.message, "content", None) + if isinstance(content, str): + return content + return None + return None + def _get_external_tool_names(self) -> Set[str]: r"""Returns a set of external tool names.""" return set(self._external_tool_schemas.keys()) @@ -822,16 +1184,6 @@ def update_memory( ) -> None: r"""Updates the agent memory with a new message. - If the single *message* exceeds the model's context window, it will - be **automatically split into multiple smaller chunks** before being - written into memory. This prevents later failures in - `ScoreBasedContextCreator` where an over-sized message cannot fit - into the available token budget at all. - - This slicing logic handles both regular text messages (in the - `content` field) and long tool call results (in the `result` field of - a `FunctionCallingMessage`). - Args: message (BaseMessage): The new message to add to the stored messages. @@ -841,151 +1193,16 @@ def update_memory( (default: :obj:`None`) (default: obj:`None`) """ - - # 1. Helper to write a record to memory - def _write_single_record( - message: BaseMessage, role: OpenAIBackendRole, timestamp: float - ): - self.memory.write_record( - MemoryRecord( - message=message, - role_at_backend=role, - timestamp=timestamp, - agent_id=self.agent_id, - ) - ) - - base_ts = ( - timestamp + record = MemoryRecord( + message=message, + role_at_backend=role, + timestamp=timestamp if timestamp is not None - else time.time_ns() / 1_000_000_000 - ) - - # 2. Get token handling utilities, fallback if unavailable - try: - context_creator = self.memory.get_context_creator() - token_counter = context_creator.token_counter - token_limit = context_creator.token_limit - except AttributeError: - _write_single_record(message, role, base_ts) - return - - # 3. Check if slicing is necessary - try: - current_tokens = token_counter.count_tokens_from_messages( - [message.to_openai_message(role)] - ) - - _, ctx_tokens = self.memory.get_context() - - remaining_budget = max(0, token_limit - ctx_tokens) - - if current_tokens <= remaining_budget: - _write_single_record(message, role, base_ts) - return - except Exception as e: - logger.warning( - f"Token calculation failed before chunking, " - f"writing message as-is. Error: {e}" - ) - _write_single_record(message, role, base_ts) - return - - # 4. Perform slicing - logger.warning( - f"Message with {current_tokens} tokens exceeds remaining budget " - f"of {remaining_budget}. Slicing into smaller chunks." + else time.time_ns() / 1_000_000_000, # Nanosecond precision + agent_id=self.agent_id, ) - - text_to_chunk: Optional[str] = None - is_function_result = False - - if isinstance(message, FunctionCallingMessage) and isinstance( - message.result, str - ): - text_to_chunk = message.result - is_function_result = True - elif isinstance(message.content, str): - text_to_chunk = message.content - - if not text_to_chunk or not text_to_chunk.strip(): - _write_single_record(message, role, base_ts) - return - # Encode the entire text to get a list of all token IDs - try: - all_token_ids = token_counter.encode(text_to_chunk) - except Exception as e: - logger.error(f"Failed to encode text for chunking: {e}") - _write_single_record(message, role, base_ts) # Fallback - return - - if not all_token_ids: - _write_single_record(message, role, base_ts) # Nothing to chunk - return - - # 1. Base chunk size: one-tenth of the smaller of (a) total token - # limit and (b) current remaining budget. This prevents us from - # creating chunks that are guaranteed to overflow the - # immediate context window. - base_chunk_size = max(1, remaining_budget) // 10 - - # 2. Each chunk gets a textual prefix such as: - # "[chunk 3/12 of a long message]\n" - # The prefix itself consumes tokens, so if we do not subtract its - # length the *total* tokens of the outgoing message (prefix + body) - # can exceed the intended bound. We estimate the prefix length - # with a representative example that is safely long enough for the - # vast majority of cases (three-digit indices). - sample_prefix = "[chunk 1/1000 of a long message]\n" - prefix_token_len = len(token_counter.encode(sample_prefix)) - - # 3. The real capacity for the message body is therefore the base - # chunk size minus the prefix length. Fallback to at least one - # token to avoid zero or negative sizes. - chunk_body_limit = max(1, base_chunk_size - prefix_token_len) - - # 4. Calculate how many chunks we will need with this body size. - num_chunks = math.ceil(len(all_token_ids) / chunk_body_limit) - group_id = str(uuid.uuid4()) - - for i in range(num_chunks): - start_idx = i * chunk_body_limit - end_idx = start_idx + chunk_body_limit - chunk_token_ids = all_token_ids[start_idx:end_idx] - - chunk_body = token_counter.decode(chunk_token_ids) - - prefix = f"[chunk {i + 1}/{num_chunks} of a long message]\n" - new_body = prefix + chunk_body - - if is_function_result and isinstance( - message, FunctionCallingMessage - ): - new_msg: BaseMessage = FunctionCallingMessage( - role_name=message.role_name, - role_type=message.role_type, - meta_dict=message.meta_dict, - content=message.content, - func_name=message.func_name, - args=message.args, - result=new_body, - tool_call_id=message.tool_call_id, - ) - else: - new_msg = message.create_new_instance(new_body) - - meta = (new_msg.meta_dict or {}).copy() - meta.update( - { - "chunk_idx": i + 1, - "chunk_total": num_chunks, - "chunk_group_id": group_id, - } - ) - new_msg.meta_dict = meta - - # Increment timestamp slightly to maintain order - _write_single_record(new_msg, role, base_ts + i * 1e-6) + self.memory.write_record(record) + self._register_record_addition(role) def load_memory(self, memory: AgentMemory) -> None: r"""Load the provided memory into the agent. @@ -1070,6 +1287,7 @@ def summarize( summary_prompt: Optional[str] = None, response_format: Optional[Type[BaseModel]] = None, working_directory: Optional[Union[str, Path]] = None, + include_summaries: bool = False, ) -> Dict[str, Any]: r"""Summarize the agent's current conversation context and persist it to a markdown file. @@ -1089,6 +1307,11 @@ def summarize( defining the expected structure of the response. If provided, the summary will be generated as structured output and included in the result. + include_summaries (bool): Whether to include previously generated + summaries in the content to be summarized. If False (default), + only non-summary messages will be summarized. If True, all + messages including previous summaries will be summarized + (full compression). (default: :obj:`False`) working_directory (Optional[str|Path]): Optional directory to save the markdown summary file. If provided, overrides the default directory used by ContextUtility. @@ -1142,6 +1365,12 @@ def summarize( role = message.get('role', 'unknown') content = message.get('content', '') + # Skip summary messages if include_summaries is False + if not include_summaries and isinstance(content, str): + # Check if this is a summary message by looking for marker + if content.startswith('[CAMEL_SUMMARY]'): + continue + # Handle tool call messages (assistant calling tools) tool_calls = message.get('tool_calls') if tool_calls and isinstance(tool_calls, (list, tuple)): @@ -1321,6 +1550,13 @@ def summarize( } result.update(result_dict) + + # Add prefix to summary content in result + if self.summarize_threshold is not None: + summary_with_prefix = f"[CAMEL_SUMMARY] {summary_content}" + result["summary_with_prefix"] = summary_with_prefix + result["include_summaries"] = include_summaries + logger.info("Conversation summary saved to %s", file_path) return result @@ -1336,6 +1572,7 @@ async def asummarize( summary_prompt: Optional[str] = None, response_format: Optional[Type[BaseModel]] = None, working_directory: Optional[Union[str, Path]] = None, + include_summaries: bool = False, ) -> Dict[str, Any]: r"""Asynchronously summarize the agent's current conversation context and persist it to a markdown file. @@ -1358,6 +1595,11 @@ async def asummarize( working_directory (Optional[str|Path]): Optional directory to save the markdown summary file. If provided, overrides the default directory used by ContextUtility. + include_summaries (bool): Whether to include previously generated + summaries in the content to be summarized. If False (default), + only non-summary messages will be summarized. If True, all + messages including previous summaries will be summarized + (full compression). (default: :obj:`False`) Returns: Dict[str, Any]: A dictionary containing the summary text, file @@ -1398,6 +1640,12 @@ async def asummarize( role = message.get('role', 'unknown') content = message.get('content', '') + # Skip summary messages if include_summaries is False + if not include_summaries and isinstance(content, str): + # Check if this is a summary message by looking for marker + if content.startswith('[CAMEL_SUMMARY]'): + continue + # Handle tool call messages (assistant calling tools) tool_calls = message.get('tool_calls') if tool_calls and isinstance(tool_calls, (list, tuple)): @@ -1586,6 +1834,13 @@ async def asummarize( } result.update(result_dict) + + # Add prefix to summary content in result + if self.summarize_threshold is not None: + summary_with_prefix = f"[CAMEL_SUMMARY] {summary_content}" + result["summary_with_prefix"] = summary_with_prefix + result["include_summaries"] = include_summaries + logger.info("Conversation summary saved to %s", file_path) return result @@ -1604,7 +1859,14 @@ def clear_memory(self) -> None: self.memory.clear() if self.system_message is not None: - self.update_memory(self.system_message, OpenAIBackendRole.SYSTEM) + self.memory.write_record( + MemoryRecord( + message=self.system_message, + role_at_backend=OpenAIBackendRole.SYSTEM, + timestamp=time.time_ns() / 1_000_000_000, + agent_id=self.agent_id, + ) + ) def _generate_system_message_for_output_language( self, @@ -1638,17 +1900,8 @@ def init_messages(self) -> None: r"""Initializes the stored messages list with the current system message. """ - self.memory.clear() - # Write system message to memory if provided - if self.system_message is not None: - self.memory.write_record( - MemoryRecord( - message=self.system_message, - role_at_backend=OpenAIBackendRole.SYSTEM, - timestamp=time.time_ns() / 1_000_000_000, - agent_id=self.agent_id, - ) - ) + self._reset_summary_state() + self.clear_memory() def reset_to_original_system_message(self) -> None: r"""Reset system message to original, removing any appended context. @@ -2072,22 +2325,136 @@ def _step_impl( try: openai_messages, num_tokens = self.memory.get_context() + if self.summarize_threshold is not None: + threshold = self._calculate_next_summary_threshold() + summary_token_count = self._summary_state.get( + "summary_token_count", 0 + ) + token_limit = self.model_backend.token_limit + + if num_tokens <= token_limit: + # Check if summary tokens exceed 60% of limit + # If so, perform full compression (including summaries) + if summary_token_count > token_limit * 0.6: + logger.info( + f"Summary tokens ({summary_token_count}) " + f"exceed 60% of limit ({token_limit * 0.6}). " + f"Performing full compression." + ) + # Summarize everything (including summaries) + summary = self.summarize(include_summaries=True) + self._update_memory_with_summary(summary) + elif num_tokens > threshold: + logger.info( + f"Token count ({num_tokens}) exceed threshold " + f"({threshold}). Triggering summarization." + ) + # Only summarize non-summary content + summary = self.summarize(include_summaries=False) + self._update_memory_with_summary(summary) accumulated_context_tokens += num_tokens except RuntimeError as e: return self._step_terminate( e.args[1], tool_call_records, "max_tokens_exceeded" ) - # Get response from model backend - response = self._get_model_response( - openai_messages, - num_tokens=num_tokens, - current_iteration=iteration_count, - response_format=response_format, - tool_schemas=[] - if disable_tools - else self._get_full_tool_schemas(), - prev_num_openai_messages=prev_num_openai_messages, - ) + # Get response from model backend with token limit error handling + try: + response = self._get_model_response( + openai_messages, + num_tokens=num_tokens, + current_iteration=iteration_count, + response_format=response_format, + tool_schemas=[] + if disable_tools + else self._get_full_tool_schemas(), + prev_num_openai_messages=prev_num_openai_messages, + ) + except Exception as exc: + logger.exception("Model error: %s", exc) + + if self._is_token_limit_error(exc): + tool_signature = self._last_tool_call_signature + if ( + tool_signature is not None + and tool_signature + == self._last_token_limit_tool_signature + ): + description = self._describe_tool_call( + self._last_tool_call_record + ) + repeated_msg = ( + "Context exceeded again by the same tool call." + ) + if description: + repeated_msg += f" {description}" + raise RuntimeError(repeated_msg) from exc + + allowed, reason = self._can_attempt_summary() + if not allowed: + error_message = ( + "Context limit exceeded and further summarization " + "is not possible." + ) + if reason: + error_message += f" {reason}" + raise RuntimeError(error_message) from exc + + logger.warning( + "Token limit exceeded error detected. " + "Summarizing context." + ) + + recent_records: List[ContextRecord] + try: + recent_records = self.memory.retrieve() + except Exception: # pragma: no cover - defensive guard + recent_records = [] + + indices_to_remove = ( + self._find_indices_to_remove_for_last_tool_pair( + recent_records + ) + ) + self.memory.remove_records_by_indices(indices_to_remove) + + summary = self.summarize() + if isinstance(input_message, BaseMessage): + input_message = input_message.content + + tool_notice = self._format_tool_limit_notice() + summary_messages = ( + "Due to token limit being exceeded in this request, " + "the conversation has been summarized." + "the content after [Input Message] is the summary " + "of [Input Message]" + + "" + + str(input_message) + + "" + + "\n\n" + "[Context Summary]\n\n" + + summary.get("summary", "") + + "\n\n[Summary status]\n\n" + + str(summary.get("status", "")) + + "\n\nThe following is the user's original input. " + "Please continue based on existing information." + ) + if tool_notice: + summary_messages += "\n\n" + tool_notice + self.clear_memory() + summary_messages = BaseMessage.make_assistant_message( + role_name="Assistant", content=summary_messages + ) + self.update_memory( + summary_messages, OpenAIBackendRole.ASSISTANT + ) + self._on_summary(recent_records) + + self._last_token_limit_tool_signature = tool_signature + + return self._step_impl(input_message, response_format) + + raise + prev_num_openai_messages = len(openai_messages) iteration_count += 1 @@ -2282,6 +2649,7 @@ async def _astep_non_streaming_task( step_token_usage = self._create_token_usage_tracker() iteration_count: int = 0 prev_num_openai_messages: int = 0 + while True: if self.pause_event is not None and not self.pause_event.is_set(): if isinstance(self.pause_event, asyncio.Event): @@ -2292,21 +2660,140 @@ async def _astep_non_streaming_task( await loop.run_in_executor(None, self.pause_event.wait) try: openai_messages, num_tokens = self.memory.get_context() + if self.summarize_threshold is not None: + threshold = self._calculate_next_summary_threshold() + summary_token_count = self._summary_state.get( + "summary_token_count", 0 + ) + token_limit = self.model_backend.token_limit + + if num_tokens <= token_limit: + # Check if summary tokens exceed 60% of limit + # If so, perform full compression (including summaries) + if summary_token_count > token_limit * 0.6: + logger.info( + f"Summary tokens ({summary_token_count}) " + f"exceed 60% of limit ({token_limit * 0.6}). " + f"Full compression." + ) + # Summarize everything (including summaries) + summary = await self.asummarize( + include_summaries=True + ) + self._update_memory_with_summary(summary) + elif num_tokens > threshold: + logger.info( + f"Token count ({num_tokens}) exceed threshold " + "({threshold}). Triggering summarization." + ) + # Only summarize non-summary content + summary = await self.asummarize( + include_summaries=False + ) + self._update_memory_with_summary(summary) accumulated_context_tokens += num_tokens except RuntimeError as e: return self._step_terminate( e.args[1], tool_call_records, "max_tokens_exceeded" ) - response = await self._aget_model_response( - openai_messages, - num_tokens=num_tokens, - current_iteration=iteration_count, - response_format=response_format, - tool_schemas=[] - if disable_tools - else self._get_full_tool_schemas(), - prev_num_openai_messages=prev_num_openai_messages, - ) + # Get response from model backend with token limit error handling + try: + response = await self._aget_model_response( + openai_messages, + num_tokens=num_tokens, + current_iteration=iteration_count, + response_format=response_format, + tool_schemas=[] + if disable_tools + else self._get_full_tool_schemas(), + prev_num_openai_messages=prev_num_openai_messages, + ) + except Exception as exc: + if self._is_token_limit_error(exc): + tool_signature = self._last_tool_call_signature + if ( + tool_signature is not None + and tool_signature + == self._last_token_limit_tool_signature + ): + description = self._describe_tool_call( + self._last_tool_call_record + ) + repeated_msg = ( + "Context exceeded again by the same tool call." + ) + if description: + repeated_msg += f" {description}" + raise RuntimeError(repeated_msg) from exc + + allowed, reason = self._can_attempt_summary() + if not allowed: + error_message = ( + "Context limit exceeded and further summarization " + "is not possible." + ) + if reason: + error_message += f" {reason}" + raise RuntimeError(error_message) from exc + + logger.warning( + "Token limit exceeded error detected. " + "Summarizing context." + ) + + recent_records: List[ContextRecord] + try: + recent_records = self.memory.retrieve() + except Exception: # pragma: no cover - defensive guard + recent_records = [] + + indices_to_remove = ( + self._find_indices_to_remove_for_last_tool_pair( + recent_records + ) + ) + self.memory.remove_records_by_indices(indices_to_remove) + + summary = await self.asummarize() + if isinstance(input_message, BaseMessage): + input_message = input_message.content + + tool_notice = self._format_tool_limit_notice() + summary_messages = ( + "Due to token limit being exceeded in this request, " + "the conversation has been summarized." + "the content after [Input Message] is the summary" + + "" + + str(input_message) + + "" + + "\n\n" + "[Context Summary]\n\n" + + summary.get("summary", "") + + "\n\n[Summary status]\n\n" + + str(summary.get("status", "")) + + "\n\nThe following is the user's original input. " + "Please continue based on existing information." + ) + + if tool_notice: + summary_messages += "\n\n" + tool_notice + self.clear_memory() + summary_messages = BaseMessage.make_assistant_message( + role_name="Assistant", content=summary_messages + ) + self.update_memory( + summary_messages, OpenAIBackendRole.ASSISTANT + ) + self._on_summary(recent_records) + + self._last_token_limit_tool_signature = tool_signature + + return await self._astep_non_streaming_task( + input_message, response_format + ) + + raise + prev_num_openai_messages = len(openai_messages) iteration_count += 1 @@ -2479,6 +2966,8 @@ def _get_model_response( if response: break except RateLimitError as e: + if self._is_token_limit_error(e): + raise last_error = e if attempt < self.retry_attempts - 1: delay = min(self.retry_delay * (2**attempt), 60.0) @@ -2496,7 +2985,6 @@ def _get_model_response( except Exception: logger.error( f"Model error: {self.model_backend.model_type}", - exc_info=True, ) raise else: @@ -2543,6 +3031,8 @@ async def _aget_model_response( if response: break except RateLimitError as e: + if self._is_token_limit_error(e): + raise last_error = e if attempt < self.retry_attempts - 1: delay = min(self.retry_delay * (2**attempt), 60.0) diff --git a/camel/memories/agent_memories.py b/camel/memories/agent_memories.py index 01abee68a1..1789cec469 100644 --- a/camel/memories/agent_memories.py +++ b/camel/memories/agent_memories.py @@ -129,6 +129,16 @@ def clean_tool_calls(self) -> None: # Save the modified records back to storage self._chat_history_block.storage.save(record_dicts) + def pop_records(self, count: int) -> List[MemoryRecord]: + r"""Removes the most recent records from chat history memory.""" + return self._chat_history_block.pop_records(count) + + def remove_records_by_indices( + self, indices: List[int] + ) -> List[MemoryRecord]: + r"""Removes records at specified indices from chat history memory.""" + return self._chat_history_block.remove_records_by_indices(indices) + class VectorDBMemory(AgentMemory): r"""An agent memory wrapper of :obj:`VectorDBBlock`. This memory queries @@ -193,6 +203,20 @@ def clear(self) -> None: r"""Removes all records from the vector database memory.""" self._vectordb_block.clear() + def pop_records(self, count: int) -> List[MemoryRecord]: + r"""Rolling back is unsupported for vector database memory.""" + raise NotImplementedError( + "VectorDBMemory does not support removing historical records." + ) + + def remove_records_by_indices( + self, indices: List[int] + ) -> List[MemoryRecord]: + r"""Removing by indices is unsupported for vector database memory.""" + raise NotImplementedError( + "VectorDBMemory does not support removing records by indices." + ) + class LongtermAgentMemory(AgentMemory): r"""An implementation of the :obj:`AgentMemory` abstract base class for @@ -277,3 +301,13 @@ def clear(self) -> None: r"""Removes all records from the memory.""" self.chat_history_block.clear() self.vector_db_block.clear() + + def pop_records(self, count: int) -> List[MemoryRecord]: + r"""Removes recent chat history records while leaving vector memory.""" + return self.chat_history_block.pop_records(count) + + def remove_records_by_indices( + self, indices: List[int] + ) -> List[MemoryRecord]: + r"""Removes records at specified indices from chat history.""" + return self.chat_history_block.remove_records_by_indices(indices) diff --git a/camel/memories/base.py b/camel/memories/base.py index f9d4a0ad83..4cf2ccce15 100644 --- a/camel/memories/base.py +++ b/camel/memories/base.py @@ -45,6 +45,32 @@ def write_record(self, record: MemoryRecord) -> None: """ self.write_records([record]) + def pop_records(self, count: int) -> List[MemoryRecord]: + r"""Removes records from the memory and returns the removed records. + + Args: + count (int): Number of records to remove. + + Returns: + List[MemoryRecord]: The records that were removed from the memory + in their original order. + """ + raise NotImplementedError + + def remove_records_by_indices( + self, indices: List[int] + ) -> List[MemoryRecord]: + r"""Removes records at specified indices from the memory. + + Args: + indices (List[int]): List of indices to remove. Indices should be + valid positions in the current record list. + + Returns: + List[MemoryRecord]: The removed records in their original order. + """ + raise NotImplementedError + @abstractmethod def clear(self) -> None: r"""Clears all messages from the memory.""" diff --git a/camel/memories/blocks/chat_history_block.py b/camel/memories/blocks/chat_history_block.py index 8beaa909e4..1f311f131a 100644 --- a/camel/memories/blocks/chat_history_block.py +++ b/camel/memories/blocks/chat_history_block.py @@ -167,3 +167,118 @@ def write_records(self, records: List[MemoryRecord]) -> None: def clear(self) -> None: r"""Clears all chat messages from the memory.""" self.storage.clear() + + def pop_records(self, count: int) -> List[MemoryRecord]: + r"""Removes the most recent records from the memory. + + Args: + count (int): Number of records to remove from the end of the + conversation history. A value of 0 results in no changes. + + Returns: + List[MemoryRecord]: The removed records in chronological order. + """ + if not isinstance(count, int): + raise TypeError("`count` must be an integer.") + if count < 0: + raise ValueError("`count` must be non-negative.") + if count == 0: + return [] + + record_dicts = self.storage.load() + if not record_dicts: + return [] + + # Preserve initial system/developer instruction if present. + protected_prefix = ( + 1 + if ( + record_dicts + and record_dicts[0]['role_at_backend'] + in { + OpenAIBackendRole.SYSTEM.value, + OpenAIBackendRole.DEVELOPER.value, + } + ) + else 0 + ) + + removable_count = max(len(record_dicts) - protected_prefix, 0) + if removable_count == 0: + return [] + + pop_count = min(count, removable_count) + split_index = len(record_dicts) - pop_count + + popped_dicts = record_dicts[split_index:] + remaining_dicts = record_dicts[:split_index] + + self.storage.clear() + if remaining_dicts: + self.storage.save(remaining_dicts) + + return [MemoryRecord.from_dict(record) for record in popped_dicts] + + def remove_records_by_indices( + self, indices: List[int] + ) -> List[MemoryRecord]: + r"""Removes records at specified indices from the memory. + + Args: + indices (List[int]): List of indices to remove. Indices are + positions in the current record list (0-based). + System/developer messages at index 0 are protected and will + not be removed. + + Returns: + List[MemoryRecord]: The removed records in their original order. + """ + if not indices: + return [] + + record_dicts = self.storage.load() + if not record_dicts: + return [] + + # Preserve initial system/developer instruction if present. + protected_prefix = ( + 1 + if ( + record_dicts + and record_dicts[0]['role_at_backend'] + in { + OpenAIBackendRole.SYSTEM.value, + OpenAIBackendRole.DEVELOPER.value, + } + ) + else 0 + ) + + # Filter out protected indices and invalid ones + valid_indices = sorted( + { + idx + for idx in indices + if idx >= protected_prefix and idx < len(record_dicts) + } + ) + + if not valid_indices: + return [] + + # Extract records to remove (in original order) + removed_records = [record_dicts[idx] for idx in valid_indices] + + # Build remaining records by excluding removed indices + remaining_dicts = [ + record + for idx, record in enumerate(record_dicts) + if idx not in valid_indices + ] + + # Save back to storage + self.storage.clear() + if remaining_dicts: + self.storage.save(remaining_dicts) + + return [MemoryRecord.from_dict(record) for record in removed_records] diff --git a/camel/memories/context_creators/score_based.py b/camel/memories/context_creators/score_based.py index 6d2b9ea349..6733a38f8e 100644 --- a/camel/memories/context_creators/score_based.py +++ b/camel/memories/context_creators/score_based.py @@ -11,41 +11,24 @@ # See the License for the specific language governing permissions and # limitations under the License. # ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. ========= -from collections import defaultdict -from typing import Dict, List, Optional, Tuple -from pydantic import BaseModel +from typing import List, Optional, Tuple -from camel.logger import get_logger from camel.memories.base import BaseContextCreator from camel.memories.records import ContextRecord -from camel.messages import FunctionCallingMessage, OpenAIMessage +from camel.messages import OpenAIMessage from camel.types.enums import OpenAIBackendRole from camel.utils import BaseTokenCounter -logger = get_logger(__name__) - - -class _ContextUnit(BaseModel): - idx: int - record: ContextRecord - num_tokens: int - class ScoreBasedContextCreator(BaseContextCreator): - r"""A default implementation of context creation strategy, which inherits - from :obj:`BaseContextCreator`. - - This class provides a strategy to generate a conversational context from - a list of chat history records while ensuring the total token count of - the context does not exceed a specified limit. It prunes messages based - on their score if the total token count exceeds the limit. + r"""A context creation strategy that orders records chronologically. Args: - token_counter (BaseTokenCounter): An instance responsible for counting - tokens in a message. - token_limit (int): The maximum number of tokens allowed in the - generated context. + token_counter (BaseTokenCounter): Token counter instance used to + compute the combined token count of the returned messages. + token_limit (int): Retained for API compatibility. No longer used to + filter records. """ def __init__( @@ -66,376 +49,34 @@ def create_context( self, records: List[ContextRecord], ) -> Tuple[List[OpenAIMessage], int]: - r"""Constructs conversation context from chat history while respecting - token limits. - - Key strategies: - 1. System message is always prioritized and preserved - 2. Truncation removes low-score messages first - 3. Final output maintains chronological order and in history memory, - the score of each message decreases according to keep_rate. The - newer the message, the higher the score. - 4. Tool calls and their responses are kept together to maintain - API compatibility - - Args: - records (List[ContextRecord]): List of context records with scores - and timestamps. - - Returns: - Tuple[List[OpenAIMessage], int]: - - Ordered list of OpenAI messages - - Total token count of the final context - - Raises: - RuntimeError: If system message alone exceeds token limit - """ - # ====================== - # 1. System Message Handling - # ====================== - system_unit, regular_units = self._extract_system_message(records) - system_tokens = system_unit.num_tokens if system_unit else 0 + """Returns messages sorted by timestamp and their total token count.""" - # Check early if system message alone exceeds token limit - if system_tokens > self.token_limit: - raise RuntimeError( - f"System message alone exceeds token limit" - f": {system_tokens} > {self.token_limit}", - system_tokens, - ) + system_record: Optional[ContextRecord] = None + remaining_records: List[ContextRecord] = [] - # ====================== - # 2. Deduplication & Initial Processing - # ====================== - seen_uuids = set() - if system_unit: - seen_uuids.add(system_unit.record.memory_record.uuid) - - # Process non-system messages with deduplication - for idx, record in enumerate(records): + for record in records: if ( - record.memory_record.role_at_backend + system_record is None + and record.memory_record.role_at_backend == OpenAIBackendRole.SYSTEM ): + system_record = record continue - if record.memory_record.uuid in seen_uuids: - continue - seen_uuids.add(record.memory_record.uuid) - - token_count = self.token_counter.count_tokens_from_messages( - [record.memory_record.to_openai_message()] - ) - regular_units.append( - _ContextUnit( - idx=idx, - record=record, - num_tokens=token_count, - ) - ) - - # ====================== - # 3. Tool Call Relationship Mapping - # ====================== - tool_call_groups = self._group_tool_calls_and_responses(regular_units) - - # ====================== - # 4. Token Calculation - # ====================== - total_tokens = system_tokens + sum(u.num_tokens for u in regular_units) - - # ====================== - # 5. Early Return if Within Limit - # ====================== - if total_tokens <= self.token_limit: - sorted_units = sorted( - regular_units, key=self._conversation_sort_key - ) - return self._assemble_output(sorted_units, system_unit) - - # ====================== - # 6. Truncation Logic with Tool Call Awareness - # ====================== - remaining_units = self._truncate_with_tool_call_awareness( - regular_units, tool_call_groups, system_tokens - ) - - # Log only after truncation is actually performed so that both - # the original and the final token counts are visible. - tokens_after = system_tokens + sum( - u.num_tokens for u in remaining_units - ) - logger.warning( - "Context truncation performed: " - f"before={total_tokens}, after={tokens_after}, " - f"limit={self.token_limit}" - ) - - # ====================== - # 7. Output Assembly - # ====================== - - # In case system message is the only message in memory when sorted - # units are empty, raise an error - if system_unit and len(remaining_units) == 0 and len(records) > 1: - raise RuntimeError( - "System message and current message exceeds token limit ", - total_tokens, - ) - - # Sort remaining units chronologically - final_units = sorted(remaining_units, key=self._conversation_sort_key) - return self._assemble_output(final_units, system_unit) - - def _group_tool_calls_and_responses( - self, units: List[_ContextUnit] - ) -> Dict[str, List[_ContextUnit]]: - r"""Groups tool calls with their corresponding responses based on - `tool_call_id`. - - This improved logic robustly gathers all messages (assistant requests - and tool responses, including chunks) that share a `tool_call_id`. - - Args: - units (List[_ContextUnit]): List of context units to analyze. - - Returns: - Dict[str, List[_ContextUnit]]: Mapping from `tool_call_id` to a - list of related units. - """ - tool_call_groups: Dict[str, List[_ContextUnit]] = defaultdict(list) - - for unit in units: - # FunctionCallingMessage stores tool_call_id. - message = unit.record.memory_record.message - tool_call_id = getattr(message, 'tool_call_id', None) - - if tool_call_id: - tool_call_groups[tool_call_id].append(unit) - - # Filter out empty or incomplete groups if necessary, - # though defaultdict and getattr handle this gracefully. - return dict(tool_call_groups) - - def _truncate_with_tool_call_awareness( - self, - regular_units: List[_ContextUnit], - tool_call_groups: Dict[str, List[_ContextUnit]], - system_tokens: int, - ) -> List[_ContextUnit]: - r"""Truncates messages while preserving tool call-response pairs. - This method implements a more sophisticated truncation strategy: - 1. It treats tool call groups (request + responses) and standalone - messages as individual items to be included. - 2. It sorts all items by score and greedily adds them to the context. - 3. **Partial Truncation**: If a complete tool group is too large to - fit,it attempts to add the request message and as many of the most - recent response chunks as the token budget allows. - - Args: - regular_units (List[_ContextUnit]): All regular message units. - tool_call_groups (Dict[str, List[_ContextUnit]]): Grouped tool - calls. - system_tokens (int): Tokens used by the system message. - - Returns: - List[_ContextUnit]: A list of units that fit within the token - limit. - """ - - # Create a set for quick lookup of units belonging to any tool call - tool_call_unit_ids = { - unit.record.memory_record.uuid - for group in tool_call_groups.values() - for unit in group - } - - # Separate standalone units from tool call groups - standalone_units = [ - u - for u in regular_units - if u.record.memory_record.uuid not in tool_call_unit_ids - ] - - # Prepare all items (standalone units and groups) for sorting - all_potential_items: List[Dict] = [] - for unit in standalone_units: - all_potential_items.append( - { - "type": "standalone", - "score": unit.record.score, - "timestamp": unit.record.timestamp, - "tokens": unit.num_tokens, - "item": unit, - } - ) - for group in tool_call_groups.values(): - all_potential_items.append( - { - "type": "group", - "score": max(u.record.score for u in group), - "timestamp": max(u.record.timestamp for u in group), - "tokens": sum(u.num_tokens for u in group), - "item": group, - } - ) - - # Sort all potential items by score (high to low), then timestamp - all_potential_items.sort(key=lambda x: (-x["score"], -x["timestamp"])) - - remaining_units: List[_ContextUnit] = [] - current_tokens = system_tokens - - for item_dict in all_potential_items: - item_type = item_dict["type"] - item = item_dict["item"] - item_tokens = item_dict["tokens"] - - if current_tokens + item_tokens <= self.token_limit: - # The whole item (standalone or group) fits, so add it - if item_type == "standalone": - remaining_units.append(item) - else: # item_type == "group" - remaining_units.extend(item) - current_tokens += item_tokens - - elif item_type == "group": - # The group does not fit completely; try partial inclusion. - request_unit: Optional[_ContextUnit] = None - response_units: List[_ContextUnit] = [] - - for unit in item: - # Assistant msg with `args` is the request - if ( - isinstance( - unit.record.memory_record.message, - FunctionCallingMessage, - ) - and unit.record.memory_record.message.args is not None - ): - request_unit = unit - else: - response_units.append(unit) - - # A group must have a request to be considered for inclusion. - if request_unit is None: - continue - - # Check if we can at least fit the request. - if ( - current_tokens + request_unit.num_tokens - <= self.token_limit - ): - units_to_add = [request_unit] - tokens_to_add = request_unit.num_tokens - - # Sort responses by timestamp to add newest chunks first - response_units.sort( - key=lambda u: u.record.timestamp, reverse=True - ) + remaining_records.append(record) - for resp_unit in response_units: - if ( - current_tokens - + tokens_to_add - + resp_unit.num_tokens - <= self.token_limit - ): - units_to_add.append(resp_unit) - tokens_to_add += resp_unit.num_tokens + remaining_records.sort(key=lambda record: record.timestamp) - # A request must be followed by at least one response - if len(units_to_add) > 1: - remaining_units.extend(units_to_add) - current_tokens += tokens_to_add + messages: List[OpenAIMessage] = [] + if system_record is not None: + messages.append(system_record.memory_record.to_openai_message()) - return remaining_units - - def _extract_system_message( - self, records: List[ContextRecord] - ) -> Tuple[Optional[_ContextUnit], List[_ContextUnit]]: - r"""Extracts the system message from records and validates it. - - Args: - records (List[ContextRecord]): List of context records - representing conversation history. - - Returns: - Tuple[Optional[_ContextUnit], List[_ContextUnit]]: containing: - - The system message as a `_ContextUnit`, if valid; otherwise, - `None`. - - An empty list, serving as the initial container for regular - messages. - """ - if not records: - return None, [] - - first_record = records[0] - if ( - first_record.memory_record.role_at_backend - != OpenAIBackendRole.SYSTEM - ): - return None, [] - - message = first_record.memory_record.to_openai_message() - tokens = self.token_counter.count_tokens_from_messages([message]) - system_message_unit = _ContextUnit( - idx=0, - record=first_record, - num_tokens=tokens, + messages.extend( + record.memory_record.to_openai_message() + for record in remaining_records ) - return system_message_unit, [] - - def _conversation_sort_key( - self, unit: _ContextUnit - ) -> Tuple[float, float]: - r"""Defines the sorting key for assembling the final output. - - Sorting priority: - - Primary: Sort by timestamp in ascending order (chronological order). - - Secondary: Sort by score in descending order (higher scores first - when timestamps are equal). - - Args: - unit (_ContextUnit): A `_ContextUnit` representing a conversation - record. - - Returns: - Tuple[float, float]: - - Timestamp for chronological sorting. - - Negative score for descending order sorting. - """ - return (unit.record.timestamp, -unit.record.score) - - def _assemble_output( - self, - context_units: List[_ContextUnit], - system_unit: Optional[_ContextUnit], - ) -> Tuple[List[OpenAIMessage], int]: - r"""Assembles final message list with proper ordering and token count. - - Args: - context_units (List[_ContextUnit]): Sorted list of regular message - units. - system_unit (Optional[_ContextUnit]): System message unit (if - present). - - Returns: - Tuple[List[OpenAIMessage], int]: Tuple of (ordered messages, total - tokens) - """ - messages = [] - total_tokens = 0 - - # Add system message first if present - if system_unit: - messages.append( - system_unit.record.memory_record.to_openai_message() - ) - total_tokens += system_unit.num_tokens - # Add sorted regular messages - for unit in context_units: - messages.append(unit.record.memory_record.to_openai_message()) - total_tokens += unit.num_tokens + if not messages: + return [], 0 + total_tokens = self.token_counter.count_tokens_from_messages(messages) return messages, total_tokens diff --git a/camel/storages/vectordb_storages/oceanbase.py b/camel/storages/vectordb_storages/oceanbase.py index 7cfe4b5c16..c250e87b0b 100644 --- a/camel/storages/vectordb_storages/oceanbase.py +++ b/camel/storages/vectordb_storages/oceanbase.py @@ -121,10 +121,11 @@ def __init__( ) # Get the first index parameter - first_index_param = next(iter(index_params)) - self._client.create_vidx_with_vec_index_param( - table_name=self.table_name, vidx_param=first_index_param - ) + first_index_param = next(iter(index_params), None) + if first_index_param is not None: + self._client.create_vidx_with_vec_index_param( + table_name=self.table_name, vidx_param=first_index_param + ) logger.info(f"Created table {self.table_name} with vector index") else: diff --git a/test/agents/test_chat_agent.py b/test/agents/test_chat_agent.py index ffc8a06da5..2e6f7538eb 100644 --- a/test/agents/test_chat_agent.py +++ b/test/agents/test_chat_agent.py @@ -560,6 +560,21 @@ def test_chat_agent_step_exceed_token_number(step_call_count=3): system_message=system_msg, token_limit=1, ) + + original_get_context = assistant.memory.get_context + + def mock_get_context(): + messages, _ = original_get_context() + # Raise RuntimeError as if context size exceeded limit + raise RuntimeError( + "Context size exceeded", + { + "status": "error", + "message": "The context has exceeded the maximum token limit.", + }, + ) + + assistant.memory.get_context = mock_get_context assistant.model_backend.run = MagicMock( return_value=model_backend_rsp_base ) diff --git a/test/memories/test_chat_history_memory.py b/test/memories/test_chat_history_memory.py index 22daecfe84..3eee644402 100644 --- a/test/memories/test_chat_history_memory.py +++ b/test/memories/test_chat_history_memory.py @@ -91,3 +91,69 @@ def test_chat_history_memory(memory: ChatHistoryMemory): assert output_messages[0] == system_msg.to_openai_system_message() assert output_messages[1] == user_msg.to_openai_user_message() assert output_messages[2] == assistant_msg.to_openai_assistant_message() + + +@pytest.mark.parametrize("memory", ["in-memory", "json"], indirect=True) +def test_chat_history_memory_pop_records(memory: ChatHistoryMemory): + system_msg = BaseMessage( + "system", + role_type=RoleType.DEFAULT, + meta_dict=None, + content="System instructions", + ) + user_msgs = [ + BaseMessage( + "AI user", + role_type=RoleType.USER, + meta_dict=None, + content=f"Message {idx}", + ) + for idx in range(3) + ] + + records = [ + MemoryRecord( + message=system_msg, + role_at_backend=OpenAIBackendRole.SYSTEM, + timestamp=datetime.now().timestamp(), + agent_id="system", + ), + *[ + MemoryRecord( + message=msg, + role_at_backend=OpenAIBackendRole.USER, + timestamp=datetime.now().timestamp(), + agent_id="user", + ) + for msg in user_msgs + ], + ] + + memory.write_records(records) + + popped = memory.pop_records(2) + assert [record.message.content for record in popped] == [ + "Message 1", + "Message 2", + ] + + remaining_messages, _ = memory.get_context() + assert [msg['content'] for msg in remaining_messages] == [ + "System instructions", + "Message 0", + ] + + # Attempting to pop more than available should leave system message intact. + popped = memory.pop_records(5) + assert [record.message.content for record in popped] == ["Message 0"] + + remaining_messages, _ = memory.get_context() + assert [msg['content'] for msg in remaining_messages] == [ + "System instructions", + ] + + # Zero pop should be a no-op. + assert memory.pop_records(0) == [] + + with pytest.raises(ValueError): + memory.pop_records(-1)