|
10 | 10 | from typing import TYPE_CHECKING, Callable, Optional, Sequence |
11 | 11 |
|
12 | 12 | from datatrove.executor.base import DistributedEnvVars, PipelineExecutor |
13 | | -from datatrove.io import DataFolderLike, get_datafolder |
| 13 | +from datatrove.io import DataFolderLike, file_is_local, get_datafolder |
14 | 14 | from datatrove.pipeline.base import PipelineStep |
15 | 15 | from datatrove.utils._import_utils import check_required_dependencies |
16 | 16 | from datatrove.utils.logging import add_task_logger, close_task_logger, log_pipeline, logger |
@@ -431,18 +431,26 @@ def __init__( |
431 | 431 | tasks_per_job: int = 1, |
432 | 432 | time: Optional[int] = None, |
433 | 433 | ): |
| 434 | + # Check if the logging_dir is local fs and if so issue a warning that for synchronization it has to be a shared filesystem |
| 435 | + if logging_dir and file_is_local(logging_dir): |
| 436 | + logger.warning( |
| 437 | + "Logging directory points to a local filesystem. For correct synchronization to work this " |
| 438 | + "filesystem needs be shared across the submitting node as well as the workers and needs " |
| 439 | + "to be persistent across node restarts." |
| 440 | + ) |
| 441 | + |
434 | 442 | super().__init__(pipeline, logging_dir, skip_completed, randomize_start_duration) |
435 | 443 | self.tasks = tasks |
436 | 444 | self.workers = workers if workers != -1 else tasks |
437 | 445 | self.depends = depends |
438 | | - # track whether run() has been called |
439 | 446 | self.cpus_per_task = cpus_per_task |
440 | 447 | self.gpus_per_task = gpus_per_task |
441 | 448 | self.mem_per_cpu_gb = mem_per_cpu_gb |
442 | 449 | self.ray_remote_kwargs = ray_remote_kwargs |
443 | 450 | self.tasks_per_job = tasks_per_job |
444 | 451 | self.log_first = log_first |
445 | 452 | self.time = time |
| 453 | + self._launched = False |
446 | 454 | self.nodes_per_task = nodes_per_task |
447 | 455 |
|
448 | 456 | def get_distributed_env(self, node_rank: int = -1) -> DistributedEnvVars: |
@@ -472,12 +480,22 @@ def run(self): |
472 | 480 | check_required_dependencies("ray", ["ray"]) |
473 | 481 | import ray |
474 | 482 |
|
475 | | - # 1) If there is a depends=, ensure it has run and is finished |
| 483 | + assert not self.depends or (isinstance(self.depends, RayPipelineExecutor)), ( |
| 484 | + "depends= must be a RayPipelineExecutor" |
| 485 | + ) |
476 | 486 | if self.depends: |
477 | | - logger.info(f'Launching dependency job "{self.depends}"') |
478 | | - self.depends.run() |
479 | | - |
480 | | - # 2) Check if all tasks are already completed |
| 487 | + # take care of launching any unlaunched dependencies |
| 488 | + if not self.depends._launched: |
| 489 | + logger.info(f'Launching dependency job "{self.depends}"') |
| 490 | + self.depends.run() |
| 491 | + while ( |
| 492 | + incomplete := len(self.depends.get_incomplete_ranks(skip_completed=True)) |
| 493 | + ) > 0: # set skip_completed=True to get *real* incomplete task count |
| 494 | + logger.info(f"Dependency job still has {incomplete}/{self.depends.world_size} tasks. Waiting...") |
| 495 | + time.sleep(2 * 60) |
| 496 | + |
| 497 | + self._launched = True |
| 498 | + # 3) Check if all tasks are already completed |
481 | 499 | incomplete_ranks = self.get_incomplete_ranks(range(self.world_size)) |
482 | 500 | if not incomplete_ranks: |
483 | 501 | logger.info(f"All {self.world_size} tasks appear to be completed already. Nothing to run.") |
|
0 commit comments