|
16 | 16 | from enum import Enum |
17 | 17 | from typing import Dict, List, Optional, Set |
18 | 18 |
|
| 19 | +from camel.logger import get_logger |
19 | 20 | from camel.tasks import Task |
20 | 21 |
|
| 22 | +logger = get_logger(__name__) |
| 23 | + |
21 | 24 |
|
22 | 25 | class PacketStatus(Enum): |
23 | 26 | r"""The status of a packet. The packet can be in one of the following |
@@ -269,6 +272,46 @@ async def get_dependency_ids(self) -> List[str]: |
269 | 272 | async with self._condition: |
270 | 273 | return list(self._task_by_status[PacketStatus.ARCHIVED]) |
271 | 274 |
|
| 275 | + async def get_in_flight_tasks(self, publisher_id: str) -> List[Task]: |
| 276 | + r"""Get all tasks that are currently in-flight (SENT, RETURNED |
| 277 | + or PROCESSING) published by the given publisher. |
| 278 | +
|
| 279 | + Args: |
| 280 | + publisher_id (str): The ID of the publisher whose |
| 281 | + in-flight tasks to retrieve. |
| 282 | +
|
| 283 | + Returns: |
| 284 | + List[Task]: List of tasks that are currently in-flight. |
| 285 | + """ |
| 286 | + async with self._condition: |
| 287 | + in_flight_tasks = [] |
| 288 | + seen_task_ids = set() # Track seen IDs for duplicate detection |
| 289 | + |
| 290 | + # Get tasks with SENT, RETURNED or PROCESSING |
| 291 | + # status published by this publisher |
| 292 | + for status in [ |
| 293 | + PacketStatus.SENT, |
| 294 | + PacketStatus.PROCESSING, |
| 295 | + PacketStatus.RETURNED, |
| 296 | + ]: |
| 297 | + for task_id in self._task_by_status[status]: |
| 298 | + if task_id in self._task_dict: |
| 299 | + packet = self._task_dict[task_id] |
| 300 | + if packet.publisher_id == publisher_id: |
| 301 | + # Defensive check: detect if task appears in |
| 302 | + # multiple status sets (should never happen) |
| 303 | + if task_id in seen_task_ids: |
| 304 | + logger.warning( |
| 305 | + f"Task {task_id} found in multiple " |
| 306 | + f"status sets. This indicates a bug in " |
| 307 | + f"status management." |
| 308 | + ) |
| 309 | + continue |
| 310 | + in_flight_tasks.append(packet.task) |
| 311 | + seen_task_ids.add(task_id) |
| 312 | + |
| 313 | + return in_flight_tasks |
| 314 | + |
272 | 315 | async def get_task_by_id(self, task_id: str) -> Task: |
273 | 316 | r"""Get a task from the channel by its ID.""" |
274 | 317 | async with self._condition: |
|
0 commit comments