Skip to content

Commit b1be8a7

Browse files
authored
Merge pull request #403 from huggingface/ray_fix
ray nits
2 parents 8bbda9c + f17215b commit b1be8a7

File tree

3 files changed

+77
-7
lines changed

3 files changed

+77
-7
lines changed

src/datatrove/executor/ray.py

Lines changed: 25 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from typing import TYPE_CHECKING, Callable, Optional, Sequence
1111

1212
from datatrove.executor.base import DistributedEnvVars, PipelineExecutor
13-
from datatrove.io import DataFolderLike, get_datafolder
13+
from datatrove.io import DataFolderLike, file_is_local, get_datafolder
1414
from datatrove.pipeline.base import PipelineStep
1515
from datatrove.utils._import_utils import check_required_dependencies
1616
from datatrove.utils.logging import add_task_logger, close_task_logger, log_pipeline, logger
@@ -431,18 +431,26 @@ def __init__(
431431
tasks_per_job: int = 1,
432432
time: Optional[int] = None,
433433
):
434+
# Check if the logging_dir is local fs and if so issue a warning that for synchronization it has to be a shared filesystem
435+
if logging_dir and file_is_local(logging_dir):
436+
logger.warning(
437+
"Logging directory points to a local filesystem. For correct synchronization to work this "
438+
"filesystem needs be shared across the submitting node as well as the workers and needs "
439+
"to be persistent across node restarts."
440+
)
441+
434442
super().__init__(pipeline, logging_dir, skip_completed, randomize_start_duration)
435443
self.tasks = tasks
436444
self.workers = workers if workers != -1 else tasks
437445
self.depends = depends
438-
# track whether run() has been called
439446
self.cpus_per_task = cpus_per_task
440447
self.gpus_per_task = gpus_per_task
441448
self.mem_per_cpu_gb = mem_per_cpu_gb
442449
self.ray_remote_kwargs = ray_remote_kwargs
443450
self.tasks_per_job = tasks_per_job
444451
self.log_first = log_first
445452
self.time = time
453+
self._launched = False
446454
self.nodes_per_task = nodes_per_task
447455

448456
def get_distributed_env(self, node_rank: int = -1) -> DistributedEnvVars:
@@ -472,12 +480,22 @@ def run(self):
472480
check_required_dependencies("ray", ["ray"])
473481
import ray
474482

475-
# 1) If there is a depends=, ensure it has run and is finished
483+
assert not self.depends or (isinstance(self.depends, RayPipelineExecutor)), (
484+
"depends= must be a RayPipelineExecutor"
485+
)
476486
if self.depends:
477-
logger.info(f'Launching dependency job "{self.depends}"')
478-
self.depends.run()
479-
480-
# 2) Check if all tasks are already completed
487+
# take care of launching any unlaunched dependencies
488+
if not self.depends._launched:
489+
logger.info(f'Launching dependency job "{self.depends}"')
490+
self.depends.run()
491+
while (
492+
incomplete := len(self.depends.get_incomplete_ranks(skip_completed=True))
493+
) > 0: # set skip_completed=True to get *real* incomplete task count
494+
logger.info(f"Dependency job still has {incomplete}/{self.depends.world_size} tasks. Waiting...")
495+
time.sleep(2 * 60)
496+
497+
self._launched = True
498+
# 3) Check if all tasks are already completed
481499
incomplete_ranks = self.get_incomplete_ranks(range(self.world_size))
482500
if not incomplete_ranks:
483501
logger.info(f"All {self.world_size} tasks appear to be completed already. Nothing to run.")

src/datatrove/io.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -293,6 +293,10 @@ def get_fs_with_filepath(data: DataFileLike) -> tuple[AbstractFileSystem, str]:
293293
# (str path, initialized fs object)
294294
if isinstance(data, tuple) and isinstance(data[0], str) and isinstance(data[1], AbstractFileSystem):
295295
return (data[1], data[0]) # yeah yeah this is a bit weird I agree
296+
297+
if isinstance(data, DataFolder):
298+
return (data.fs, data.path)
299+
296300
raise ValueError(
297301
"You must pass a DataFileLike instance, a str path, a (str path, fs_init_kwargs) or (str path, fs object)"
298302
)

tests/executor/test_ray.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,54 @@ def run(self, data, rank=None, world_size=None):
6565
f"Expected file {file} was not found in {log_dir}",
6666
)
6767

68+
def test_dependencies(self):
69+
"""Test that multiple executors can depend on the same parent executor and the parent only runs once."""
70+
71+
parent_log_dir = get_datafolder(f"{self.tmp_dir}/parent")
72+
73+
class ParentSimpleStep(PipelineStep):
74+
def run(self, data, rank=None, world_size=None):
75+
with open(parent_log_dir.resolve_paths("parent.txt"), "a") as f:
76+
f.write(f"called {rank}\n")
77+
78+
class ChildSimpleStep(PipelineStep):
79+
def run(self, data, rank=None, world_size=None):
80+
pass
81+
82+
# Create parent executor
83+
parent_executor = RayPipelineExecutor(
84+
pipeline=[ParentSimpleStep()],
85+
tasks=2,
86+
workers=2,
87+
logging_dir=parent_log_dir,
88+
)
89+
90+
# Create two child executors that depend on the same parent
91+
child1_log_dir = get_datafolder(f"{self.tmp_dir}/child1")
92+
child1_executor = RayPipelineExecutor(
93+
pipeline=[ChildSimpleStep()],
94+
tasks=2,
95+
workers=2,
96+
logging_dir=child1_log_dir,
97+
depends=parent_executor,
98+
)
99+
100+
child2_log_dir = get_datafolder(f"{self.tmp_dir}/child2")
101+
child2_executor = RayPipelineExecutor(
102+
pipeline=[ChildSimpleStep()],
103+
tasks=2,
104+
workers=2,
105+
logging_dir=child2_log_dir,
106+
depends=parent_executor,
107+
)
108+
109+
# Run child1 - this should launch the parent first
110+
child1_executor.run()
111+
child2_executor.run()
112+
with open(parent_log_dir.resolve_paths("parent.txt"), "r") as f:
113+
# Two calls because of two tasks
114+
self.assertEqual(sorted(f.read().strip().splitlines()), ["called 0", "called 1"])
115+
68116
def test_placement_group_creation(self):
69117
"""Test that placement groups are created when nodes_per_task > 1"""
70118
from datatrove.executor.ray import RayTaskManager

0 commit comments

Comments
 (0)