Skip to content

Commit 29e6b91

Browse files
committed
refactor(workforce): Enhance worker initialization and callback handling in async context
1 parent 96e6a03 commit 29e6b91

File tree

2 files changed

+144
-179
lines changed

2 files changed

+144
-179
lines changed

camel/societies/workforce/workforce.py

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -339,6 +339,9 @@ def __init__(
339339
self.snapshot_interval: float = 30.0
340340
# Shared memory UUID tracking to prevent re-sharing duplicates
341341
self._shared_memory_uuids: Set[str] = set()
342+
# Defer initial worker-created callbacks until an event loop is
343+
# available in async context.
344+
self._pending_worker_created: Deque[BaseNode] = deque(self._children)
342345
self._initialize_callbacks(callbacks)
343346

344347
# Set up coordinator agent with default system message
@@ -532,9 +535,6 @@ def _initialize_callbacks(
532535
"WorkforceLogger addition."
533536
)
534537

535-
for child in self._children:
536-
asyncio.run(self._notify_worker_created(child))
537-
538538
async def _notify_worker_created(
539539
self,
540540
worker_node: BaseNode,
@@ -553,6 +553,18 @@ async def _notify_worker_created(
553553
for cb in self._callbacks:
554554
await cb.log_worker_created(event)
555555

556+
async def _flush_initial_worker_created_callbacks(self) -> None:
557+
r"""Flush pending worker-created callbacks that were queued during
558+
initialization before an event loop was available."""
559+
if not self._pending_worker_created:
560+
return
561+
562+
pending = list(self._pending_worker_created)
563+
self._pending_worker_created.clear()
564+
565+
for child in pending:
566+
await self._notify_worker_created(child)
567+
556568
def _get_or_create_shared_context_utility(
557569
self,
558570
session_id: Optional[str] = None,
@@ -2258,6 +2270,9 @@ async def process_task_async(
22582270
Returns:
22592271
Task: The updated task.
22602272
"""
2273+
# Emit worker-created callbacks lazily once an event loop is present.
2274+
await self._flush_initial_worker_created_callbacks()
2275+
22612276
# Delegate to intervention pipeline when requested to keep
22622277
# backward-compat.
22632278
if interactive:

test/workforce/test_workforce_callbacks.py

Lines changed: 126 additions & 176 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,8 @@
1818
from camel.agents import ChatAgent
1919
from camel.messages import BaseMessage
2020
from camel.models import ModelFactory
21+
from camel.societies.workforce import SingleAgentWorker
22+
from camel.societies.workforce.base import BaseNode
2123
from camel.societies.workforce.events import (
2224
AllTasksCompletedEvent,
2325
TaskAssignedEvent,
@@ -142,6 +144,80 @@ def _build_stub_agent() -> ChatAgent:
142144
return ChatAgent(model=model)
143145

144146

147+
def _build_persona_agent(role_name: str, content: str) -> ChatAgent:
148+
"""Construct a stub-backed ChatAgent with a system persona."""
149+
return ChatAgent(
150+
system_message=BaseMessage.make_assistant_message(
151+
role_name=role_name,
152+
content=content,
153+
),
154+
model=ModelFactory.create(
155+
model_platform=ModelPlatformType.OPENAI,
156+
model_type=ModelType.STUB,
157+
),
158+
)
159+
160+
161+
def _build_worker_specs() -> list[tuple[str, str, ChatAgent]]:
162+
"""Build the standard trio of workers used across tests."""
163+
return [
164+
(
165+
"A researcher who can search online for information.",
166+
"SearchWork",
167+
_build_persona_agent(
168+
"Research Specialist",
169+
"You are a research specialist who excels at finding and "
170+
"gathering information from the web.",
171+
),
172+
),
173+
(
174+
"An analyst who can process research findings.",
175+
"AnalystWorker",
176+
_build_persona_agent(
177+
"Business Analyst",
178+
"You are an expert business analyst. Your job is "
179+
"to analyze research findings, identify key insights, "
180+
"opportunities, and challenges.",
181+
),
182+
),
183+
(
184+
"A writer who can create a final report from the analysis.",
185+
"WriterWorker",
186+
_build_persona_agent(
187+
"Report Writer",
188+
"You are a professional report writer. You take "
189+
"analytical insights and synthesize them into a clear, "
190+
"concise, and well-structured final report.",
191+
),
192+
),
193+
]
194+
195+
196+
async def _assert_metrics_callbacks(
197+
workforce: Workforce, cb: _MetricsCallback
198+
):
199+
"""Verify metrics callback toggles across reset cycles."""
200+
assert not cb.dump_to_json_called
201+
assert not cb.get_ascii_tree_called
202+
assert not cb.get_kpis_called
203+
await workforce.dump_workforce_logs("foo.log")
204+
assert cb.dump_to_json_called
205+
206+
await workforce.reset()
207+
assert not cb.dump_to_json_called
208+
assert not cb.get_ascii_tree_called
209+
assert not cb.get_kpis_called
210+
await workforce.get_workforce_kpis()
211+
assert cb.get_kpis_called
212+
213+
await workforce.reset()
214+
assert not cb.dump_to_json_called
215+
assert not cb.get_ascii_tree_called
216+
assert not cb.get_kpis_called
217+
await workforce.get_workforce_log_tree()
218+
assert cb.get_ascii_tree_called
219+
220+
145221
@pytest.mark.asyncio
146222
async def test_workforce_callback_registration_and_metrics_handling():
147223
"""Verify default logger addition and metrics-callback skip logic.
@@ -176,165 +252,48 @@ async def test_workforce_callback_registration_and_metrics_handling():
176252
Workforce("CB Test - Invalid", callbacks=[object()])
177253

178254

179-
def assert_event_sequence(events: list[str], min_worker_count: int):
180-
"""
181-
Validate that the given event sequence follows the expected logical order.
182-
This version is flexible to handle:
183-
- Task retries and dynamic worker creation
184-
- Cases where tasks are not decomposed (e.g., when using stub models)
185-
"""
186-
idx = 0
187-
n = len(events)
188-
189-
# 1. Expect at least min_worker_count WorkerCreatedEvent events first
190-
initial_worker_count = 0
191-
while idx < n and events[idx] == "WorkerCreatedEvent":
192-
initial_worker_count += 1
193-
idx += 1
194-
assert initial_worker_count >= min_worker_count, (
195-
f"Expected at least {min_worker_count} initial "
196-
f"WorkerCreatedEvents, got {initial_worker_count}"
197-
)
198-
199-
# 2. Expect one main TaskCreatedEvent
200-
assert idx < n and events[idx] == "TaskCreatedEvent", (
201-
f"Event {idx} should be TaskCreatedEvent, got "
202-
f"{events[idx] if idx < n else 'END'}"
203-
)
204-
idx += 1
205-
206-
# 3. TaskDecomposedEvent may or may not be present
207-
# (depends on coordinator behavior)
208-
# If the coordinator can't parse stub responses, it may skip
209-
# decomposition
210-
has_decomposition = idx < n and events[idx] == "TaskDecomposedEvent"
211-
if has_decomposition:
212-
idx += 1
213-
214-
# 4. Count all event types in the remaining events
215-
all_events = events[idx:]
216-
task_assigned_count = all_events.count("TaskAssignedEvent")
217-
task_started_count = all_events.count("TaskStartedEvent")
218-
task_completed_count = all_events.count("TaskCompletedEvent")
219-
all_tasks_completed_count = all_events.count("AllTasksCompletedEvent")
220-
221-
# 5. Validate basic invariants
222-
# At minimum, the main task should be assigned and processed
223-
assert (
224-
task_assigned_count >= 1
225-
), f"Expected at least 1 TaskAssignedEvent, got {task_assigned_count}"
226-
assert (
227-
task_started_count >= 1
228-
), f"Expected at least 1 TaskStartedEvent, got {task_started_count}"
229-
assert (
230-
task_completed_count >= 1
231-
), f"Expected at least 1 TaskCompletedEvent, got {task_completed_count}"
232-
233-
# 6. Expect exactly one AllTasksCompletedEvent at the end
234-
assert all_tasks_completed_count == 1, (
235-
f"Expected exactly 1 AllTasksCompletedEvent, got "
236-
f"{all_tasks_completed_count}"
237-
)
238-
assert (
239-
events[-1] == "AllTasksCompletedEvent"
240-
), "Last event should be AllTasksCompletedEvent"
241-
242-
# 7. All events should be of expected types
243-
allowed_events = {
244-
"WorkerCreatedEvent",
245-
"WorkerDeletedEvent",
246-
"TaskCreatedEvent",
247-
"TaskDecomposedEvent",
248-
"TaskAssignedEvent",
249-
"TaskStartedEvent",
250-
"TaskCompletedEvent",
251-
"TaskFailedEvent",
252-
"AllTasksCompletedEvent",
253-
}
254-
for i, e in enumerate(events):
255-
assert e in allowed_events, f"Unexpected event type at {i}: {e}"
256-
257-
258255
@pytest.mark.asyncio
259-
async def test_workforce_emits_expected_event_sequence():
260-
# Use STUB model to avoid real API calls and ensure fast,
261-
# deterministic execution
262-
search_agent = ChatAgent(
263-
system_message=BaseMessage.make_assistant_message(
264-
role_name="Research Specialist",
265-
content="You are a research specialist who excels at finding and "
266-
"gathering information from the web.",
267-
),
268-
model=ModelFactory.create(
269-
model_platform=ModelPlatformType.OPENAI,
270-
model_type=ModelType.STUB,
271-
),
272-
)
273-
274-
analyst_agent = ChatAgent(
275-
system_message=BaseMessage.make_assistant_message(
276-
role_name="Business Analyst",
277-
content="You are an expert business analyst. Your job is "
278-
"to analyze research findings, identify key insights, "
279-
"opportunities, and challenges.",
280-
),
281-
model=ModelFactory.create(
282-
model_platform=ModelPlatformType.OPENAI,
283-
model_type=ModelType.STUB,
284-
),
285-
)
286-
287-
writer_agent = ChatAgent(
288-
system_message=BaseMessage.make_assistant_message(
289-
role_name="Report Writer",
290-
content="You are a professional report writer. You take "
291-
"analytical insights and synthesize them into a clear, "
292-
"concise, and well-structured final report.",
293-
),
294-
model=ModelFactory.create(
295-
model_platform=ModelPlatformType.OPENAI,
296-
model_type=ModelType.STUB,
297-
),
298-
)
299-
256+
@pytest.mark.parametrize(
257+
"preconfigure_children",
258+
[False, True],
259+
ids=["add_workers_at_runtime", "preconfigure_children"],
260+
)
261+
async def test_workforce_emits_expected_events_for_worker_init_modes(
262+
preconfigure_children: bool,
263+
):
264+
"""Validate event ordering for both worker setup paths."""
300265
cb = _MetricsCallback()
301-
302-
# Use STUB models for coordinator and task agents to avoid real API calls
303-
coordinator_agent = ChatAgent(
304-
model=ModelFactory.create(
305-
model_platform=ModelPlatformType.OPENAI,
306-
model_type=ModelType.STUB,
266+
coordinator_agent = _build_stub_agent()
267+
task_agent = _build_stub_agent()
268+
worker_specs = _build_worker_specs()
269+
270+
if preconfigure_children:
271+
children: list[BaseNode] = [
272+
SingleAgentWorker(description=child_desc, worker=agent)
273+
for _, child_desc, agent in worker_specs
274+
]
275+
workforce = Workforce(
276+
'Business Analysis Team',
277+
graceful_shutdown_timeout=30.0,
278+
callbacks=[cb],
279+
coordinator_agent=coordinator_agent,
280+
task_agent=task_agent,
281+
children=children,
307282
)
308-
)
309-
task_agent = ChatAgent(
310-
model=ModelFactory.create(
311-
model_platform=ModelPlatformType.OPENAI,
312-
model_type=ModelType.STUB,
283+
else:
284+
workforce = Workforce(
285+
'Business Analysis Team',
286+
graceful_shutdown_timeout=30.0,
287+
callbacks=[cb],
288+
coordinator_agent=coordinator_agent,
289+
task_agent=task_agent,
313290
)
314-
)
315-
316-
workforce = Workforce(
317-
'Business Analysis Team',
318-
graceful_shutdown_timeout=30.0,
319-
callbacks=[cb],
320-
coordinator_agent=coordinator_agent,
321-
task_agent=task_agent,
322-
)
291+
for add_desc, _, agent in worker_specs:
292+
await workforce.add_single_agent_worker(
293+
add_desc,
294+
worker=agent,
295+
)
323296

324-
await workforce.add_single_agent_worker(
325-
"A researcher who can search online for information.",
326-
worker=search_agent,
327-
)
328-
await workforce.add_single_agent_worker(
329-
"An analyst who can process research findings.",
330-
worker=analyst_agent,
331-
)
332-
await workforce.add_single_agent_worker(
333-
"A writer who can create a final report from the analysis.",
334-
worker=writer_agent,
335-
)
336-
337-
# Use a simpler task to ensure fast and deterministic execution
338297
human_task = Task(
339298
content=(
340299
"Create a simple report about electric scooters. "
@@ -348,27 +307,18 @@ async def test_workforce_emits_expected_event_sequence():
348307

349308
await workforce.process_task_async(human_task)
350309

351-
# test that the event sequence is as expected
310+
expected_events = [
311+
"WorkerCreatedEvent",
312+
"WorkerCreatedEvent",
313+
"WorkerCreatedEvent",
314+
"TaskCreatedEvent",
315+
"WorkerCreatedEvent",
316+
"TaskAssignedEvent",
317+
"TaskStartedEvent",
318+
"TaskCompletedEvent",
319+
"AllTasksCompletedEvent",
320+
]
352321
actual_events = [e.__class__.__name__ for e in cb.events]
353-
assert_event_sequence(actual_events, min_worker_count=3)
322+
assert actual_events == expected_events
354323

355-
# test that metrics callback methods work as expected
356-
assert not cb.dump_to_json_called
357-
assert not cb.get_ascii_tree_called
358-
assert not cb.get_kpis_called
359-
await workforce.dump_workforce_logs("foo.log")
360-
assert cb.dump_to_json_called
361-
362-
await workforce.reset()
363-
assert not cb.dump_to_json_called
364-
assert not cb.get_ascii_tree_called
365-
assert not cb.get_kpis_called
366-
await workforce.get_workforce_kpis()
367-
assert cb.get_kpis_called
368-
369-
await workforce.reset()
370-
assert not cb.dump_to_json_called
371-
assert not cb.get_ascii_tree_called
372-
assert not cb.get_kpis_called
373-
await workforce.get_workforce_log_tree()
374-
assert cb.get_ascii_tree_called
324+
await _assert_metrics_callbacks(workforce, cb)

0 commit comments

Comments
 (0)