Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 43 additions & 0 deletions camel/societies/workforce/task_channel.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,11 @@
from enum import Enum
from typing import Dict, List, Optional, Set

from camel.logger import get_logger
from camel.tasks import Task

logger = get_logger(__name__)


class PacketStatus(Enum):
r"""The status of a packet. The packet can be in one of the following
Expand Down Expand Up @@ -269,6 +272,46 @@ async def get_dependency_ids(self) -> List[str]:
async with self._condition:
return list(self._task_by_status[PacketStatus.ARCHIVED])

async def get_in_flight_tasks(self, publisher_id: str) -> List[Task]:
r"""Get all tasks that are currently in-flight (SENT, RETURNED
or PROCESSING) published by the given publisher.

Args:
publisher_id (str): The ID of the publisher whose
in-flight tasks to retrieve.

Returns:
List[Task]: List of tasks that are currently in-flight.
"""
async with self._condition:
in_flight_tasks = []
seen_task_ids = set() # Track seen IDs for duplicate detection

# Get tasks with SENT, RETURNED or PROCESSING
# status published by this publisher
for status in [
PacketStatus.SENT,
PacketStatus.PROCESSING,
PacketStatus.RETURNED,
]:
for task_id in self._task_by_status[status]:
if task_id in self._task_dict:
packet = self._task_dict[task_id]
if packet.publisher_id == publisher_id:
# Defensive check: detect if task appears in
# multiple status sets (should never happen)
if task_id in seen_task_ids:
logger.warning(
f"Task {task_id} found in multiple "
f"status sets. This indicates a bug in "
f"status management."
)
continue
in_flight_tasks.append(packet.task)
seen_task_ids.add(task_id)

return in_flight_tasks

async def get_task_by_id(self, task_id: str) -> Task:
r"""Get a task from the channel by its ID."""
async with self._condition:
Expand Down
Loading