Skip to content

Commit

Permalink
Merge branch 'main' into lpinheiro/feat/add-vectordb-chroma-store
Browse files Browse the repository at this point in the history
  • Loading branch information
lspinheiro authored Oct 25, 2024
2 parents 4391961 + f31ff66 commit 8bb13f0
Show file tree
Hide file tree
Showing 24 changed files with 409 additions and 339 deletions.
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
from ._base_chat_agent import BaseChatAgent, BaseToolUseChatAgent
from ._base_chat_agent import BaseChatAgent
from ._code_executor_agent import CodeExecutorAgent
from ._coding_assistant_agent import CodingAssistantAgent
from ._tool_use_assistant_agent import ToolUseAssistantAgent

__all__ = [
"BaseChatAgent",
"BaseToolUseChatAgent",
"CodeExecutorAgent",
"CodingAssistantAgent",
"ToolUseAssistantAgent",
Expand Down
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
from abc import ABC, abstractmethod
from typing import List, Sequence
from typing import Sequence

from autogen_core.base import CancellationToken
from autogen_core.components.tools import Tool

from ..base import ChatAgent, TaskResult, TerminationCondition, ToolUseChatAgent
from ..base import ChatAgent, TaskResult, TerminationCondition
from ..messages import ChatMessage
from ..teams import RoundRobinGroupChat

Expand Down Expand Up @@ -51,21 +50,3 @@ async def run(
termination_condition=termination_condition,
)
return result


class BaseToolUseChatAgent(BaseChatAgent, ToolUseChatAgent):
"""Base class for a chat agent that can use tools.
Subclass this base class to create an agent class that uses tools by returning
ToolCallMessage message from the :meth:`on_messages` method and receiving
ToolCallResultMessage message from the input to the :meth:`on_messages` method.
"""

def __init__(self, name: str, description: str, registered_tools: List[Tool]) -> None:
super().__init__(name, description)
self._registered_tools = registered_tools

@property
def registered_tools(self) -> List[Tool]:
"""The list of tools that the agent can use."""
return self._registered_tools
Original file line number Diff line number Diff line change
@@ -1,29 +1,52 @@
import asyncio
import json
import logging
from typing import Any, Awaitable, Callable, List, Sequence

from autogen_core.base import CancellationToken
from autogen_core.components import FunctionCall
from autogen_core.components.models import (
AssistantMessage,
ChatCompletionClient,
FunctionExecutionResult,
FunctionExecutionResultMessage,
LLMMessage,
SystemMessage,
UserMessage,
)
from autogen_core.components.tools import FunctionTool, Tool
from pydantic import BaseModel, ConfigDict

from .. import EVENT_LOGGER_NAME
from ..messages import (
ChatMessage,
MultiModalMessage,
StopMessage,
TextMessage,
ToolCallMessage,
ToolCallResultMessage,
)
from ._base_chat_agent import BaseToolUseChatAgent
from ._base_chat_agent import BaseChatAgent

event_logger = logging.getLogger(EVENT_LOGGER_NAME)

class ToolUseAssistantAgent(BaseToolUseChatAgent):

class ToolCallEvent(BaseModel):
"""A tool call event."""

tool_calls: List[FunctionCall]
"""The tool call message."""

model_config = ConfigDict(arbitrary_types_allowed=True)


class ToolCallResultEvent(BaseModel):
"""A tool call result event."""

tool_call_results: List[FunctionExecutionResult]
"""The tool call result message."""

model_config = ConfigDict(arbitrary_types_allowed=True)


class ToolUseAssistantAgent(BaseChatAgent):
"""An agent that provides assistance with tool use.
It responds with a StopMessage when 'terminate' is detected in the response.
Expand All @@ -45,46 +68,50 @@ def __init__(
description: str = "An agent that provides assistance with ability to use tools.",
system_message: str = "You are a helpful AI assistant. Solve tasks using your tools. Reply with 'TERMINATE' when the task has been completed.",
):
tools: List[Tool] = []
super().__init__(name=name, description=description)
self._model_client = model_client
self._system_messages = [SystemMessage(content=system_message)]
self._tools: List[Tool] = []
for tool in registered_tools:
if isinstance(tool, Tool):
tools.append(tool)
self._tools.append(tool)
elif callable(tool):
if hasattr(tool, "__doc__") and tool.__doc__ is not None:
description = tool.__doc__
else:
description = ""
tools.append(FunctionTool(tool, description=description))
self._tools.append(FunctionTool(tool, description=description))
else:
raise ValueError(f"Unsupported tool type: {type(tool)}")
super().__init__(name=name, description=description, registered_tools=tools)
self._model_client = model_client
self._system_messages = [SystemMessage(content=system_message)]
self._tool_schema = [tool.schema for tool in tools]
self._model_context: List[LLMMessage] = []

async def on_messages(self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken) -> ChatMessage:
# Add messages to the model context.
for msg in messages:
if isinstance(msg, ToolCallResultMessage):
self._model_context.append(FunctionExecutionResultMessage(content=msg.content))
elif not isinstance(msg, TextMessage | MultiModalMessage | StopMessage):
raise ValueError(f"Unsupported message type: {type(msg)}")
else:
self._model_context.append(UserMessage(content=msg.content, source=msg.source))
# TODO: add special handling for handoff messages
self._model_context.append(UserMessage(content=msg.content, source=msg.source))

# Generate an inference result based on the current model context.
llm_messages = self._system_messages + self._model_context
result = await self._model_client.create(
llm_messages, tools=self._tool_schema, cancellation_token=cancellation_token
)
result = await self._model_client.create(llm_messages, tools=self._tools, cancellation_token=cancellation_token)

# Add the response to the model context.
self._model_context.append(AssistantMessage(content=result.content, source=self.name))

# Detect tool calls.
if isinstance(result.content, list) and all(isinstance(item, FunctionCall) for item in result.content):
return ToolCallMessage(content=result.content, source=self.name)
# Run tool calls until the model produces a string response.
while isinstance(result.content, list) and all(isinstance(item, FunctionCall) for item in result.content):
event_logger.debug(ToolCallEvent(tool_calls=result.content))
# Execute the tool calls.
results = await asyncio.gather(
*[self._execute_tool_call(call, cancellation_token) for call in result.content]
)
event_logger.debug(ToolCallResultEvent(tool_call_results=results))
self._model_context.append(FunctionExecutionResultMessage(content=results))
# Generate an inference result based on the current model context.
result = await self._model_client.create(
self._model_context, tools=self._tools, cancellation_token=cancellation_token
)
self._model_context.append(AssistantMessage(content=result.content, source=self.name))

assert isinstance(result.content, str)
# Detect stop request.
Expand All @@ -93,3 +120,20 @@ async def on_messages(self, messages: Sequence[ChatMessage], cancellation_token:
return StopMessage(content=result.content, source=self.name)

return TextMessage(content=result.content, source=self.name)

async def _execute_tool_call(
self, tool_call: FunctionCall, cancellation_token: CancellationToken
) -> FunctionExecutionResult:
"""Execute a tool call and return the result."""
try:
if not self._tools:
raise ValueError("No tools are available.")
tool = next((t for t in self._tools if t.name == tool_call.name), None)
if tool is None:
raise ValueError(f"The tool '{tool_call.name}' is not available.")
arguments = json.loads(tool_call.arguments)
result = await tool.run_json(arguments, cancellation_token)
result_as_str = tool.return_value_as_string(result)
return FunctionExecutionResult(content=result_as_str, call_id=tool_call.id)
except Exception as e:
return FunctionExecutionResult(content=f"Error: {e}", call_id=tool_call.id)
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
from ._chat_agent import ChatAgent, ToolUseChatAgent
from ._chat_agent import ChatAgent
from ._task import TaskResult, TaskRunner
from ._team import Team
from ._termination import TerminatedException, TerminationCondition

__all__ = [
"ChatAgent",
"ToolUseChatAgent",
"Team",
"TerminatedException",
"TerminationCondition",
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from typing import List, Protocol, Sequence, runtime_checkable
from typing import Protocol, Sequence, runtime_checkable

from autogen_core.base import CancellationToken
from autogen_core.components.tools import Tool

from ..messages import ChatMessage
from ._task import TaskResult, TaskRunner
Expand Down Expand Up @@ -38,13 +37,3 @@ async def run(
) -> TaskResult:
"""Run the agent with the given task and return the result."""
...


@runtime_checkable
class ToolUseChatAgent(ChatAgent, Protocol):
"""Protocol for a chat agent that can use tools."""

@property
def registered_tools(self) -> List[Tool]:
"""The list of tools that the agent can use."""
...
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,12 @@
import sys
from datetime import datetime

from ..agents._tool_use_assistant_agent import ToolCallEvent, ToolCallResultEvent
from ..messages import ChatMessage, StopMessage, TextMessage
from ..teams._events import (
ContentPublishEvent,
SelectSpeakerEvent,
GroupChatPublishEvent,
GroupChatSelectSpeakerEvent,
TerminationEvent,
ToolCallEvent,
ToolCallResultEvent,
)


Expand All @@ -25,7 +24,7 @@ def serialize_chat_message(message: ChatMessage) -> str:

def emit(self, record: logging.LogRecord) -> None:
ts = datetime.fromtimestamp(record.created).isoformat()
if isinstance(record.msg, ContentPublishEvent):
if isinstance(record.msg, GroupChatPublishEvent):
if record.msg.source is None:
sys.stdout.write(
f"\n{'-'*75} \n"
Expand All @@ -41,19 +40,15 @@ def emit(self, record: logging.LogRecord) -> None:
sys.stdout.flush()
elif isinstance(record.msg, ToolCallEvent):
sys.stdout.write(
f"\n{'-'*75} \n"
f"\033[91m[{ts}], Tool Call:\033[0m\n"
f"\n{self.serialize_chat_message(record.msg.agent_message)}"
f"\n{'-'*75} \n" f"\033[91m[{ts}], Tool Call:\033[0m\n" f"\n{str(record.msg.model_dump())}"
)
sys.stdout.flush()
elif isinstance(record.msg, ToolCallResultEvent):
sys.stdout.write(
f"\n{'-'*75} \n"
f"\033[91m[{ts}], Tool Call Result:\033[0m\n"
f"\n{self.serialize_chat_message(record.msg.agent_message)}"
f"\n{'-'*75} \n" f"\033[91m[{ts}], Tool Call Result:\033[0m\n" f"\n{str(record.msg.model_dump())}"
)
sys.stdout.flush()
elif isinstance(record.msg, SelectSpeakerEvent):
elif isinstance(record.msg, GroupChatSelectSpeakerEvent):
sys.stdout.write(
f"\n{'-'*75} \n" f"\033[91m[{ts}], Selected Next Speaker:\033[0m\n" f"\n{record.msg.selected_speaker}"
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,11 @@
from datetime import datetime
from typing import Any

from ..agents._tool_use_assistant_agent import ToolCallEvent, ToolCallResultEvent
from ..teams._events import (
ContentPublishEvent,
SelectSpeakerEvent,
GroupChatPublishEvent,
GroupChatSelectSpeakerEvent,
TerminationEvent,
ToolCallEvent,
ToolCallResultEvent,
)


Expand All @@ -21,7 +20,7 @@ def __init__(self, filename: str) -> None:

def emit(self, record: logging.LogRecord) -> None:
ts = datetime.fromtimestamp(record.created).isoformat()
if isinstance(record.msg, ContentPublishEvent | ToolCallEvent | ToolCallResultEvent | TerminationEvent):
if isinstance(record.msg, GroupChatPublishEvent | TerminationEvent):
log_entry = json.dumps(
{
"timestamp": ts,
Expand All @@ -31,7 +30,7 @@ def emit(self, record: logging.LogRecord) -> None:
},
default=self.json_serializer,
)
elif isinstance(record.msg, SelectSpeakerEvent):
elif isinstance(record.msg, GroupChatSelectSpeakerEvent):
log_entry = json.dumps(
{
"timestamp": ts,
Expand All @@ -41,6 +40,24 @@ def emit(self, record: logging.LogRecord) -> None:
},
default=self.json_serializer,
)
elif isinstance(record.msg, ToolCallEvent):
log_entry = json.dumps(
{
"timestamp": ts,
"tool_calls": record.msg.model_dump(),
"type": "ToolCallEvent",
},
default=self.json_serializer,
)
elif isinstance(record.msg, ToolCallResultEvent):
log_entry = json.dumps(
{
"timestamp": ts,
"tool_call_results": record.msg.model_dump(),
"type": "ToolCallResultEvent",
},
default=self.json_serializer,
)
else:
raise ValueError(f"Unexpected log record: {record.msg}")
file_record = logging.LogRecord(
Expand Down
29 changes: 10 additions & 19 deletions python/packages/autogen-agentchat/src/autogen_agentchat/messages.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from typing import List

from autogen_core.components import FunctionCall, Image
from autogen_core.components.models import FunctionExecutionResult
from autogen_core.components import Image
from pydantic import BaseModel


Expand All @@ -26,37 +25,29 @@ class MultiModalMessage(BaseMessage):
"""The content of the message."""


class ToolCallMessage(BaseMessage):
"""A message containing a list of function calls."""

content: List[FunctionCall]
"""The list of function calls."""


class ToolCallResultMessage(BaseMessage):
"""A message containing the results of function calls."""

content: List[FunctionExecutionResult]
"""The list of function execution results."""


class StopMessage(BaseMessage):
"""A message requesting stop of a conversation."""

content: str
"""The content for the stop message."""


ChatMessage = TextMessage | MultiModalMessage | StopMessage | ToolCallMessage | ToolCallResultMessage
class HandoffMessage(BaseMessage):
"""A message requesting handoff of a conversation to another agent."""

content: str
"""The agent name to handoff the conversation to."""


ChatMessage = TextMessage | MultiModalMessage | StopMessage | HandoffMessage
"""A message used by agents in a team."""


__all__ = [
"BaseMessage",
"TextMessage",
"MultiModalMessage",
"ToolCallMessage",
"ToolCallResultMessage",
"StopMessage",
"HandoffMessage",
"ChatMessage",
]
Loading

0 comments on commit 8bb13f0

Please sign in to comment.