Skip to content

Commit 8bbda9c

Browse files
authored
Merge pull request #406 from huggingface/multi-node-inference
Multi-Node Distributed Inference Support
2 parents c337fce + 0715bff commit 8bbda9c

File tree

22 files changed

+2338
-657
lines changed

22 files changed

+2338
-657
lines changed

examples/inference_example_chunked.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
from datatrove.data import Document
1616
from datatrove.executor.local import LocalPipelineExecutor
17+
from datatrove.executor.slurm import SlurmPipelineExecutor
1718
from datatrove.pipeline.inference.run_inference import InferenceConfig, InferenceResult, InferenceRunner
1819
from datatrove.pipeline.writers import JsonlWriter
1920

@@ -189,6 +190,27 @@ async def process_page(page: int) -> InferenceResult:
189190
tasks=1,
190191
)
191192

193+
# Example 3: Distributed inference
194+
pipeline_executor_distributed = SlurmPipelineExecutor(
195+
tasks=100,
196+
time="10:00:00",
197+
partition="hopper-prod",
198+
gpus_per_task=8,
199+
nodes_per_task=2,
200+
logging_dir=LOGS_PATH,
201+
pipeline=[
202+
documents,
203+
InferenceRunner(
204+
rollout_fn=chunked_rollout,
205+
config=InferenceConfig(
206+
server_type="vllm",
207+
model_name_or_path="deepseek-ai/DeepSeek-R1",
208+
tp=16,
209+
),
210+
output_writer=JsonlWriter(OUTPUT_PATH),
211+
),
212+
],
213+
)
192214
if __name__ == "__main__":
193215
# Run the pipeline
194216
pipeline_executor.run()

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,7 @@ inference = [
8989
"aiosqlite",
9090
]
9191
ray = [
92-
"ray"
92+
"ray[default]"
9393
]
9494
quality = [
9595
"ruff>=0.1.5"

src/datatrove/executor/base.py

Lines changed: 53 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,12 @@
11
import dataclasses
22
import json
3+
import os
34
import random
45
import time
56
from abc import ABC, abstractmethod
67
from collections import deque
78
from collections.abc import Sequence
8-
from typing import Callable
9+
from typing import Callable, TypedDict
910

1011
from datatrove.io import DataFolderLike, get_datafolder
1112
from datatrove.pipeline.base import PipelineStep
@@ -20,6 +21,19 @@
2021
from datatrove.utils.stats import PipelineStats
2122

2223

24+
class DistributedEnvVars(TypedDict):
25+
"""Required environment variables that must be set by get_distributed_env.
26+
27+
All values must be strings.
28+
"""
29+
30+
datatrove_node_ips: str # comma-separated list of node IPs/hostnames
31+
datatrove_cpus_per_task: str # number of CPUs per task
32+
datatrove_mem_per_cpu: str # memory per CPU in GB
33+
datatrove_gpus_on_node: str # number of GPUs on the node
34+
datatrove_executor: str # executor type
35+
36+
2337
class PipelineExecutor(ABC):
2438
"""Base class for pipeline executors (local, slurm, etc.)
2539
@@ -62,22 +76,50 @@ def world_size(self) -> int:
6276
"""
6377
return 0
6478

65-
def _run_for_rank(self, rank: int, local_rank: int = 0) -> PipelineStats:
79+
@abstractmethod
80+
def get_distributed_env(self, node_rank: int = -1) -> DistributedEnvVars:
81+
"""
82+
Returns a dictionary of environment variables to set for distributed execution.
83+
This method is called by `_run_for_rank` to set up the distributed environment.
84+
85+
Args:
86+
node_rank: node rank/ID. -1 means single node mode (default).
87+
88+
Returns: DistributedEnvVars dictionary with all required environment variables.
89+
All values must be strings.
90+
"""
91+
pass
92+
93+
def _set_distributed_environment(self, node_rank: int):
94+
env_vars = self.get_distributed_env(node_rank)
95+
os.environ["DATATROVE_NODE_RANK"] = str(node_rank)
96+
os.environ["DATATROVE_EXECUTOR"] = env_vars["datatrove_executor"]
97+
os.environ["DATATROVE_NODE_IPS"] = env_vars["datatrove_node_ips"]
98+
os.environ["DATATROVE_CPUS_PER_TASK"] = env_vars["datatrove_cpus_per_task"]
99+
os.environ["DATATROVE_MEM_PER_CPU"] = env_vars["datatrove_mem_per_cpu"]
100+
os.environ["DATATROVE_GPUS_ON_NODE"] = env_vars["datatrove_gpus_on_node"]
101+
102+
def _run_for_rank(self, rank: int, local_rank: int = 0, node_rank: int = -1) -> PipelineStats:
66103
"""
67104
Main executor's method. Sets up logging, pipes data from each pipeline step to the next, saves statistics
68-
and marks tasks as completed.
105+
and marks tasks as completed. We assume node_rank == 0 is the master node. node_rank == -1 means single node mode.
106+
Completion is only marked on the master node, all other nodes are ignored in terms of job completion as we use 1-master, many-workers mode.
107+
In this case it's master responsibility to check for workers completion and mark the job as complete.
69108
Args:
70109
rank: the rank that we want to run the pipeline for
71110
local_rank: at the moment this is only used for logging.
72111
Any task with local_rank != 0 will not print logs to console.
73-
112+
node_rank: node rank/ID for logging prefix. Logs will be prefixed with [NODE X] if node_rank != -1. We assume node_rank == 0 is the master node. -1 means single node mode (default).
74113
Returns: the stats for this task
75114
76115
"""
77116
if self.is_rank_completed(rank):
78117
logger.info(f"Skipping {rank=} as it has already been completed.")
79118
return PipelineStats()
80-
logfile = add_task_logger(self.logging_dir, rank, local_rank)
119+
120+
self._set_distributed_environment(node_rank)
121+
122+
logfile = add_task_logger(self.logging_dir, rank, local_rank, node_rank=node_rank)
81123
log_pipeline(self.pipeline)
82124

83125
if self.randomize_start_duration > 0:
@@ -97,13 +139,13 @@ def _run_for_rank(self, rank: int, local_rank: int = 0) -> PipelineStats:
97139

98140
logger.success(f"Processing done for {rank=}")
99141

100-
# stats
142+
# stats - only save on master node in distributed setting (or when node_rank <= 0 for single node)
101143
stats = PipelineStats(self.pipeline)
102-
with self.logging_dir.open(f"stats/{rank:05d}.json", "w") as f:
103-
stats.save_to_disk(f)
104-
logger.info(stats.get_repr(f"Task {rank}"))
105-
# completed
106-
self.mark_rank_as_completed(rank)
144+
if node_rank <= 0:
145+
with self.logging_dir.open(f"stats/{rank:05d}.json", "w") as f:
146+
stats.save_to_disk(f)
147+
logger.info(stats.get_repr(f"Task {rank}"))
148+
self.mark_rank_as_completed(rank)
107149
except Exception as e:
108150
logger.exception(e)
109151
raise e

src/datatrove/executor/local.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
import multiprocess
77

8-
from datatrove.executor.base import PipelineExecutor
8+
from datatrove.executor.base import DistributedEnvVars, PipelineExecutor
99
from datatrove.io import DataFolderLike
1010
from datatrove.pipeline.base import PipelineStep
1111
from datatrove.utils.logging import logger
@@ -150,6 +150,19 @@ def run(self):
150150
logger.success(stats.get_repr(f"All {self.local_tasks} tasks"))
151151
return stats
152152

153+
def get_distributed_env(self, node_rank: int = -1) -> DistributedEnvVars:
154+
"""Get distributed environment variables for LOCAL executor."""
155+
# Default values for local execution - these can be overridden if needed
156+
# For now, we'll use reasonable defaults
157+
158+
return DistributedEnvVars(
159+
datatrove_node_ips="localhost",
160+
datatrove_cpus_per_task="-1",
161+
datatrove_mem_per_cpu="-1",
162+
datatrove_gpus_on_node="-1",
163+
datatrove_executor="LOCAL",
164+
)
165+
153166
@property
154167
def world_size(self) -> int:
155168
"""

0 commit comments

Comments
 (0)