Skip to content

Commit cd462c8

Browse files
committed
refactor(workforce): Refactor WorkforceCallback and all related callbacks to async interface
1 parent d64e6e9 commit cd462c8

21 files changed

+442
-379
lines changed

camel/benchmarks/browsecomp.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
# See the License for the specific language governing permissions and
1212
# limitations under the License.
1313
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
14-
14+
import asyncio
1515
import base64
1616
import hashlib
1717
import json
@@ -593,7 +593,7 @@ def process_benchmark_row(row: Dict[str, Any]) -> Dict[str, Any]:
593593
elif isinstance(pipeline_template, Workforce):
594594
pipeline = pipeline_template.clone() # type: ignore[assignment]
595595
task = Task(content=input_message, id="0")
596-
task = pipeline.process_task(task) # type: ignore[attr-defined]
596+
task = asyncio.run(pipeline.process_task_async(task)) # type: ignore[attr-defined]
597597
if task_json_formatter:
598598
formatter_in_process = task_json_formatter.clone()
599599
else:

camel/societies/workforce/workforce.py

Lines changed: 40 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -509,9 +509,9 @@ def _initialize_callbacks(
509509
)
510510

511511
for child in self._children:
512-
self._notify_worker_created(child)
512+
asyncio.run(self._notify_worker_created(child))
513513

514-
def _notify_worker_created(
514+
async def _notify_worker_created(
515515
self,
516516
worker_node: BaseNode,
517517
*,
@@ -527,7 +527,7 @@ def _notify_worker_created(
527527
metadata=metadata,
528528
)
529529
for cb in self._callbacks:
530-
cb.log_worker_created(event)
530+
await cb.log_worker_created(event)
531531

532532
def _get_or_create_shared_context_utility(
533533
self,
@@ -1189,7 +1189,7 @@ async def _apply_recovery_strategy(
11891189
subtask_ids=[st.id for st in subtasks],
11901190
)
11911191
for cb in self._callbacks:
1192-
cb.log_task_decomposed(task_decomposed_event)
1192+
await cb.log_task_decomposed(task_decomposed_event)
11931193
for subtask in subtasks:
11941194
task_created_event = TaskCreatedEvent(
11951195
task_id=subtask.id,
@@ -1199,7 +1199,7 @@ async def _apply_recovery_strategy(
11991199
metadata=subtask.additional_info,
12001200
)
12011201
for cb in self._callbacks:
1202-
cb.log_task_created(task_created_event)
1202+
await cb.log_task_created(task_created_event)
12031203

12041204
# Insert subtasks at the head of the queue
12051205
self._pending_tasks.extendleft(reversed(subtasks))
@@ -1719,7 +1719,7 @@ async def handle_decompose_append_task(
17191719
return [task]
17201720

17211721
if reset and self._state != WorkforceState.RUNNING:
1722-
self.reset()
1722+
await self.reset()
17231723
logger.info("Workforce reset before handling task.")
17241724

17251725
# Focus on the new task
@@ -1733,7 +1733,7 @@ async def handle_decompose_append_task(
17331733
metadata=task.additional_info,
17341734
)
17351735
for cb in self._callbacks:
1736-
cb.log_task_created(task_created_event)
1736+
await cb.log_task_created(task_created_event)
17371737

17381738
# The agent tend to be overconfident on the whole task, so we
17391739
# decompose the task into subtasks first
@@ -1754,7 +1754,7 @@ async def handle_decompose_append_task(
17541754
subtask_ids=[st.id for st in subtasks],
17551755
)
17561756
for cb in self._callbacks:
1757-
cb.log_task_decomposed(task_decomposed_event)
1757+
await cb.log_task_decomposed(task_decomposed_event)
17581758
for subtask in subtasks:
17591759
task_created_event = TaskCreatedEvent(
17601760
task_id=subtask.id,
@@ -1764,7 +1764,7 @@ async def handle_decompose_append_task(
17641764
metadata=subtask.additional_info,
17651765
)
17661766
for cb in self._callbacks:
1767-
cb.log_task_created(task_created_event)
1767+
await cb.log_task_created(task_created_event)
17681768

17691769
if subtasks:
17701770
# _pending_tasks will contain both undecomposed
@@ -2027,7 +2027,7 @@ def _start_child_node_when_paused(
20272027
# Close the coroutine to prevent RuntimeWarning
20282028
start_coroutine.close()
20292029

2030-
def add_single_agent_worker(
2030+
async def add_single_agent_worker(
20312031
self,
20322032
description: str,
20332033
worker: ChatAgent,
@@ -2083,13 +2083,13 @@ def add_single_agent_worker(
20832083
# If workforce is paused, start the worker's listening task
20842084
self._start_child_node_when_paused(worker_node.start())
20852085

2086-
self._notify_worker_created(
2086+
await self._notify_worker_created(
20872087
worker_node,
20882088
worker_type='SingleAgentWorker',
20892089
)
20902090
return self
20912091

2092-
def add_role_playing_worker(
2092+
async def add_role_playing_worker(
20932093
self,
20942094
description: str,
20952095
assistant_role_name: str,
@@ -2160,7 +2160,7 @@ def add_role_playing_worker(
21602160
# If workforce is paused, start the worker's listening task
21612161
self._start_child_node_when_paused(worker_node.start())
21622162

2163-
self._notify_worker_created(
2163+
await self._notify_worker_created(
21642164
worker_node,
21652165
worker_type='RolePlayingWorker',
21662166
)
@@ -2202,7 +2202,7 @@ async def _async_reset(self) -> None:
22022202
self._pause_event.set()
22032203

22042204
@check_if_running(False)
2205-
def reset(self) -> None:
2205+
async def reset(self) -> None:
22062206
r"""Reset the workforce and all the child nodes under it. Can only
22072207
be called when the workforce is not running.
22082208
"""
@@ -2229,9 +2229,7 @@ def reset(self) -> None:
22292229
if self._loop and not self._loop.is_closed():
22302230
# If we have a loop, use it to set the event safely
22312231
try:
2232-
asyncio.run_coroutine_threadsafe(
2233-
self._async_reset(), self._loop
2234-
).result()
2232+
await self._async_reset()
22352233
except RuntimeError as e:
22362234
logger.warning(f"Failed to reset via existing loop: {e}")
22372235
# Fallback to direct event manipulation
@@ -2242,7 +2240,7 @@ def reset(self) -> None:
22422240

22432241
for cb in self._callbacks:
22442242
if isinstance(cb, WorkforceMetrics):
2245-
cb.reset_task_data()
2243+
await cb.reset_task_data()
22462244

22472245
def save_workflow_memories(
22482246
self,
@@ -3093,7 +3091,7 @@ async def _post_task(self, task: Task, assignee_id: str) -> None:
30933091
task_id=task.id, worker_id=assignee_id
30943092
)
30953093
for cb in self._callbacks:
3096-
cb.log_task_started(task_started_event)
3094+
await cb.log_task_started(task_started_event)
30973095

30983096
try:
30993097
await self._channel.post_task(task, self.node_id, assignee_id)
@@ -3240,15 +3238,13 @@ async def _create_worker_node_for_task(self, task: Task) -> Worker:
32403238

32413239
self._children.append(new_node)
32423240

3243-
self._notify_worker_created(
3241+
await self._notify_worker_created(
32443242
new_node,
32453243
worker_type='SingleAgentWorker',
32463244
role=new_node_conf.role,
32473245
metadata={'description': new_node_conf.description},
32483246
)
3249-
self._child_listening_tasks.append(
3250-
asyncio.create_task(new_node.start())
3251-
)
3247+
self._child_listening_tasks.append(await new_node.start())
32523248
return new_node
32533249

32543250
async def _create_new_agent(self, role: str, sys_msg: str) -> ChatAgent:
@@ -3363,7 +3359,7 @@ async def _post_ready_tasks(self) -> None:
33633359
for cb in self._callbacks:
33643360
# queue_time_seconds can be derived by logger if task
33653361
# creation time is logged
3366-
cb.log_task_assigned(task_assigned_event)
3362+
await cb.log_task_assigned(task_assigned_event)
33673363

33683364
# Step 2: Iterate through all pending tasks and post those that are
33693365
# ready
@@ -3492,7 +3488,7 @@ async def _post_ready_tasks(self) -> None:
34923488
},
34933489
)
34943490
for cb in self._callbacks:
3495-
cb.log_task_failed(task_failed_event)
3491+
await cb.log_task_failed(task_failed_event)
34963492

34973493
self._completed_tasks.append(task)
34983494
self._cleanup_task_tracking(task.id)
@@ -3555,7 +3551,7 @@ async def _handle_failed_task(self, task: Task) -> bool:
35553551
},
35563552
)
35573553
for cb in self._callbacks:
3558-
cb.log_task_failed(task_failed_event)
3554+
await cb.log_task_failed(task_failed_event)
35593555

35603556
# Check for immediate halt conditions
35613557
if task.failure_count >= MAX_TASK_RETRIES:
@@ -3699,7 +3695,7 @@ async def _handle_completed_task(self, task: Task) -> None:
36993695
metadata={'current_state': task.state.value},
37003696
)
37013697
for cb in self._callbacks:
3702-
cb.log_task_completed(task_completed_event)
3698+
await cb.log_task_completed(task_completed_event)
37033699

37043700
# Find and remove the completed task from pending tasks
37053701
tasks_list = list(self._pending_tasks)
@@ -3815,7 +3811,7 @@ async def _graceful_shutdown(self, failed_task: Task) -> None:
38153811
# Wait for the full timeout period
38163812
await asyncio.sleep(self.graceful_shutdown_timeout)
38173813

3818-
def get_workforce_log_tree(self) -> str:
3814+
async def get_workforce_log_tree(self) -> str:
38193815
r"""Returns an ASCII tree representation of the task hierarchy and
38203816
worker status.
38213817
"""
@@ -3825,19 +3821,19 @@ def get_workforce_log_tree(self) -> str:
38253821
if len(metrics_cb) == 0:
38263822
return "Metrics Callback not initialized."
38273823
else:
3828-
return metrics_cb[0].get_ascii_tree_representation()
3824+
return await metrics_cb[0].get_ascii_tree_representation()
38293825

3830-
def get_workforce_kpis(self) -> Dict[str, Any]:
3826+
async def get_workforce_kpis(self) -> Dict[str, Any]:
38313827
r"""Returns a dictionary of key performance indicators."""
38323828
metrics_cb: List[WorkforceMetrics] = [
38333829
cb for cb in self._callbacks if isinstance(cb, WorkforceMetrics)
38343830
]
38353831
if len(metrics_cb) == 0:
38363832
return {"error": "Metrics Callback not initialized."}
38373833
else:
3838-
return metrics_cb[0].get_kpis()
3834+
return await metrics_cb[0].get_kpis()
38393835

3840-
def dump_workforce_logs(self, file_path: str) -> None:
3836+
async def dump_workforce_logs(self, file_path: str) -> None:
38413837
r"""Dumps all collected logs to a JSON file.
38423838
38433839
Args:
@@ -3849,7 +3845,7 @@ def dump_workforce_logs(self, file_path: str) -> None:
38493845
if len(metrics_cb) == 0:
38503846
print("Logger not initialized. Cannot dump logs.")
38513847
return
3852-
metrics_cb[0].dump_to_json(file_path)
3848+
await metrics_cb[0].dump_to_json(file_path)
38533849
# Use logger.info or print, consistent with existing style
38543850
logger.info(f"Workforce logs dumped to {file_path}")
38553851

@@ -4325,7 +4321,7 @@ async def _listen_to_channel(self) -> None:
43254321
logger.info("All tasks completed.")
43264322
all_tasks_completed_event = AllTasksCompletedEvent()
43274323
for cb in self._callbacks:
4328-
cb.log_all_tasks_completed(all_tasks_completed_event)
4324+
await cb.log_all_tasks_completed(all_tasks_completed_event)
43294325

43304326
# shut down the whole workforce tree
43314327
self.stop()
@@ -4413,7 +4409,7 @@ async def cleanup():
44134409

44144410
self._running = False
44154411

4416-
def clone(self, with_memory: bool = False) -> 'Workforce':
4412+
async def clone(self, with_memory: bool = False) -> 'Workforce':
44174413
r"""Creates a new instance of Workforce with the same configuration.
44184414
44194415
Args:
@@ -4444,13 +4440,13 @@ def clone(self, with_memory: bool = False) -> 'Workforce':
44444440
for child in self._children:
44454441
if isinstance(child, SingleAgentWorker):
44464442
cloned_worker = child.worker.clone(with_memory)
4447-
new_instance.add_single_agent_worker(
4443+
await new_instance.add_single_agent_worker(
44484444
child.description,
44494445
cloned_worker,
44504446
pool_max_size=10,
44514447
)
44524448
elif isinstance(child, RolePlayingWorker):
4453-
new_instance.add_role_playing_worker(
4449+
await new_instance.add_role_playing_worker(
44544450
child.description,
44554451
child.assistant_role_name,
44564452
child.user_role_name,
@@ -4460,7 +4456,7 @@ def clone(self, with_memory: bool = False) -> 'Workforce':
44604456
child.chat_turn_limit,
44614457
)
44624458
elif isinstance(child, Workforce):
4463-
new_instance.add_workforce(child.clone(with_memory))
4459+
new_instance.add_workforce(await child.clone(with_memory))
44644460
else:
44654461
logger.warning(f"{type(child)} is not being cloned.")
44664462
continue
@@ -4695,7 +4691,7 @@ def get_children_info():
46954691
return children_info
46964692

46974693
# Add single agent worker
4698-
def add_single_agent_worker(
4694+
async def add_single_agent_worker(
46994695
description,
47004696
system_message=None,
47014697
role_name="Assistant",
@@ -4759,7 +4755,9 @@ def add_single_agent_worker(
47594755
"message": str(e),
47604756
}
47614757

4762-
workforce_instance.add_single_agent_worker(description, agent)
4758+
await workforce_instance.add_single_agent_worker(
4759+
description, agent
4760+
)
47634761

47644762
return {
47654763
"status": "success",
@@ -4770,7 +4768,7 @@ def add_single_agent_worker(
47704768
return {"status": "error", "message": str(e)}
47714769

47724770
# Add role playing worker
4773-
def add_role_playing_worker(
4771+
async def add_role_playing_worker(
47744772
description,
47754773
assistant_role_name,
47764774
user_role_name,
@@ -4827,7 +4825,7 @@ def add_role_playing_worker(
48274825
"message": "Cannot add workers while workforce is running", # noqa: E501
48284826
}
48294827

4830-
workforce_instance.add_role_playing_worker(
4828+
await workforce_instance.add_role_playing_worker(
48314829
description=description,
48324830
assistant_role_name=assistant_role_name,
48334831
user_role_name=user_role_name,

camel/societies/workforce/workforce_callback.py

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -35,40 +35,42 @@ class WorkforceCallback(ABC):
3535
"""
3636

3737
@abstractmethod
38-
def log_task_created(
38+
async def log_task_created(
3939
self,
4040
event: TaskCreatedEvent,
4141
) -> None:
4242
pass
4343

4444
@abstractmethod
45-
def log_task_decomposed(self, event: TaskDecomposedEvent) -> None:
45+
async def log_task_decomposed(self, event: TaskDecomposedEvent) -> None:
4646
pass
4747

4848
@abstractmethod
49-
def log_task_assigned(self, event: TaskAssignedEvent) -> None:
49+
async def log_task_assigned(self, event: TaskAssignedEvent) -> None:
5050
pass
5151

5252
@abstractmethod
53-
def log_task_started(self, event: TaskStartedEvent) -> None:
53+
async def log_task_started(self, event: TaskStartedEvent) -> None:
5454
pass
5555

5656
@abstractmethod
57-
def log_task_completed(self, event: TaskCompletedEvent) -> None:
57+
async def log_task_completed(self, event: TaskCompletedEvent) -> None:
5858
pass
5959

6060
@abstractmethod
61-
def log_task_failed(self, event: TaskFailedEvent) -> None:
61+
async def log_task_failed(self, event: TaskFailedEvent) -> None:
6262
pass
6363

6464
@abstractmethod
65-
def log_worker_created(self, event: WorkerCreatedEvent) -> None:
65+
async def log_worker_created(self, event: WorkerCreatedEvent) -> None:
6666
pass
6767

6868
@abstractmethod
69-
def log_worker_deleted(self, event: WorkerDeletedEvent) -> None:
69+
async def log_worker_deleted(self, event: WorkerDeletedEvent) -> None:
7070
pass
7171

7272
@abstractmethod
73-
def log_all_tasks_completed(self, event: AllTasksCompletedEvent) -> None:
73+
async def log_all_tasks_completed(
74+
self, event: AllTasksCompletedEvent
75+
) -> None:
7476
pass

0 commit comments

Comments
 (0)