Skip to content

Commit f17215b

Browse files
committed
Merge remote-tracking branch 'origin/main' into ray_fix
2 parents d92e3b8 + 8bbda9c commit f17215b

File tree

29 files changed

+3894
-1087
lines changed

29 files changed

+3894
-1087
lines changed

.github/workflows/testing.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,5 +45,7 @@ jobs:
4545
uv pip install --system .[testing]
4646
python -m nltk.downloader punkt
4747
- name: Test with pytest
48+
env:
49+
PYTHONFAULTHANDLER: 1
4850
run: |
4951
python -m pytest -sv ./tests/

.gitignore

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -160,4 +160,7 @@ cython_debug/
160160

161161
.vscode/
162162

163-
playground/
163+
playground/
164+
165+
# codex
166+
node_modules/

README.md

Lines changed: 48 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -321,13 +321,57 @@ Some options common to most readers:
321321
- `limit` read only a certain number of samples. Useful for testing/debugging
322322

323323
### Synthetic data generation
324-
We support [vLLM](https://github.com/vllm-project/vllm) and [SGLang](https://github.com/sgl-project/sglang) for inference using the [InferenceRunner block](src/datatrove/pipeline/inference/run_inference.py). Each datatrove task will spawn a replica of the target model and full asynchronous continuous batching will guarantee high model throughput.
324+
Install the inference extras with `pip install datatrove[inference]` to pull in the lightweight HTTP client, checkpointing dependencies and async sqlite cache.
325325

326-
By setting `checkpoints_local_dir` and `records_per_chunk` generations will be written to a local folder until a chunk is complete, allowing for checkpointing in case tasks fail or are preempted.
326+
We support [vLLM](https://github.com/vllm-project/vllm), [SGLang](https://github.com/sgl-project/sglang), OpenAI-compatible HTTPS endpoints and a local `dummy` server through the [InferenceRunner block](src/datatrove/pipeline/inference/run_inference.py). Each datatrove task can spin up its own server replica (for `vllm`, `sglang` or `dummy`) or talk directly to an external endpoint while asynchronous batching keeps GPU utilization high.
327327

328-
Tune `max_concurrent_requests` to tune batching behaviour. If you have slow pre-processing, you can also increase `max_concurrent_tasks` (to a value higher than `max_concurrent_requests`).
328+
Rollouts are plain async callables that receive a `Document`, a `generate(payload)` callback and any extra kwargs coming from `shared_context`. You can freely orchestrate multiple sequential or parallel `generate` calls inside the rollout. Set `rollouts_per_document` to automatically run the same rollout multiple times per sample; the runner collects successful outputs under `document.metadata["rollout_results"]`.
329329

330-
Refer to the [example](examples/inference_example_chunked.py) for more info.
330+
`shared_context` lets you inject shared state into every rollout invocation. It accepts:
331+
- a dict (passed through as keyword arguments),
332+
- a callable returning a dict (handy for lazily creating resources),
333+
- a context manager or a callable returning one (great for pools, GPU allocators, temp dirs, etc.). Context managers are properly entered/exited once per task.
334+
335+
Recoverable generation:
336+
- Setting `checkpoints_local_dir` together with `records_per_chunk` writes every `Document` to local chunk files (remember to include `${chunk_index}` in the output filename template), then uploads them via the configured writer. Failed tasks automatically resume from the last finished chunk.
337+
- When checkpointing is enabled a sqlite-backed `RequestCache` deduplicates individual rollouts via payload hashes (requires `xxhash` and `aiosqlite`) so completed generations are never re-sent during retries.
338+
339+
Tune batching with `max_concurrent_generations` and, when pre/post-processing is heavy, raise `max_concurrent_documents` to allow more rollout coroutines to build payloads while requests are in flight.
340+
341+
<details>
342+
<summary>Minimal end-to-end example</summary>
343+
344+
```
345+
from datatrove.data import Document
346+
from datatrove.executor.local import LocalPipelineExecutor
347+
from datatrove.pipeline.inference.run_inference import InferenceConfig, InferenceRunner
348+
from datatrove.pipeline.writers import JsonlWriter
349+
350+
async def simple_rollout(doc: Document, generate):
351+
payload = {"messages": [{"role": "user", "content": [{"type": "text", "text": doc.text}]}], "max_tokens": 2048}
352+
return await generate(payload)
353+
354+
documents = [Document(text="What's the weather in Tokyo?", id=str(i)) for i in range(1005)]
355+
config = InferenceConfig(server_type="vllm", model_name_or_path="google/gemma-3-27b-it", rollouts_per_document=1, max_concurrent_generations=500)
356+
357+
LocalPipelineExecutor(
358+
pipeline=[
359+
documents,
360+
InferenceRunner(
361+
rollout_fn=simple_rollout,
362+
config=config,
363+
records_per_chunk=500,
364+
checkpoints_local_dir="/fsx/.../translate-checkpoints",
365+
output_writer=JsonlWriter("s3://.../final_output_data", output_filename="${rank}_chunk_${chunk_index}.jsonl"),
366+
),
367+
],
368+
logging_dir="/fsx/.../inference_logs",
369+
tasks=1,
370+
).run()
371+
```
372+
</details>
373+
374+
The extended [inference_example_chunked.py](examples/inference_example_chunked.py) script demonstrates single- and multi-rollout flows, resumable checkpoints and sharing a process pool across rollouts.
331375

332376
### Extracting text
333377
You can use [extractors](src/datatrove/pipeline/extractors) to extract text content from raw html. The most commonly used extractor in datatrove is [Trafilatura](src/datatrove/pipeline/extractors/trafilatura.py), which uses the [trafilatura](https://trafilatura.readthedocs.io/en/latest/) library.
Lines changed: 117 additions & 90 deletions
Original file line numberDiff line numberDiff line change
@@ -1,33 +1,35 @@
11
"""
2-
Chunked inference pipeline example with chunking.
2+
Chunked inference pipeline example with rollouts.
33
44
This example shows how to run inference on documents using the InferenceRunner
5-
with chunking enabled. Documents are processed in chunks with checkpoint support
6-
for resuming from failures. Each chunk is saved to a separate output file.
5+
with checkpointing enabled. Documents are processed with a rollout function that
6+
can perform multiple generations per document before the results are written.
77
"""
88

9-
from typing import Any, AsyncGenerator
9+
import asyncio
10+
from concurrent.futures import ProcessPoolExecutor
11+
from contextlib import contextmanager
12+
from functools import partial
13+
from typing import Any, Awaitable, Callable
1014

1115
from datatrove.data import Document
1216
from datatrove.executor.local import LocalPipelineExecutor
13-
from datatrove.pipeline.inference.run_inference import InferenceConfig, InferenceRunner
17+
from datatrove.executor.slurm import SlurmPipelineExecutor
18+
from datatrove.pipeline.inference.run_inference import InferenceConfig, InferenceResult, InferenceRunner
1419
from datatrove.pipeline.writers import JsonlWriter
1520

1621

17-
# For creating query payloads, you have 2 options:
18-
# 1. Create a simple query builder that returns a dict
19-
def simple_query_builder(runner: InferenceRunner, document: Document) -> dict[str, Any] | None:
22+
async def simple_rollout(
23+
document: Document,
24+
generate: Callable[[dict[str, Any]], Awaitable[InferenceResult]],
25+
) -> InferenceResult:
2026
"""
21-
Simple query builder that extracts text from document for OCR processing.
27+
Basic rollout that sends a single request per document.
2228
23-
Args:
24-
runner: Inference runner instance
25-
document: Input document with text content
26-
27-
Returns:
28-
Query payload for the inference server
29+
Returns the InferenceResult directly, which will be stored under document.metadata["rollout_results"].
2930
"""
30-
return {
31+
32+
payload = {
3133
"messages": [
3234
{
3335
"role": "user",
@@ -39,110 +41,98 @@ def simple_query_builder(runner: InferenceRunner, document: Document) -> dict[st
3941
"max_tokens": 2048,
4042
}
4143

44+
return await generate(payload)
4245

43-
def large_sample_query_builder(runner: InferenceRunner, document: Document) -> dict[str, Any] | None:
44-
"""Query builder that chunks long samples and requests callbacks for continuation."""
4546

46-
MAX_CHARS_PER_PART = 4000
47+
async def chunked_rollout(
48+
document: Document,
49+
generate: Callable[[dict[str, Any]], Awaitable[InferenceResult]],
50+
) -> str:
51+
"""
52+
Rollout that chunks long inputs and stitches the generations together.
53+
"""
54+
4755
instruction = "Rewrite this in a more formal style:"
48-
chunks = document.metadata.get("chunks")
49-
if not chunks:
50-
text = document.text
51-
if len(text) > MAX_CHARS_PER_PART:
52-
chunks = [text[i : i + MAX_CHARS_PER_PART] for i in range(0, len(text), MAX_CHARS_PER_PART)]
53-
document.metadata["chunks"] = chunks
54-
else:
55-
chunks = [text]
56-
57-
inference_results = document.metadata.get("inference_results") or []
58-
total_parts = len(chunks)
59-
current_index = min(len(inference_results), total_parts - 1)
60-
current_chunk = chunks[current_index]
61-
62-
if current_index == 0:
63-
payload = {
64-
"messages": [
65-
{
66-
"role": "user",
67-
"content": f"{instruction}\n\n{current_chunk}",
68-
}
69-
],
70-
}
71-
else:
72-
previous_chunk = chunks[current_index - 1]
73-
previous_result = inference_results[-1]
74-
previous_generation = getattr(previous_result, "text", str(previous_result))
56+
max_chars_per_part = 4000
57+
text = document.text
58+
chunks = [text[i : i + max_chars_per_part] for i in range(0, len(text), max_chars_per_part)] or [text]
59+
60+
generations: list[dict[str, Any]] = []
61+
prev_chunk = None
62+
63+
for chunk in chunks:
64+
# here we just ask the model to continue the previous generation or an empty msg if there isn't anything
7565
payload = {
7666
"messages": [
7767
{
7868
"role": "user",
79-
"content": f"{instruction}\n\n{previous_chunk}{current_chunk}",
69+
"content": f"{instruction}\n\n{prev_chunk if prev_chunk else ''}{chunk}",
8070
},
8171
{
8272
"role": "assistant",
83-
"content": previous_generation,
73+
"content": generations[-1] if generations else "",
8474
},
8575
],
86-
# see these params here https://docs.vllm.ai/en/v0.7.2/api/offline_inference/llm.html#vllm.LLM.chat
76+
# see https://docs.vllm.ai/en/v0.7.2/api/offline_inference/llm.html#vllm.LLM.chat
8777
"continue_final_message": True,
8878
"add_generation_prompt": False,
8979
"echo": False,
9080
}
9181

92-
# if we have a bunch of chunks for this sample, we want this function to be called again after the next generation is completed
93-
payload["callback"] = len(inference_results) < total_parts - 1
94-
return payload
95-
82+
# could potentially have some error handling here
83+
result: InferenceResult = await generate(payload)
84+
generations.append(result.text)
85+
prev_chunk = chunk
86+
return "\n".join(generations)
9687

97-
# 2. Create an async query builder that returns an async generator of dicts. Use this option if you need
98-
# a) Create multiple requests per document
99-
# b) Your query function is IO/CPU heavy
10088

101-
102-
def heavy_cpu_task(document: Document, page: int):
103-
# block sleep
89+
def cpu_heavy_build_payload(doc: Document, page: int) -> dict[str, Any]:
90+
# simulate heavy work
10491
import time
10592

93+
# not async on purpose
10694
time.sleep(10)
10795
return {
10896
"messages": [
10997
{
11098
"role": "user",
111-
"content": [{"type": "text", "text": document.text}],
99+
"content": [{"type": "text", "text": f"[page {page}] {doc.text}"}],
112100
}
113101
],
114102
"max_tokens": 4096,
115103
}
116104

117105

118-
async def async_query_builder(runner: InferenceRunner, document: Document) -> AsyncGenerator[dict[str, Any], None]:
119-
"""
120-
Query builder for Language Model.
106+
@contextmanager
107+
def process_pool_context(max_workers: int = 100):
108+
"""Context manager for ProcessPoolExecutor that ensures proper cleanup."""
109+
with ProcessPoolExecutor(max_workers=max_workers) as pool:
110+
# This resource will be accessible in the rollout function as a keyword argument
111+
# (and shared for all rollout invocations). try/finally syntax works too
112+
yield {"process_pool": pool}
121113

122-
Args:
123-
document: Input document with image URL or content
124114

125-
Returns:
126-
Async generator of query payloads for the inference server
115+
async def heavy_cpu_rollout(
116+
document: Document,
117+
generate: Callable[[dict[str, Any]], Awaitable[InferenceResult]],
118+
process_pool: ProcessPoolExecutor,
119+
) -> list[InferenceResult]:
120+
"""
121+
Example rollout that offloads heavy preprocessing to a process pool.
122+
123+
The process_pool should be provided via shared_context when creating the InferenceRunner.
124+
See example usage below.
127125
"""
128-
import asyncio
129-
import atexit
130-
from concurrent.futures import ProcessPoolExecutor
131126

132-
# Because it's async, you can run IO heavy tasks with little to no overhead (simply use await)
133-
# If you need to run CPU heavy tasks, it's a bit more complicated
134-
# 1. create a process pool executor and bind it to the runner
135-
# 2. access the process pool, then using asyncio.run_in_executor
127+
loop = asyncio.get_running_loop()
136128

137-
# If we didn't run with this the whole execution would take at least 1000*2*10 seconds
138-
if not hasattr(runner, "process_pool"):
139-
runner.process_pool = ProcessPoolExecutor(max_workers=100)
140-
runner.process_pool.__enter__()
141-
# Register cleanup
142-
atexit.register(runner.process_pool.__exit__, None, None, None)
129+
async def process_page(page: int) -> InferenceResult:
130+
payload = await loop.run_in_executor(process_pool, cpu_heavy_build_payload, document, page)
131+
return await generate(payload)
143132

144-
for page in [1, 2]:
145-
yield await asyncio.get_running_loop().run_in_executor(runner.process_pool, heavy_cpu_task, document, page)
133+
page_results = await asyncio.gather(*[process_page(page) for page in [1, 2]], return_exceptions=True)
134+
135+
return page_results
146136

147137

148138
# Configuration
@@ -158,32 +148,69 @@ async def async_query_builder(runner: InferenceRunner, document: Document) -> As
158148
config: InferenceConfig = InferenceConfig(
159149
server_type="vllm", # Options: "sglang", "vllm", "dummy"
160150
model_name_or_path="reducto/RolmOCR",
161-
temperature=0.0,
162151
model_max_context=8192,
163-
max_concurrent_requests=500,
164-
max_concurrent_tasks=500,
165152
metric_interval=120,
153+
default_generation_params={"temperature": 0.0},
154+
rollouts_per_document=1,
155+
max_concurrent_generations=500,
166156
)
167157

168158
# Create the pipeline with chunking
159+
# Example 1: Simple rollout without shared context
169160
pipeline_executor: LocalPipelineExecutor = LocalPipelineExecutor(
170161
pipeline=[
171-
# Read input documents
172162
documents,
173163
InferenceRunner(
174-
query_builder=large_sample_query_builder,
164+
rollout_fn=chunked_rollout,
175165
config=config,
176166
records_per_chunk=500, # Enable chunking with 500 documents per chunk
177-
checkpoints_local_dir=CHECKPOINTS_PATH, # leave unset to disable checkpointing behaviour
167+
checkpoints_local_dir=CHECKPOINTS_PATH, # Leave unset to disable checkpointing
178168
output_writer=JsonlWriter(OUTPUT_PATH, output_filename="${rank}_chunk_${chunk_index}.jsonl"),
179-
# you can also pass a postprocess_fn(document) -> document|None to modify/filter the document after inference. Return None to remove the document
180-
postprocess_fn=None,
181169
),
182170
],
183171
logging_dir=LOGS_PATH,
184172
tasks=1, # Number of parallel tasks
185173
)
186174

175+
# Example 2: Rollout with shared context (process pool)
176+
pipeline_executor_with_pool = LocalPipelineExecutor(
177+
pipeline=[
178+
documents,
179+
InferenceRunner(
180+
rollout_fn=heavy_cpu_rollout,
181+
config=config,
182+
records_per_chunk=500,
183+
checkpoints_local_dir=CHECKPOINTS_PATH,
184+
output_writer=JsonlWriter(OUTPUT_PATH, output_filename="${rank}_chunk_${chunk_index}.jsonl"),
185+
# we could call it without partial, but this way the pool is initialized lazily and not before the job starts
186+
shared_context=partial(process_pool_context, max_workers=100),
187+
),
188+
],
189+
logging_dir=LOGS_PATH,
190+
tasks=1,
191+
)
192+
193+
# Example 3: Distributed inference
194+
pipeline_executor_distributed = SlurmPipelineExecutor(
195+
tasks=100,
196+
time="10:00:00",
197+
partition="hopper-prod",
198+
gpus_per_task=8,
199+
nodes_per_task=2,
200+
logging_dir=LOGS_PATH,
201+
pipeline=[
202+
documents,
203+
InferenceRunner(
204+
rollout_fn=chunked_rollout,
205+
config=InferenceConfig(
206+
server_type="vllm",
207+
model_name_or_path="deepseek-ai/DeepSeek-R1",
208+
tp=16,
209+
),
210+
output_writer=JsonlWriter(OUTPUT_PATH),
211+
),
212+
],
213+
)
187214
if __name__ == "__main__":
188215
# Run the pipeline
189216
pipeline_executor.run()

0 commit comments

Comments
 (0)