Skip to content

Commit 6e91868

Browse files
authored
feat: Add support for multi-agent communication (#4)
## Overview This PR enhances the agent communication system by adding support for multi-agent messaging, allowing an agent to send messages to multiple recipients simultaneously. It also includes several improvements to error handling and message formatting. ### Key Changes - Added support for sending messages to multiple agents at once via `talk_to` method - Enhanced error handling for agent communication with better error messages - Improved message formatting for both single and multi-agent conversations - Added reminder in prompt template to close XML tags - Fixed return type annotations for `ask` and `talk_to` methods - Bumped version from 0.3.0 to 0.3.1 ### Implementation Details - Modified `Interaction` class to store receivers as a list of agents - Updated `AgentInteractionManager` to handle multiple receivers - Added support for "all" keyword to broadcast messages to all agents - Enhanced error messages for invalid agent IDs - Improved string representation of interactions for better readability ### Testing - Added new test cases for multi-agent communication - Updated existing tests to accommodate the new receiver list format - Added tests for manual and automatic action execution ### Breaking Changes None. All changes are backward compatible as single-agent communication continues to work as before.
1 parent 66c8f4a commit 6e91868

7 files changed

+111
-36
lines changed

agentarium/Agent.py

+14-4
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,8 @@ class Agent:
7979
<ACTION>
8080
[{{One of the following actions: {list_of_actions}}}]
8181
</ACTION>
82+
83+
Don't forget to close each tag that you open.
8284
"""
8385

8486
def __init__(self, **kwargs):
@@ -343,7 +345,12 @@ def act(self) -> str:
343345
temperature=0.4,
344346
)
345347

346-
regex_result = re.search(r"<ACTION>(.*?)</ACTION>", response.choices[0].message.content, re.DOTALL).group(1).strip()
348+
try:
349+
regex_result = re.search(r"<ACTION>(.*?)</ACTION>", response.choices[0].message.content, re.DOTALL).group(1).strip()
350+
except AttributeError as e:
351+
logging.error(f"Received a response without any action: {response.choices[0].message.content=}")
352+
raise e
353+
347354
action_name, *args = [value.replace("[", "").replace("]", "").strip() for value in regex_result.split("]")]
348355

349356
if action_name not in self._actions:
@@ -391,7 +398,7 @@ def get_interactions(self) -> List[Interaction]:
391398
"""
392399
return self._interaction_manager.get_agent_interactions(self)
393400

394-
def ask(self, message: str) -> None:
401+
def ask(self, message: str) -> str:
395402
"""
396403
Ask the agent a question and receive a contextually aware response.
397404
@@ -507,7 +514,7 @@ def remove_action(self, action_name: str) -> None:
507514

508515
del self._actions[action_name]
509516

510-
def talk_to(self, agent: Agent, message: str) -> None:
517+
def talk_to(self, agent: Agent | list[Agent], message: str) -> Dict[str, Any]:
511518
"""
512519
Send a message from one agent to another and record the interaction.
513520
"""
@@ -516,7 +523,10 @@ def talk_to(self, agent: Agent, message: str) -> None:
516523
# Did you really removed the default "talk" action and expect the talk_to method to work?
517524
raise RuntimeError("Talk action not found in the agent's action space.")
518525

519-
self.execute_action("talk", agent.agent_id, message)
526+
if isinstance(agent, Agent):
527+
return self.execute_action("talk", agent.agent_id, message)
528+
else:
529+
return self.execute_action("talk", ','.join([agent.agent_id for agent in agent]), message)
520530

521531
def think(self, message: str) -> None:
522532
"""

agentarium/AgentInteractionManager.py

+13-6
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@ def get_agent(self, agent_id: str) -> Agent | None:
7878
"""
7979
return self._agents.get(agent_id)
8080

81-
def record_interaction(self, sender: Agent, receiver: Agent, message: str) -> None:
81+
def record_interaction(self, sender: Agent, receiver: Agent | list[Agent], message: str) -> None:
8282
"""
8383
Record a new interaction between two agents.
8484
@@ -87,15 +87,21 @@ def record_interaction(self, sender: Agent, receiver: Agent, message: str) -> No
8787
8888
Args:
8989
sender (Agent): The agent initiating the interaction.
90-
receiver (Agent): The agent receiving the interaction.
90+
receiver (Agent | list[Agent]): The agent(s) receiving the interaction.
9191
message (str): The content of the interaction.
9292
"""
9393

94+
from .Agent import Agent
95+
9496
if sender.agent_id not in self._agents:
9597
raise ValueError(f"Sender agent {sender.agent_id} is not registered in the interaction manager.")
9698

97-
if receiver.agent_id not in self._agents:
98-
raise ValueError(f"Receiver agent {receiver.agent_id} is not registered in the interaction manager.")
99+
if isinstance(receiver, Agent):
100+
receiver = [receiver]
101+
102+
if any(_receiver.agent_id not in self._agents for _receiver in receiver):
103+
invalid_receivers = [_receiver for _receiver in receiver if _receiver.agent_id not in self._agents]
104+
raise ValueError(f"Receiver agent(s) {invalid_receivers} not registered in the interaction manager.")
99105

100106
interaction = Interaction(sender=sender, receiver=receiver, message=message)
101107

@@ -105,8 +111,9 @@ def record_interaction(self, sender: Agent, receiver: Agent, message: str) -> No
105111
# Record in private interactions for both sender and receiver
106112
self._agent_private_interactions[sender.agent_id].append(interaction)
107113

108-
if receiver.agent_id != sender.agent_id:
109-
self._agent_private_interactions[receiver.agent_id].append(interaction)
114+
for _receiver in receiver:
115+
if _receiver.agent_id != sender.agent_id:
116+
self._agent_private_interactions[_receiver.agent_id].append(interaction)
110117

111118
def get_all_interactions(self) -> List[Interaction]:
112119
"""

agentarium/Interaction.py

+8-5
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,8 @@ class Interaction:
2323
sender: Agent
2424
"""The agent who initiated the interaction."""
2525

26-
receiver: Agent
27-
"""The agent who received the interaction."""
26+
receiver: list[Agent]
27+
"""The agent(s) who received the interaction."""
2828

2929
message: str
3030
"""The content of the interaction between the agents."""
@@ -35,7 +35,7 @@ def dump(self) -> dict:
3535
"""
3636
return {
3737
"sender": self.sender.agent_id,
38-
"receiver": self.receiver.agent_id,
38+
"receiver": [receiver.agent_id for receiver in self.receiver],
3939
"message": self.message,
4040
}
4141

@@ -46,7 +46,11 @@ def __str__(self) -> str:
4646
Returns:
4747
str: A formatted string showing sender, receiver, and the interaction message.
4848
"""
49-
return f"Interaction from {self.sender.name} ({self.sender.agent_id}) to {self.receiver.name} ({self.receiver.agent_id}): {self.message}"
49+
50+
if len(self.receiver) == 1:
51+
return f"{self.sender.name} ({self.sender.agent_id}) said to {self.receiver[0].name} ({self.receiver[0].agent_id}): {self.message}"
52+
else:
53+
return f"{self.sender.name} ({self.sender.agent_id}) said to {', '.join([_receiver.name + f'({_receiver.agent_id})' for _receiver in self.receiver])}: {self.message}"
5054

5155
def __repr__(self) -> str:
5256
"""
@@ -56,4 +60,3 @@ def __repr__(self) -> str:
5660
str: A formatted string showing sender, receiver, and the interaction message.
5761
"""
5862
return self.__str__()
59-

agentarium/__init__.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from .actions.Action import Action
44
from .Interaction import Interaction
55

6-
__version__ = "0.3.0"
6+
__version__ = "0.3.1"
77

88
__all__ = [
99
"Agent",

agentarium/actions/default_actions.py

+26-7
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ def talk_action_function(*args, **kwargs):
2020
Returns:
2121
dict: A dictionary containing:
2222
- sender (str): The ID of the sending agent
23-
- receiver (str): The ID of the receiving agent
23+
- receiver (str | list[str]): The ID of the receiving agent
2424
- message (str): The content of the message
2525
2626
Raises:
@@ -34,19 +34,32 @@ def talk_action_function(*args, **kwargs):
3434
if "agent" not in kwargs:
3535
raise RuntimeError(f"Couldn't find agent in {kwargs=} for TALK action")
3636

37-
receiver = kwargs["agent"]._interaction_manager.get_agent(args[0])
37+
# Retrieve the receiver IDs
38+
receiver_ids = [_rec.strip() for _rec in args[0].split(',') if len(_rec.strip()) != 0]
3839

39-
if receiver is None:
40+
# If the receiver is "all", talk to all agents
41+
if "all" in receiver_ids:
42+
receiver_ids = [agent.agent_id for agent in kwargs["agent"]._interaction_manager._agents.values()]
43+
44+
# Retrieve the receivers from their IDs
45+
receivers = [kwargs["agent"]._interaction_manager.get_agent(_receiver_id) for _receiver_id in receiver_ids]
46+
47+
# Check if the receivers are valid
48+
if any(receiver is None for receiver in receivers):
4049
logger.error(f"Received a TALK action with an invalid agent ID {args[0]=} {args=} {kwargs['agent']._interaction_manager._agents=}")
41-
raise RuntimeError(f"Received a TALK action with an invalid agent ID: {args[0]}. {args=}")
50+
raise RuntimeError(f"Received a TALK action with an invalid agent ID: {args[0]=}. {args=}")
4251

4352
message = args[1]
4453

45-
kwargs["agent"]._interaction_manager.record_interaction(kwargs["agent"], receiver, message)
54+
if len(message) == 0:
55+
logger.warning(f"Received empty message for talk_action_function: {args=}")
56+
57+
# Record the interaction
58+
kwargs["agent"]._interaction_manager.record_interaction(kwargs["agent"], receivers, message)
4659

4760
return {
4861
"sender": kwargs["agent"].agent_id,
49-
"receiver": receiver.agent_id,
62+
"receiver": [receiver.agent_id for receiver in receivers],
5063
"message": message,
5164
}
5265

@@ -88,9 +101,15 @@ def think_action_function(*args, **kwargs):
88101

89102

90103
# Create action instances at module level
104+
talk_action_description = """\
105+
Talk to agents by specifying their IDs followed by the content to say:
106+
- To talk to a single agent: Enter an agent ID (e.g. "1")
107+
- To talk to multiple agents: Enter IDs separated by commas (e.g. "1,2,3")
108+
- To talk to all agents: Enter "all"\
109+
"""
91110
talk_action = Action(
92111
name="talk",
93-
description="Talk to the agent with the given ID.",
112+
description=talk_action_description,
94113
parameters=["agent_id", "message"],
95114
function=talk_action_function,
96115
)

tests/test_agent.py

+38-6
Original file line numberDiff line numberDiff line change
@@ -67,15 +67,15 @@ def test_agent_interaction(agent_pair):
6767
# Check Alice's interactions
6868
alice_interactions = alice.get_interactions()
6969
assert len(alice_interactions) == 1
70-
assert alice_interactions[0].sender == alice
71-
assert alice_interactions[0].receiver == bob
70+
assert alice_interactions[0].sender.agent_id == alice.agent_id
71+
assert alice_interactions[0].receiver[0].agent_id == bob.agent_id
7272
assert alice_interactions[0].message == message
7373

7474
# Check Bob's interactions
7575
bob_interactions = bob.get_interactions()
7676
assert len(bob_interactions) == 1
77-
assert bob_interactions[0].sender == alice
78-
assert bob_interactions[0].receiver == bob
77+
assert bob_interactions[0].sender.agent_id == alice.agent_id
78+
assert bob_interactions[0].receiver[0].agent_id == bob.agent_id
7979
assert bob_interactions[0].message == message
8080

8181

@@ -89,8 +89,8 @@ def test_agent_think(agent_pair):
8989

9090
interactions = agent.get_interactions()
9191
assert len(interactions) == 1
92-
assert interactions[0].sender == agent
93-
assert interactions[0].receiver == agent
92+
assert interactions[0].sender.agent_id == agent.agent_id
93+
assert interactions[0].receiver[0].agent_id == agent.agent_id
9494
assert interactions[0].message == thought
9595

9696

@@ -110,3 +110,35 @@ def test_agent_reset(base_agent):
110110
base_agent.reset()
111111
assert len(base_agent.get_interactions()) == 0
112112
assert len(base_agent.storage) == 0
113+
114+
115+
def test_display_interaction(agent_pair):
116+
"""Test displaying an interaction."""
117+
alice, bob = agent_pair
118+
message = "Hello Bob!"
119+
120+
alice.talk_to(bob, message) # One receiver
121+
alice.talk_to([bob, alice], message) # Multiple receivers
122+
123+
# just check that it doesn't raise an error
124+
print(alice.get_interactions()[0]) # one receiver
125+
print(alice.get_interactions()[1]) # multiple receivers
126+
127+
128+
def test_agent_actions_manualy_executed(base_agent):
129+
"""Test agent actions."""
130+
131+
base_agent.add_action(Action("test", "A test action", ["message"], lambda message, *args, **kwargs: {"message": message}))
132+
133+
action = base_agent.execute_action("test", "test message")
134+
assert action["message"] == "test message"
135+
136+
137+
def test_agent_actions_automatically_executed(base_agent):
138+
"""Test agent actions."""
139+
140+
base_agent.add_action(Action("test", "A test action", ["message"], lambda message, *args, **kwargs: {"message": message}))
141+
base_agent.think("I need to do the action 'test' with the message 'test message'")
142+
143+
action = base_agent.act()
144+
assert action["message"] == "test message"

tests/test_interaction_manager.py

+11-7
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import pytest
2-
from agentarium import Agent, Interaction
3-
from agentarium.AgentInteractionManager import AgentInteractionManager
42

3+
from agentarium.Agent import Agent
4+
from agentarium.AgentInteractionManager import AgentInteractionManager
55

66
@pytest.fixture
77
def interaction_manager():
@@ -34,8 +34,8 @@ def test_interaction_recording(interaction_manager, agent_pair):
3434
assert len(alice.get_interactions()) == 1
3535

3636
interaction = alice.get_interactions()[0]
37-
assert interaction.sender is alice
38-
assert interaction.receiver is bob
37+
assert interaction.sender.agent_id == alice.agent_id
38+
assert interaction.receiver[0].agent_id == bob.agent_id
3939
assert interaction.message == message
4040

4141
def test_get_agent_interactions(interaction_manager, agent_pair):
@@ -89,16 +89,20 @@ def test_self_interaction(interaction_manager, agent_pair):
8989

9090
interactions = interaction_manager.get_agent_interactions(alice)
9191
assert len(interactions) == 1
92-
assert interactions[0].sender is alice
93-
assert interactions[0].receiver is alice
92+
assert interactions[0].sender.agent_id == alice.agent_id
93+
assert interactions[0].receiver[0].agent_id == alice.agent_id
9494
assert interactions[0].message == thought
9595

9696
def test_interaction_validation(interaction_manager, agent_pair):
9797
"""Test validation of interaction recording."""
9898

9999
alice, _ = agent_pair
100100

101-
unregistered_agent = type("UnregisteredAgent", (object,), {"agent_id": "some_unregistered_id"})
101+
unregistered_agent = type("UnregisteredAgent", (Agent,), {
102+
"agent_id": "some_unregistered_id",
103+
"__init__": lambda self, *args, **kwargs: None,
104+
"agent_informations": {},
105+
})()
102106

103107
with pytest.raises(ValueError):
104108
interaction_manager.record_interaction(unregistered_agent, alice, "Hello")

0 commit comments

Comments
 (0)