Skip to content

Commit d941a0a

Browse files
authored
fix some bugs (#14)
1 parent 52f6f79 commit d941a0a

File tree

4 files changed

+31
-17
lines changed

4 files changed

+31
-17
lines changed

examples/patterns.py

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -3,14 +3,16 @@
33

44
import openai
55
from agnext.agent_components.models_clients.openai_client import OpenAI
6-
from agnext.application_components.single_threaded_agent_runtime import SingleThreadedAgentRuntime
6+
from agnext.application_components.single_threaded_agent_runtime import (
7+
SingleThreadedAgentRuntime,
8+
)
79
from agnext.chat.agents.oai_assistant import OpenAIAssistantAgent
810
from agnext.chat.messages import ChatMessage
911
from agnext.chat.patterns.group_chat import GroupChat
1012
from agnext.chat.patterns.orchestrator import Orchestrator
1113

1214

13-
async def group_chat() -> None:
15+
async def group_chat(message: str) -> None:
1416
runtime = SingleThreadedAgentRuntime()
1517

1618
joe_oai_assistant = openai.beta.assistants.create(
@@ -44,22 +46,22 @@ async def group_chat() -> None:
4446
)
4547

4648
chat = GroupChat(
47-
"chat_room",
49+
"Host",
4850
"A round-robin chat room.",
4951
runtime,
5052
[joe, cathy],
5153
num_rounds=5,
5254
)
5355

54-
response = runtime.send_message(ChatMessage(body="Run a show!", sender="external"), chat)
56+
response = runtime.send_message(ChatMessage(body=message, sender="host"), chat)
5557

5658
while not response.done():
5759
await runtime.process_next()
5860

5961
print((await response).body) # type: ignore
6062

6163

62-
async def orchestrator() -> None:
64+
async def orchestrator(message: str) -> None:
6365
runtime = SingleThreadedAgentRuntime()
6466

6567
developer_oai_assistant = openai.beta.assistants.create(
@@ -93,16 +95,16 @@ async def orchestrator() -> None:
9395
)
9496

9597
chat = Orchestrator(
96-
"Team",
97-
"A software development team.",
98+
"Manager",
99+
"A software development team manager.",
98100
runtime,
99101
[developer, product_manager],
100102
model_client=OpenAI(model="gpt-3.5-turbo"),
101103
)
102104

103105
response = runtime.send_message(
104106
ChatMessage(
105-
body="Write a simple FastAPI webapp for showing the current time.",
107+
body=message,
106108
sender="customer",
107109
),
108110
chat,
@@ -122,11 +124,12 @@ async def orchestrator() -> None:
122124
choices=chocies,
123125
help="The pattern to demo.",
124126
)
127+
parser.add_argument("--message", help="The message to send.")
125128
args = parser.parse_args()
126129

127130
if args.pattern == "group_chat":
128-
asyncio.run(group_chat())
131+
asyncio.run(group_chat(args.message))
129132
elif args.pattern == "orchestrator":
130-
asyncio.run(orchestrator())
133+
asyncio.run(orchestrator(args.message))
131134
else:
132135
raise ValueError(f"Invalid pattern: {args.pattern}")

src/agnext/chat/patterns/group_chat.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
from typing import List, Sequence
22

3+
from ...agent_components.type_routed_agent import message_handler
34
from ...core.agent_runtime import AgentRuntime
5+
from ...core.cancellation_token import CancellationToken
46
from ..agents.base import BaseChatAgent
57
from ..messages import ChatMessage
68

@@ -19,7 +21,13 @@ def __init__(
1921
self._num_rounds = num_rounds
2022
self._history: List[ChatMessage] = []
2123

22-
async def on_chat_message(self, message: ChatMessage) -> ChatMessage:
24+
@message_handler(ChatMessage)
25+
async def on_chat_message(
26+
self,
27+
message: ChatMessage,
28+
require_response: bool,
29+
cancellation_token: CancellationToken,
30+
) -> ChatMessage | None:
2331
if message.reset:
2432
# Reset the history.
2533
self._history = []

src/agnext/chat/patterns/orchestrator.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,11 @@
11
import json
22
from typing import Any, List, Sequence, Tuple
33

4-
from agnext.core.agent_runtime import AgentRuntime
5-
from agnext.core.cancellation_token import CancellationToken
6-
74
from ...agent_components.model_client import ModelClient
85
from ...agent_components.type_routed_agent import message_handler
96
from ...agent_components.types import AssistantMessage, LLMMessage, UserMessage
7+
from ...core.agent_runtime import AgentRuntime
8+
from ...core.cancellation_token import CancellationToken
109
from ..agents.base import BaseChatAgent
1110
from ..messages import ChatMessage
1211

@@ -33,8 +32,11 @@ def __init__(
3332

3433
@message_handler(ChatMessage)
3534
async def on_chat_message(
36-
self, message: ChatMessage, require_response: bool, cancellation_token: CancellationToken
37-
) -> ChatMessage:
35+
self,
36+
message: ChatMessage,
37+
require_response: bool,
38+
cancellation_token: CancellationToken,
39+
) -> ChatMessage | None:
3840
# A task is received.
3941
task = message.body
4042

src/agnext/chat/utils.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
11
from typing import List, Optional, Union
22

3+
from typing_extensions import Literal
4+
35
from agnext.agent_components.types import AssistantMessage, LLMMessage, UserMessage
46
from agnext.chat.types import FunctionCallMessage, Message, MultiModalMessage, TextMessage
5-
from typing_extensions import Literal
67

78

89
def convert_content_message_to_assistant_message(

0 commit comments

Comments
 (0)