diff --git a/examples/specdec_bench/specdec_bench/utils.py b/examples/specdec_bench/specdec_bench/utils.py index 9a52d0ceac2..73d1e048c80 100644 --- a/examples/specdec_bench/specdec_bench/utils.py +++ b/examples/specdec_bench/specdec_bench/utils.py @@ -196,6 +196,10 @@ def _checkpoint_provenance(model_dir): def _is_sensitive_key(key): + # Engine configs can carry non-string dict keys (e.g. int layer ids in a + # serving_config); those are never sensitive field *names*, so skip them. + if not isinstance(key, str): + return False klow = key.lower() if klow in _SENSITIVE_KEY_ALLOWLIST: return False diff --git a/examples/speculative_decoding/eagle_utils.py b/examples/speculative_decoding/eagle_utils.py index f9675e54161..626ea786237 100644 --- a/examples/speculative_decoding/eagle_utils.py +++ b/examples/speculative_decoding/eagle_utils.py @@ -59,7 +59,6 @@ def make_speculative_data_module( train_len=None, answer_only_loss=False, shift_labels=True, - seed: int = 0, ) -> dict: """Create data module for speculative decoding training. @@ -88,14 +87,16 @@ def make_speculative_data_module( ds = load_dataset("json", data_files=data_args.data_path, split="train") if data_args.sample_size > 0: ds = ds.select(range(data_args.sample_size)) + # Map-style dataset: each rank fetches its own DistributedSampler shard. + # Fetch concurrency comes from the DataLoader's num_workers, not a config knob; + # shuffling/order is the sampler's job (seeded by training_args.seed). + # ``server_urls`` accepts a comma-separated string for multi-server fan-out. streaming_cfg = EagleVllmStreamingConfig( - server_url=data_args.streaming_server_url, + server_urls=data_args.streaming_server_url, model=data_args.streaming_model_name, shared_storage_root=data_args.streaming_shared_storage_path, max_seq_len=train_len, answer_only_loss=answer_only_loss, - prefetch=data_args.streaming_prefetch, - seed=seed, ) train_dataset = EagleVllmStreamingDataset( entries=ds, @@ -138,7 +139,9 @@ def make_speculative_data_module( raise ValueError("sample_size must be -1 (use all samples) or a positive integer") if data_args.sample_size > 0: dumped_files = dumped_files[: data_args.sample_size] - train_dataset = OfflineSupervisedDataset(dumped_files, answer_only_loss=answer_only_loss) + train_dataset = OfflineSupervisedDataset( + dumped_files, answer_only_loss=answer_only_loss, tokenizer=tokenizer + ) data_collator = EagleOfflineDataCollator(train_len=train_len) return { diff --git a/examples/speculative_decoding/launch_train.sh b/examples/speculative_decoding/launch_train.sh index 41d71d14173..a6104f35fe6 100755 --- a/examples/speculative_decoding/launch_train.sh +++ b/examples/speculative_decoding/launch_train.sh @@ -19,9 +19,8 @@ # Multi-node: ./launch_train.sh --config ../../modelopt_recipes/general/speculative_decoding/eagle3.yaml --num_nodes 2 --head_node_ip # With overrides: ./launch_train.sh --config my.yaml model.model_name_or_path=xxx training.output_dir=yyy # -# Extra key=value args are forwarded as OmegaConf dotlist overrides to main.py. -# All training config (model, data, hyperparams, eagle, fsdp) lives in the YAML file. -# Only multi-node routing args are passed here; mixed_precision is fixed to bf16. +# Extra key=value args are forwarded as OmegaConf dotlist overrides to main.py; all +# training config lives in the YAML. mixed_precision is fixed to bf16. set -eo pipefail @@ -30,12 +29,14 @@ SCRIPT_DIR="$(dirname "$(readlink -f "$0")")" CONFIG_FILE="" NUM_NODES=1 HEAD_NODE_IP="" +MACHINE_RANK="" EXTRA_ARGS=() while [ $# -gt 0 ]; do case "$1" in --config*) if [[ "$1" != *=* ]]; then shift; fi; CONFIG_FILE="${1#*=}" ;; --num_nodes*) if [[ "$1" != *=* ]]; then shift; fi; NUM_NODES="${1#*=}" ;; --head_node_ip*) if [[ "$1" != *=* ]]; then shift; fi; HEAD_NODE_IP="${1#*=}" ;; + --machine_rank*) if [[ "$1" != *=* ]]; then shift; fi; MACHINE_RANK="${1#*=}" ;; *) EXTRA_ARGS+=("$1") ;; esac shift @@ -46,7 +47,6 @@ if [ -z "$CONFIG_FILE" ]; then exit 1 fi -# GPU count detection if [[ "$NUM_NODES" != "1" ]]; then GPU_PER_NODE=${GPU_PER_NODE:-$(nvidia-smi --query-gpu=name --format=csv,noheader | wc -l)} TOTAL_GPU=$((NUM_NODES * GPU_PER_NODE)) @@ -56,20 +56,28 @@ else echo "Total GPUs: $TOTAL_GPU (single node)" fi -# Multi-node routing args (accelerate only; training config comes from the YAML) -MULTI_NODE_ARGS="" +MULTI_NODE_ARGS=() if [[ "$NUM_NODES" != "1" ]]; then - MULTI_NODE_ARGS="--num_processes $TOTAL_GPU \ - --num_machines $NUM_NODES \ - --machine_rank $SLURM_PROCID \ - --rdzv_backend c10d \ - --main_process_ip $HEAD_NODE_IP \ - --main_process_port 29500" + # --multi_gpu is required even at 1 GPU/node, else accelerate won't form the DDP group. + # machine_rank defaults to $SLURM_PROCID; override --machine_rank if node 0 isn't a trainer. + MULTI_NODE_ARGS=( + --multi_gpu + --num_processes "$TOTAL_GPU" + --num_machines "$NUM_NODES" + --machine_rank "${MACHINE_RANK:-$SLURM_PROCID}" + --main_process_ip "$HEAD_NODE_IP" + --main_process_port 29500 + ) fi export TOKENIZERS_PARALLELISM=False +# argv array, not `sh -c` (which would word-split overrides and run embedded substitutions). +CMD=(accelerate launch --mixed_precision bf16 + "${MULTI_NODE_ARGS[@]}" + "${SCRIPT_DIR}/main.py" --config "$CONFIG_FILE" "${EXTRA_ARGS[@]}") + set -x start_time=$(date +%s) -sh -c "accelerate launch --mixed_precision bf16 $MULTI_NODE_ARGS ${SCRIPT_DIR}/main.py --config $CONFIG_FILE ${EXTRA_ARGS[*]}" +"${CMD[@]}" echo "Total time: $(( $(date +%s) - $start_time )) seconds" diff --git a/examples/speculative_decoding/main.py b/examples/speculative_decoding/main.py index 9b7a9f44d2e..f62b099121d 100644 --- a/examples/speculative_decoding/main.py +++ b/examples/speculative_decoding/main.py @@ -267,7 +267,6 @@ def train(): train_len=training_args.training_seq_len, answer_only_loss=training_args.answer_only_loss, shift_labels=not is_dflash, - seed=training_args.seed, ) callbacks = [EagleTrainingPlot(training_args.ar_validate_steps, training_args.estimate_ar)] @@ -277,13 +276,10 @@ def train(): and recipe.eagle.eagle_base_lora_warmup_steps > 0 ): callbacks.append(LoRAWarmupCallback(recipe.eagle.eagle_base_lora_warmup_steps)) - if recipe.data.mode == "streaming": - # Skip-on-resume happens inside the dataset (no re-fetch from server); - # disable HF Trainer's own data skip so the offset isn't applied twice. - from modelopt.torch.speculative.plugins.hf_streaming_dataset import StreamingResumeCallback - - training_args.ignore_data_skip = True - callbacks.append(StreamingResumeCallback()) + # Leave training_args.ignore_data_skip at its default (False). The dataset is + # map-style, so HF Trainer's resume skips consumed indices at the batch-sampler + # level (accelerate.skip_first_batches) without re-fetching them, landing at the + # exact data position. Setting it True would restart the data order from the top. trainer = EagleTrainerWithAccLog( model=model, diff --git a/modelopt/recipe/config.py b/modelopt/recipe/config.py index 4bf91b52d6f..0932095f6d4 100644 --- a/modelopt/recipe/config.py +++ b/modelopt/recipe/config.py @@ -31,6 +31,18 @@ TrainingArguments as SpecTrainingArgs, ) +__all__ = [ + "RECIPE_TYPE_TO_CLASS", + "ModelOptDFlashRecipe", + "ModelOptEagleRecipe", + "ModelOptMedusaRecipe", + "ModelOptPTQRecipe", + "ModelOptRecipeBase", + "ModelOptSpeculativeRecipeBase", + "RecipeMetadataConfig", + "RecipeType", +] + class RecipeType(str, Enum): """List of recipe types. See ``RECIPE_TYPE_TO_CLASS`` at the bottom for the schema mapping.""" @@ -178,7 +190,11 @@ class ModelOptDFlashRecipe(ModelOptSpeculativeRecipeBase): @model_validator(mode="after") def _derive_dflash_offline(self) -> ModelOptDFlashRecipe: - self.dflash.dflash_offline = self.data.offline_data_path is not None + # offline (dumped .pt) and streaming (hidden states over HTTP from a vLLM + # serve) both feed pre-computed base hidden states to the DFlash module, so + # both set dflash_offline. Only fully-online training runs the base model. + # Mirrors ModelOptEagleRecipe._derive_eagle_offline. + self.dflash.dflash_offline = self.data.mode != "online" return self diff --git a/modelopt/torch/speculative/config.py b/modelopt/torch/speculative/config.py index 6b2c9396ce7..708deafc0d1 100644 --- a/modelopt/torch/speculative/config.py +++ b/modelopt/torch/speculative/config.py @@ -23,6 +23,18 @@ from .eagle.default_config import default_eagle_config, default_kimik2_eagle_config +__all__ = [ + "DFLASH_DEFAULT_CFG", + "EAGLE3_DEFAULT_CFG", + "EAGLE_MTP_DEFAULT_CFG", + "DFlashConfig", + "EagleConfig", + "MedusaConfig", + "eagle3_default_config", + "eagle_mtp_default_config", + "kimik2_eagle_default_config", +] + kimik2_eagle_default_config = deepcopy(default_kimik2_eagle_config) eagle3_default_config = deepcopy(default_eagle_config) @@ -68,8 +80,10 @@ class DFlashConfig(ModeloptBaseConfig): dflash_offline: bool = ModeloptField( default=False, description=( - "Whether to use detached DFlash (offline training from pre-computed hidden states). " - "Derived by ModelOptDFlashRecipe from data.offline_data_path; not user-configurable." + "Whether the DFlash module consumes pre-computed hidden states (offline from " + "dumped .pt files, or streaming over HTTP from a vLLM serve) instead of running " + "the base model. Derived by ModelOptDFlashRecipe from data.mode (True unless " + "online); not user-configurable." ), ) diff --git a/modelopt/torch/speculative/eagle/utils.py b/modelopt/torch/speculative/eagle/utils.py index f74fcb1e9fb..2c536d04991 100644 --- a/modelopt/torch/speculative/eagle/utils.py +++ b/modelopt/torch/speculative/eagle/utils.py @@ -41,6 +41,8 @@ from torch.utils.data import Dataset from transformers.trainer_pt_utils import LabelSmoother +from modelopt.torch.utils.loss_mask import get_loss_mask_recovery + IGNORE_TOKEN_ID = LabelSmoother.ignore_index @@ -96,20 +98,27 @@ class OfflineSupervisedDataset(Dataset): dumped_files (list): A list of file paths to the dumped .pt files. answer_only_loss (bool): If True, use the ``loss_mask`` stored in each .pt file so that only assistant-produced tokens contribute to the loss. - Raises ``ValueError`` on ``__getitem__`` if the file lacks ``loss_mask``. + If a file lacks ``loss_mask`` and ``tokenizer`` has a registered + model-specific recovery (see ``modelopt.torch.utils.loss_mask``), the + mask is rebuilt from ``input_ids``; otherwise ``__getitem__`` raises + ``ValueError``. If False (default), a uniform all-ones mask is used regardless of what is stored in the file (backward compatible). + tokenizer: Optional tokenizer used to recover the assistant mask for dumps + that lack a stored ``loss_mask``. """ def __init__( self, dumped_files, answer_only_loss: bool = False, + tokenizer=None, ): """Initialize with a list of .pt file paths.""" super().__init__() self.dumped_files = dumped_files self.answer_only_loss = answer_only_loss + self.tokenizer = tokenizer def __len__(self): return len(self.dumped_files) @@ -121,13 +130,22 @@ def __getitem__(self, i) -> dict[str, torch.Tensor]: labels[..., :-1] = offline_data["input_ids"][..., 1:] if self.answer_only_loss: - if "loss_mask" not in offline_data: + recovery = get_loss_mask_recovery(self.tokenizer) if self.tokenizer else None + if "loss_mask" in offline_data: + loss_mask = offline_data["loss_mask"].to(offline_data["input_ids"].dtype) + elif recovery is not None: + # Dumps from tokenizers that cannot emit assistant masks carry no + # loss_mask; rebuild it from the token ids. + loss_mask = recovery.compute(self.tokenizer, offline_data["input_ids"]).to( + offline_data["input_ids"].dtype + ) + else: raise ValueError( f"answer_only_loss=True requires a 'loss_mask' entry in the offline " f".pt file, but {self.dumped_files[i]} does not have one. Re-dump " - f"with --answer-only-loss in compute_hidden_states_*.py." + f"with --answer-only-loss in compute_hidden_states_*.py, or pass a " + f"tokenizer with a registered loss-mask recovery." ) - loss_mask = offline_data["loss_mask"].to(offline_data["input_ids"].dtype) else: loss_mask = torch.ones_like(offline_data["input_ids"]) diff --git a/modelopt/torch/speculative/plugins/hf_streaming_dataset.py b/modelopt/torch/speculative/plugins/hf_streaming_dataset.py index 31adbc96bf4..c1be45e9e56 100644 --- a/modelopt/torch/speculative/plugins/hf_streaming_dataset.py +++ b/modelopt/torch/speculative/plugins/hf_streaming_dataset.py @@ -13,11 +13,27 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Streaming datasets that fetch per-sample hidden states from a running inference server. - -The base class :class:`StreamingDataset` owns all the backend-/algorithm- -agnostic plumbing: threading, queue, tokenization, the bounded sliding-window -producer, loss_mask alignment, and HTTP-client lifecycle. Concrete subclasses +"""Map-style datasets that fetch per-sample hidden states from a running inference server. + +This is the streaming sibling of :class:`OfflineSupervisedDataset`: instead of +reading a pre-dumped ``.pt`` file in ``__getitem__``, it fetches the per-sample +hidden states from a live inference server over HTTP. It is a plain +``torch.utils.data.Dataset`` (map-style), so DDP sharding is handled the standard +way -- HF Trainer wraps it in a ``DistributedSampler`` and each rank's DataLoader +calls ``__getitem__`` only for that rank's indices. Each rank therefore fetches +**only its own shard** (no rank-0 funnel, no broadcast); aggregate read bandwidth +scales with the number of trainer ranks. + +Fetch concurrency comes from the DataLoader's ``num_workers`` (each worker process +issues one blocking request at a time); there is no in-process producer thread. +Keep ``num_workers`` modest and bounded so the per-server in-flight request count +(``ranks-hitting-a-server x num_workers``) stays near the server's ``max_num_seqs`` +-- flooding a cold NVFP4 MoE server can stall a worker past vLLM's execute-model +timeout and kill EngineCore. + +The base class :class:`StreamingDataset` owns the backend-/algorithm-agnostic +plumbing: tokenization, the resample-on-failure ``__getitem__`` loop, the +consecutive-failure circuit breaker, and loss_mask alignment. Concrete subclasses specialize along two axes: - **Backend** (how to talk to the server, how to decode the response): override @@ -25,39 +41,51 @@ - **Algorithm** (how to shape the per-sample dict for the trainer): override :meth:`_format`. -:class:`EagleVllmStreamingDataset` is currently the only concrete -combination (Eagle algorithm × vLLM backend); future combinations live as -sibling subclasses. - -Requires ``dataloader_num_workers=0``: multiple workers would each spawn their -own asyncio loop and issue duplicate requests against the server. +:class:`EagleVllmStreamingDataset` is currently the only concrete combination +(Eagle algorithm x vLLM backend); future combinations live as sibling subclasses. """ from __future__ import annotations -import asyncio import contextlib import os -import queue -import random -import threading +import time from pathlib import Path from typing import TypedDict import httpx import torch from pydantic import BaseModel, ConfigDict, Field, field_validator -from safetensors import safe_open -from torch.utils.data import IterableDataset, get_worker_info -from transformers import TrainerCallback +from safetensors import SafetensorError, safe_open +from torch.utils.data import Dataset from transformers.trainer_pt_utils import LabelSmoother -from modelopt.torch.utils import distributed as dist_utils from modelopt.torch.utils import print_rank_0, warn_rank_0 +from modelopt.torch.utils.loss_mask import get_loss_mask_recovery + +__all__ = [ + "EagleFetchPayload", + "EagleVllmStreamingConfig", + "EagleVllmStreamingDataset", + "StreamingConfig", + "StreamingDataset", +] IGNORE_TOKEN_ID = LabelSmoother.ignore_index -_SENTINEL = object() +# The vLLM connector writes the safetensors file asynchronously (writer thread pool) +# and returns its path before the write is durably visible, so an immediate read can +# race the writer. Retry the open with linear backoff until the file lands +# (worst case ~_READ_RETRIES * (_READ_RETRIES+1)/2 * _READ_BACKOFF s). +_READ_RETRIES = 10 +_READ_BACKOFF = 0.05 # seconds + +# Errors from ``_fetch`` that are genuinely transient (server overloaded / connection +# reset / timeout, or the safetensors writer race) and so count against the circuit +# breaker and trigger a resample. Anything else -- notably the ``RuntimeError`` raised +# on server token drift, or a programming/contract bug (``ValueError``/``KeyError``) -- +# is a real fault and propagates instead of being silently masked as a fetch miss. +_TRANSIENT_FETCH_ERRORS = (httpx.HTTPError, OSError, SafetensorError) def _tokenize_with_loss_mask( @@ -73,64 +101,72 @@ def _tokenize_with_loss_mask( tags so the tokenizer can return ``assistant_masks``. When ``max_seq_len`` is set, truncation is delegated to the tokenizer so ids and assistant_masks are truncated in lockstep. + + ``assistant_masks`` requires a fast tokenizer (it needs ``char_to_token``). For + tokenizers without it, the mask is rebuilt from token ids via a registered + model-specific recovery (see ``modelopt.torch.utils.loss_mask``) if one matches. """ + recovery = None + if answer_only_loss and not getattr(tokenizer, "is_fast", False): + recovery = get_loss_mask_recovery(tokenizer) out = tokenizer.apply_chat_template( conversations, tokenize=True, return_tensors="pt", return_dict=True, - return_assistant_tokens_mask=answer_only_loss, + return_assistant_tokens_mask=answer_only_loss and recovery is None, add_generation_prompt=False, truncation=max_seq_len is not None, max_length=max_seq_len, ) input_ids = out["input_ids"] seq_len = input_ids.shape[-1] - if answer_only_loss: + if not answer_only_loss: + loss_mask = torch.ones(seq_len, dtype=torch.long) + elif recovery is not None: + loss_mask = recovery.compute(tokenizer, input_ids[0]) + else: mask = out["assistant_masks"] if not isinstance(mask, torch.Tensor): mask = torch.tensor(mask, dtype=torch.long) loss_mask = mask.squeeze(0).to(torch.long) - if loss_mask.shape[0] != seq_len: - raise RuntimeError( - f"assistant_masks length {loss_mask.shape[0]} does not match " - f"input_ids length {seq_len}" - ) - else: - loss_mask = torch.ones(seq_len, dtype=torch.long) + if loss_mask.shape[0] != seq_len: + raise RuntimeError( + f"loss_mask length {loss_mask.shape[0]} does not match input_ids length {seq_len}" + ) return input_ids, loss_mask class StreamingConfig(BaseModel): """Static tuning knobs for :class:`StreamingDataset`. - Bundles the rarely-changing settings (loss masking, concurrency, HTTP timeout) - so the dataset ctor takes only ``entries`` + ``tokenizer`` + this config. + Bundles the rarely-changing settings (loss masking, HTTP timeout) so the dataset + ctor takes only ``entries`` + ``tokenizer`` + this config. """ model_config = ConfigDict(extra="forbid") answer_only_loss: bool = False - prefetch: int = Field(default=64, ge=1) request_timeout: float = Field(default=600.0, gt=0) # Token-level cap applied during tokenization (right-truncation). Must hold # ``max_seq_len <= vllm.max_model_len``. ``None`` disables truncation. max_seq_len: int | None = None - # Must be identical on every rank — the dataset shuffles with this seed then - # stripes by rank, so equal seeds are required for the partition to be disjoint. - seed: int = 0 - # Circuit breaker: raise after this many consecutive _fetch failures so a dead - # server doesn't silently drain the corpus. + # Circuit breaker: raise after this many consecutive _fetch failures (per worker + # process) so a dead server doesn't silently resample the whole corpus. fail_after_consecutive_skips: int = Field(default=16, ge=1) -class StreamingDataset(IterableDataset): - """Base class: stream per-sample hidden states from a running inference server. +class StreamingDataset(Dataset): + """Base class: map-style dataset that streams per-sample hidden states from a server. Backend- and algorithm-agnostic; subclasses implement :meth:`_fetch` (backend) and :meth:`_format` (algorithm). The dict shape exchanged between them is the algorithm-level contract, declared as a ``TypedDict`` in :attr:`fetch_payload_cls` and validated against the actual ``_fetch`` output on every sample. + + ``__getitem__`` must always return a valid sample for the sampler's index, so it + resamples forward through the corpus on an unfit entry or a fetch failure rather + than skipping (a skip would shrink the batch and desync DDP). """ config_cls: type[StreamingConfig] = StreamingConfig @@ -145,217 +181,80 @@ def __init__( tokenizer, config: StreamingConfig | None = None, ): - """Hold the *full* corpus on every rank; fetch lazily, rank 0 only. + """Hold the full corpus; fetch lazily, per index, in ``__getitem__``. - DDP sharding is delegated to Accelerate's ``DataLoaderDispatcher``: rank 0 - consumes the dataset and broadcasts each batch; non-zero ranks rely on - :meth:`__iter__`'s rank guard. The corpus is held in full on every rank -- - the dispatcher reads only rank 0's stream, so sharding here would just - shrink that view. Shuffling with ``config.seed`` runs on every rank so - the order is reproducible regardless of which rank ends up fetching. + DDP sharding is handled by HF Trainer's ``DistributedSampler``: each rank's + DataLoader requests only its own indices, so each rank fetches only its + shard. The corpus order is left as given -- the sampler shuffles indices + (seeded by ``training_args.seed``), so no shuffle is needed here. Args: entries: Untokenized per-sample dicts from the input jsonl. Schema is - subclass-defined (see :meth:`_tokenize_entry`); passed through to :meth:`_fetch`. + subclass-defined (see :meth:`_tokenize_entry`); passed to :meth:`_fetch`. tokenizer: HF tokenizer; used for client-side tokenization and the server/client loss-mask alignment in :meth:`_fetch`. - config: Tuning knobs (prefetch, timeout, seed, ...); defaults to + config: Tuning knobs (timeout, answer_only_loss, ...); defaults to ``self.config_cls()``. See :class:`StreamingConfig`. """ if not entries: raise ValueError("entries is empty") self.tokenizer = tokenizer self.config = config if config is not None else self.config_cls() - # One-shot, consumed by the next __iter__. - self._resume_skip = 0 - - indices = list(range(len(entries))) - random.Random(self.config.seed).shuffle(indices) - self.entries = [entries[i] for i in indices] - rank, world = dist_utils.rank(), dist_utils.size() - print_rank_0( - f"[{type(self).__name__}] rank {rank}/{world}: " - f"holds {len(self.entries)} entries (full corpus; rank 0 fetches)" - ) + # Materialize to a plain list so DataLoader worker processes fork it cheaply. + self.entries = list(entries) + # Per-process consecutive-failure counter for the circuit breaker. Reset to 0 + # on every successful fetch; tripped only by fetch failures (not unfit entries). + self._consecutive_fail = 0 + print_rank_0(f"[{type(self).__name__}] map-style dataset over {len(self.entries)} entries") def __len__(self) -> int: return len(self.entries) - def set_resume_position(self, skip: int) -> None: - """Drop the first ``skip`` entries on the next ``__iter__`` without fetching. + def __getitem__(self, idx: int) -> dict[str, torch.Tensor]: + """Tokenize -> fetch -> format the sample at ``idx``, resampling on miss. - One-shot; cleared once iteration starts. Used by - :class:`StreamingResumeCallback` on HF Trainer checkpoint resume so the - server is not re-queried for already-consumed samples. + Always returns a valid sample. An unfit entry (tokenization yields nothing) or + a fetch failure causes a forward probe to the next index; fetch failures bump + the circuit breaker, which raises once ``fail_after_consecutive_skips`` is hit. """ - self._resume_skip = skip - - @staticmethod - def _verify_accelerate_dispatcher() -> None: - """Raise if Accelerate is initialized for DDP with ``dispatch_batches=False``. - - Best-effort: no-op when Accelerate isn't installed/initialized or in single-process. - """ - try: - from accelerate.state import AcceleratorState - except ImportError: - return - if not AcceleratorState._shared_state: - return - state = AcceleratorState() - if getattr(state, "num_processes", 1) <= 1: - return - # Field moved to ``dataloader_config`` in newer Accelerate; check both. - dispatch = getattr(state, "dispatch_batches", None) - if dispatch is None: - dl_cfg = getattr(state, "dataloader_config", None) - if dl_cfg is not None: - dispatch = getattr(dl_cfg, "dispatch_batches", None) - if dispatch is False: - raise RuntimeError( - "StreamingDataset requires Accelerate's DataLoaderDispatcher " - "(dispatch_batches=True); got False — non-zero ranks would receive no data." - ) - - def __iter__(self): - # IterableDataset with DataLoader workers > 0 would spawn one asyncio loop - # per worker, each issuing the full request set — silent Nx duplication - # against the server. Fail loud instead. - if get_worker_info() is not None: - raise RuntimeError( - f"{type(self).__name__} requires dataloader_num_workers=0; " - "multiple workers would each spawn an asyncio loop and duplicate requests." - ) - # Without dispatch_batches the rank-0 guard below would silently starve - # non-zero ranks; fail loud instead. - self._verify_accelerate_dispatcher() - # Only rank 0 fetches; non-zero ranks receive batches via the dispatcher's broadcast. - if dist_utils.rank() != 0: - return - # Fresh producer per __iter__ call so re-iteration (which shouldn't - # happen in 1-epoch streaming) at least doesn't deadlock. - q: queue.Queue = queue.Queue(maxsize=self.config.prefetch) - stop = threading.Event() - skip = self._resume_skip - self._resume_skip = 0 # one-shot - entries = self.entries[skip:] if skip else self.entries - - def run(): + n = len(self.entries) + for offset in range(n): + entry = self.entries[(idx + offset) % n] + sample = self._tokenize_entry(entry) + if sample is None: + continue # entry unfit pre-fetch; server not at fault, try the next one try: - asyncio.run(self._produce(q, stop, entries)) - except Exception as e: - q.put(e) # surface to consumer - finally: - q.put(_SENTINEL) - - thread = threading.Thread(target=run, daemon=True) - thread.start() - - try: - while True: - item = q.get() - if item is _SENTINEL: - break - if isinstance(item, Exception): - raise item - yield item - finally: - stop.set() - # Drain any leftover items so producer can exit - with contextlib.suppress(queue.Empty): - while True: - q.get_nowait() - - async def _produce(self, q: queue.Queue, stop: threading.Event, entries): - """Stream ``entries`` through a sliding window of at most ``prefetch`` in-flight tasks. - - Counter is local (single writer); ``_process`` tasks report outcome via return value. - The circuit breaker has *batch-level* (not per-task) granularity: when - ``asyncio.wait(FIRST_COMPLETED)`` returns several tasks in the same loop turn, - ``consecutive_skips`` reflects set-iteration order over ``done`` -- sufficient - for "detect a dead server" but not strict temporal ordering. - - Args: - q: Bounded queue drained by :meth:`__iter__`; full queue backpressures fetching. - stop: Set by the consumer to request shutdown; checked between samples. - entries: Resume-adjusted slice of ``self.entries`` to fetch this iteration. - """ - timeout = httpx.Timeout(self.config.request_timeout, connect=10.0) - threshold = self.config.fail_after_consecutive_skips - consecutive_skips = 0 - async with httpx.AsyncClient(timeout=timeout) as client: - pending: set[asyncio.Task] = set() - entries_iter = iter(entries) - exhausted = False - try: - while not stop.is_set(): - while len(pending) < self.config.prefetch and not exhausted: - try: - entry = next(entries_iter) - except StopIteration: - exhausted = True - break - pending.add(asyncio.create_task(self._process(client, entry, q, stop))) - if not pending: - break - done, pending = await asyncio.wait(pending, return_when=asyncio.FIRST_COMPLETED) - for task in done: - outcome = task.result() # re-raises unexpected errors - if outcome is True: - consecutive_skips = 0 - elif outcome is False: - consecutive_skips += 1 - # None -> entry unfit pre-fetch; server not at fault - if consecutive_skips >= threshold: - raise RuntimeError( - f"{consecutive_skips} consecutive _fetch failures " - f"in {type(self).__name__}; server likely down." - ) - finally: - for task in pending: - task.cancel() - if pending: - await asyncio.gather(*pending, return_exceptions=True) - - async def _process( - self, - client: httpx.AsyncClient, - entry: dict, - q: queue.Queue, - stop: threading.Event, - ) -> bool | None: - """Tokenize -> fetch -> format -> enqueue. - - Returns True on enqueue, False on fetch failure (bumps breaker), None - when the entry is unfit pre-fetch (no breaker effect). - """ - if stop.is_set(): - return None - sample = await asyncio.to_thread(self._tokenize_entry, entry) - if sample is None: - return None - try: - fetched = await self._fetch(client, sample) - except Exception as e: - warn_rank_0(f"[streaming] error for {sample['cid']}: {e!r}") - return False - if fetched is None: - return False - if self.fetch_payload_cls is not None: - # ``__required_keys__`` is a TypedDict runtime attribute mypy doesn't - # track on ``type``; the assignment site guarantees it's a TypedDict. - required: frozenset[str] = self.fetch_payload_cls.__required_keys__ # type: ignore[attr-defined] - missing = required - set(fetched) - if missing: - raise RuntimeError( - f"{type(self).__name__}._fetch missing required keys {missing}; " - f"{self.fetch_payload_cls.__name__} requires " - f"{set(required)}, got {set(fetched)}" - ) - data = self._format(fetched) - # Blocking put -> backpressure when trainer is slow. - await asyncio.to_thread(q.put, data) - return True + fetched = self._fetch(sample) + except _TRANSIENT_FETCH_ERRORS as e: + # Transport/IO miss: count against the circuit breaker and resample. + # Contract violations and bugs are not caught here -- they propagate. + warn_rank_0(f"[streaming] fetch error for {sample['cid']}: {e!r}") + fetched = None + if fetched is None: + self._consecutive_fail += 1 + if self._consecutive_fail >= self.config.fail_after_consecutive_skips: + raise RuntimeError( + f"{self._consecutive_fail} consecutive _fetch failures in " + f"{type(self).__name__}; server likely down." + ) + continue # resample forward + self._consecutive_fail = 0 + if self.fetch_payload_cls is not None: + # ``__required_keys__`` is a TypedDict runtime attribute mypy doesn't + # track on ``type``; the assignment site guarantees it's a TypedDict. + required: frozenset[str] = self.fetch_payload_cls.__required_keys__ # type: ignore[attr-defined] + missing = required - set(fetched) + if missing: + raise RuntimeError( + f"{type(self).__name__}._fetch missing required keys {missing}; " + f"{self.fetch_payload_cls.__name__} requires " + f"{set(required)}, got {set(fetched)}" + ) + return self._format(fetched) + raise RuntimeError( + f"{type(self).__name__}: no fetchable sample found in the entire corpus " + f"({n} entries) starting at index {idx}." + ) def _tokenize_entry(self, entry: dict) -> dict | None: """Tokenize a single entry. @@ -382,14 +281,14 @@ def _tokenize_entry(self, entry: dict) -> dict | None: "loss_mask": loss_mask, } - async def _fetch(self, client: httpx.AsyncClient, sample: dict) -> dict | None: + def _fetch(self, sample: dict) -> dict | None: """Backend hook: send the request and decode the server's response. - Override in subclass. Any scratch resources (per-request files, mmap'd - buffers) must be released before returning. + Override in subclass. Synchronous (called from a DataLoader worker). Any + scratch resources (per-request files, mmap'd buffers) must be released before + returning. Args: - client: Shared async HTTP client owned by :meth:`_produce`. sample: :meth:`_tokenize_entry` output: ``{"cid": str, "token_ids": list[int], "loss_mask": LongTensor[seq]}``. @@ -431,16 +330,24 @@ class EagleFetchPayload(TypedDict): class EagleVllmStreamingConfig(StreamingConfig): """Adds vLLM endpoint info on top of :class:`StreamingConfig`.""" - server_url: str + # One or more vLLM endpoints; fetches round-robin across them so a single fetcher + # can spread load over several server replicas. Accepts a list or a single + # (optionally comma-separated) string. + server_urls: list[str] model: str - # Allowlist for ``hidden_states_path`` returned by the server. Must match the - # connector's ``shared_storage_path``; out-of-tree paths are rejected. + # Allowlist for ``hidden_states_path`` returned by the server. Must match (or be a + # parent of) the connector's ``shared_storage_path``; out-of-tree paths are rejected. shared_storage_root: str - @field_validator("server_url") + @field_validator("server_urls", mode="before") @classmethod - def _strip_trailing_slash(cls, v: str) -> str: - return v.rstrip("/") + def _normalize_urls(cls, v): + if isinstance(v, str): + v = v.split(",") + urls = [u.strip().rstrip("/") for u in v if u and str(u).strip()] + if not urls: + raise ValueError("server_urls must contain at least one non-empty URL") + return urls @field_validator("shared_storage_root") @classmethod @@ -449,7 +356,7 @@ def _resolve_root(cls, v: str) -> str: class EagleVllmStreamingDataset(StreamingDataset): - """Eagle (algorithm) × vLLM (backend). + """Eagle (algorithm) x vLLM (backend). Talks to a ``vllm serve`` instance configured with the ``ExampleHiddenStatesConnector`` KV-transfer connector (the server dumps captured @@ -467,13 +374,47 @@ def __init__( tokenizer, config: EagleVllmStreamingConfig, ): - """Same as the base; ``config`` must include ``server_url`` and ``model``.""" + """Same as the base; ``config`` must include ``server_urls`` and ``model``.""" super().__init__(entries=entries, tokenizer=tokenizer, config=config) self.config: EagleVllmStreamingConfig = config - async def _fetch(self, client: httpx.AsyncClient, sample: dict) -> EagleFetchPayload | None: - r = await client.post( - f"{self.config.server_url}/v1/completions", + def _client(self) -> httpx.Client: + """Lazily build a per-process HTTP client and round-robin cursor. + + DataLoader workers are forked processes; httpx connection pools must not be + shared across a fork, so each process gets its own client (and its own + round-robin cursor over ``server_urls``), keyed by PID. The cursor starts + at a per-(rank, worker) offset so cold-start fetches fan out across + replicas instead of all hitting ``server_urls[0]``. + """ + pid = os.getpid() + if getattr(self, "_client_pid", None) != pid: + self._http = httpx.Client( + timeout=httpx.Timeout(self.config.request_timeout, connect=10.0) + ) + self._client_pid = pid + # Stagger the initial cursor by (rank, worker) so cold-start fetches + # fan out instead of all pinning server_urls[0] (which can flood one + # cold replica past its execute-model timeout and kill the EngineCore). + info = torch.utils.data.get_worker_info() + worker_id = info.id if info is not None else 0 + num_workers = info.num_workers if info is not None else 1 + rank = int(os.environ.get("RANK", "0")) + self._rr = rank * num_workers + worker_id + return self._http + + def _next_url(self) -> str: + """Round-robin the next server URL (per-process cursor).""" + urls = self.config.server_urls + url = urls[self._rr % len(urls)] + self._rr += 1 + return url + + def _fetch(self, sample: dict) -> EagleFetchPayload | None: + client = self._client() + url = self._next_url() + r = client.post( + f"{url}/v1/completions", json={ "model": self.config.model, "prompt": sample["token_ids"], @@ -492,7 +433,7 @@ async def _fetch(self, client: httpx.AsyncClient, sample: dict) -> EagleFetchPay f"[streaming] path outside shared_storage_root for {sample['cid']}: {path!r}" ) return None - token_ids, hidden_states = await asyncio.to_thread(self._load_safetensors, path) + token_ids, hidden_states = self._load_safetensors(path) # Contract: the server tokenization is the client's pre-tokenized prompt # verbatim, plus at most one decode-step token at the tail (from # ``max_tokens=1``). Anything else (e.g. server-side BOS prepend, chat @@ -529,13 +470,25 @@ def _load_safetensors(path: str) -> tuple[torch.Tensor, torch.Tensor]: ``safe_open(..., framework="pt").get_tensor`` materializes an independent torch Tensor (not a view into the mmap'd file), so it is safe to unlink right after the ``with`` block exits. + + Retries past the writer race (see ``_READ_RETRIES``): a missing file means + the write hasn't started; a ``SafetensorError`` means it's mid-write. Both + clear once the writer finishes, so back off and retry before giving up. """ - with safe_open(path, framework="pt") as f: - token_ids = f.get_tensor("token_ids") - hidden_states = f.get_tensor("hidden_states") # [seq, n_layers, hidden] - with contextlib.suppress(OSError): - os.unlink(path) - return token_ids, hidden_states + for attempt in range(_READ_RETRIES): + try: + with safe_open(path, framework="pt") as f: + token_ids = f.get_tensor("token_ids") + hidden_states = f.get_tensor("hidden_states") # [seq, n_layers, hidden] + with contextlib.suppress(OSError): + os.unlink(path) + return token_ids, hidden_states + except (FileNotFoundError, SafetensorError): # noqa: PERF203 -- retry-on-race loop + if attempt == _READ_RETRIES - 1: + raise + time.sleep(_READ_BACKOFF * (attempt + 1)) + # Unreachable (the last attempt above re-raises); guards _READ_RETRIES < 1. + raise RuntimeError(f"_load_safetensors exhausted {_READ_RETRIES} retries for {path}") @staticmethod def _align_loss_mask(loss_mask: torch.Tensor, n: int) -> torch.Tensor: @@ -573,36 +526,3 @@ def _format(self, fetched: EagleFetchPayload) -> dict[str, torch.Tensor]: "loss_mask": loss_mask, "labels": labels, } - - -class StreamingResumeCallback(TrainerCallback): - """Fast-forward :class:`StreamingDataset` past consumed samples on resume. - - Dispatcher pulls a *global* batch per micro-step, hence the ``world_size`` factor. - Requires ``training_args.ignore_data_skip=True``; round-trips only when - ``world_size`` and ``config.seed`` match the original run. - """ - - def on_train_begin(self, args, state, control, train_dataloader=None, **kwargs): - """Push the skip count into the dataset when resuming mid-training.""" - if state.global_step <= 0 or train_dataloader is None: - return - ds = train_dataloader.dataset - if not hasattr(ds, "set_resume_position"): - return - if not getattr(args, "ignore_data_skip", False): - raise RuntimeError( - "StreamingResumeCallback requires ignore_data_skip=True to avoid " - "double-skipping on resume." - ) - consumed = ( - state.global_step - * args.per_device_train_batch_size - * dist_utils.size() - * args.gradient_accumulation_steps - ) - ds.set_resume_position(consumed) - print_rank_0( - f"[StreamingResumeCallback] resuming at global_step={state.global_step}; " - f"skipping {consumed} entries" - ) diff --git a/modelopt/torch/speculative/plugins/hf_training_args.py b/modelopt/torch/speculative/plugins/hf_training_args.py index a65a3183a05..6f86a467ab2 100644 --- a/modelopt/torch/speculative/plugins/hf_training_args.py +++ b/modelopt/torch/speculative/plugins/hf_training_args.py @@ -31,7 +31,9 @@ from typing import Literal -from pydantic import BaseModel, ConfigDict, Field, field_validator, model_validator +from pydantic import BaseModel, ConfigDict, field_validator, model_validator + +__all__ = ["DataArguments", "ModelArguments", "TrainingArguments"] class ModelArguments(BaseModel): @@ -62,7 +64,6 @@ class DataArguments(BaseModel): sample_size: int = -1 streaming_server_url: str | None = None streaming_model_name: str | None = None - streaming_prefetch: int = Field(default=64, ge=1) # Mirror of the vLLM connector's ``shared_storage_path``; trainer-side allowlist. streaming_shared_storage_path: str | None = None diff --git a/modelopt/torch/utils/__init__.py b/modelopt/torch/utils/__init__.py index 51d02248c14..a38c80cac01 100644 --- a/modelopt/torch/utils/__init__.py +++ b/modelopt/torch/utils/__init__.py @@ -22,6 +22,7 @@ from .import_utils import * from .list import * from .logging import * +from .loss_mask import * from .network import * from .perf import * from .regex import * diff --git a/modelopt/torch/utils/loss_mask.py b/modelopt/torch/utils/loss_mask.py new file mode 100644 index 00000000000..839bce24b8e --- /dev/null +++ b/modelopt/torch/utils/loss_mask.py @@ -0,0 +1,139 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Model-specific recovery of the assistant loss mask. + +The standard way to build an answer-only loss mask is +``apply_chat_template(..., return_assistant_tokens_mask=True)``, which maps the +``{% generation %}`` template span to tokens via ``char_to_token`` -- and that is +only available on "fast" tokenizers. Some models ship only a slow/Python tokenizer +and cannot use this path. + +This module is a small registry of per-model fallbacks that recover the mask +directly from token ids, keyed by a ``detect`` predicate. Data paths consult +:func:`get_loss_mask_recovery` and stay free of any single model's chat-format +details. It is intentionally minimal and is meant to seed a broader model-specific +patch registry. +""" + +from collections.abc import Callable +from dataclasses import dataclass + +import torch + +__all__ = ["LossMaskRecovery", "get_loss_mask_recovery", "register_loss_mask_recovery"] + + +@dataclass(frozen=True) +class LossMaskRecovery: + """A model-specific fallback for building the assistant loss mask. + + Args: + name: Identifier for the target model family (for logging/debugging). + detect: Returns ``True`` if this recovery applies to the given tokenizer. + compute: Maps ``(tokenizer, input_ids)`` to a ``(seq_len,)`` ``LongTensor`` + mask aligned to ``input_ids`` (1 on tokens that should contribute to + the loss, 0 otherwise). + """ + + name: str + detect: Callable[[object], bool] + compute: Callable[[object, torch.Tensor], torch.Tensor] + + +_RECOVERIES: list[LossMaskRecovery] = [] + + +def register_loss_mask_recovery(recovery: LossMaskRecovery) -> None: + """Register a model-specific loss-mask recovery.""" + _RECOVERIES.append(recovery) + + +def get_loss_mask_recovery(tokenizer) -> LossMaskRecovery | None: + """Return the first registered recovery whose ``detect`` matches ``tokenizer``.""" + for recovery in _RECOVERIES: + if recovery.detect(tokenizer): + return recovery + return None + + +# --------------------------------------------------------------------------- +# Kimi +# +# Kimi ships only a Python (tiktoken) tokenizer, so it cannot emit assistant masks +# via apply_chat_template. Its chat turns are rendered as +# <|im_{role}|> {role_name} <|im_middle|> {content} <|im_end|> +# so the assistant content sits between <|im_middle|> and <|im_end|>. +# --------------------------------------------------------------------------- + +_KIMI_ROLE_MARKERS = ("<|im_user|>", "<|im_assistant|>", "<|im_system|>") + + +def _kimi_detect(tokenizer) -> bool: + """Whether ``tokenizer`` defines Kimi's chat role markers as real tokens.""" + unk = getattr(tokenizer, "unk_token_id", None) + try: + ids = [ + tokenizer.convert_tokens_to_ids(t) + for t in (*_KIMI_ROLE_MARKERS, "<|im_middle|>", "<|im_end|>") + ] + except Exception: + return False + return all(i is not None and i != unk for i in ids) + + +def _kimi_compute(tokenizer, input_ids) -> torch.Tensor: + """Recover the assistant-content mask from already-tokenized Kimi chat ids. + + Marks only the ``{content}`` span (between ``<|im_middle|>`` and ``<|im_end|>``, + both exclusive). This matches the ``{% generation %}`` span used for fast + tokenizers: the role header and the trailing ``<|im_end|>`` are not masked. + """ + ids = input_ids.tolist() if hasattr(input_ids, "tolist") else list(input_ids) + assistant_id = tokenizer.convert_tokens_to_ids("<|im_assistant|>") + middle_id = tokenizer.convert_tokens_to_ids("<|im_middle|>") + end_id = tokenizer.convert_tokens_to_ids("<|im_end|>") + role_ids = {tokenizer.convert_tokens_to_ids(t) for t in _KIMI_ROLE_MARKERS} + + n = len(ids) + mask = [0] * n + i = 0 + while i < n: + if ids[i] != assistant_id: + i += 1 + continue + # Skip the role header (role_name) up to its <|im_middle|> separator. + j = i + 1 + while j < n and ids[j] != middle_id and ids[j] not in role_ids and ids[j] != end_id: + j += 1 + if j >= n or ids[j] != middle_id: + # Malformed turn (no content separator) or a trailing generation prompt. + i = j + continue + # Mark the content span [middle + 1, end): excludes <|im_middle|> and <|im_end|>. + start = j + 1 + k = start + while k < n and ids[k] != end_id and ids[k] not in role_ids: + k += 1 + for t in range(start, k): + mask[t] = 1 + i = k + + return torch.tensor(mask, dtype=torch.long) + + +register_loss_mask_recovery( + LossMaskRecovery(name="kimi", detect=_kimi_detect, compute=_kimi_compute) +) diff --git a/tests/examples/speculative_decoding/test_eagle_streaming.py b/tests/examples/speculative_decoding/test_eagle_streaming.py index 291aa0f7929..3c8f7573957 100644 --- a/tests/examples/speculative_decoding/test_eagle_streaming.py +++ b/tests/examples/speculative_decoding/test_eagle_streaming.py @@ -118,13 +118,12 @@ def test_streaming_eagle_training( f"data.streaming_server_url={server_url}", f"data.streaming_model_name={tiny_llama_path}", f"data.streaming_shared_storage_path={scratch}", - "data.streaming_prefetch=2", f"training.output_dir={output_dir}", "training.num_train_epochs=1", "training.learning_rate=1e-5", "training.training_seq_len=32", "training.save_steps=1", - "training.dataloader_num_workers=0", # enforced by StreamingDataset + "training.dataloader_num_workers=0", # map-style; 0 keeps this test single-process *_TINY_EAGLE_ARCH, ] diff --git a/tests/unit/torch/speculative/plugins/test_hf_streaming_dataset.py b/tests/unit/torch/speculative/plugins/test_hf_streaming_dataset.py index 27210ee7286..e6bac5b9755 100644 --- a/tests/unit/torch/speculative/plugins/test_hf_streaming_dataset.py +++ b/tests/unit/torch/speculative/plugins/test_hf_streaming_dataset.py @@ -13,13 +13,12 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Tests for StreamingDataset's DDP contract. +"""Tests for the map-style StreamingDataset. -We do not spin up real torch.distributed; instead we monkeypatch the helper that -reads rank/world_size. Sharding itself is delegated to Accelerate's -``DataLoaderDispatcher`` (every rank holds the full corpus; only rank 0 iterates). -These tests check the corpus-handling and rank-0-only-iter properties on which -that delegation relies. +The dataset is a plain ``torch.utils.data.Dataset``: DDP sharding is HF Trainer's +job (``DistributedSampler``), so there is no rank/dispatch logic to test here. +These tests cover the ``__getitem__`` contract: resample-on-miss, the +consecutive-failure circuit breaker, and the vLLM wire-format -> batch-dict chain. """ from pathlib import Path @@ -30,7 +29,7 @@ import safetensors.torch import torch -# hf_streaming_dataset imports TrainerCallback / LabelSmoother at module scope. +# hf_streaming_dataset imports LabelSmoother at module scope. pytest.importorskip("transformers") from modelopt.torch.speculative.plugins import hf_streaming_dataset @@ -47,133 +46,173 @@ def _entries(n: int) -> list[dict]: return [{"id": i} for i in range(n)] -@pytest.fixture -def patch_dist(monkeypatch): - """Return a setter; tests call it with (rank, world) to simulate a DDP rank. - - Patches ``modelopt.torch.utils.distributed.rank/size`` as imported into the - streaming dataset module (``dist_utils``). The dataset reads these in - ``__init__`` for logging and in ``__iter__`` for the rank-0-only gate. - """ - - def _set(rank: int, world: int): - # ``is_master`` etc. call ``rank(group=...)`` / ``size(group=...)`` — match the signature. - monkeypatch.setattr(hf_streaming_dataset.dist_utils, "rank", lambda group=None: rank) - monkeypatch.setattr(hf_streaming_dataset.dist_utils, "size", lambda group=None: world) - - return _set - - -def _entry_ids(ds: StreamingDataset) -> list[int]: - return [e["id"] for e in ds.entries] - - -@pytest.mark.parametrize("world", [1, 2, 3, 8]) -def test_every_rank_holds_full_corpus(patch_dist, world): - """Each rank must see all entries — Accelerate's dispatcher does the sharding, - so any per-rank pre-shard here would shrink rank 0's view to 1/N and break - ``max_steps``. - """ - corpus = _entries(100) - for rank in range(world): - patch_dist(rank, world) - ds = StreamingDataset(corpus, tokenizer=MagicMock(), config=StreamingConfig(seed=42)) - assert sorted(_entry_ids(ds)) == list(range(100)) - - -def test_same_seed_same_order(patch_dist): - """The shuffle is what makes rank 0's fetch order deterministic across reruns.""" - corpus = _entries(50) - patch_dist(0, 1) - a = _entry_ids(StreamingDataset(corpus, tokenizer=MagicMock(), config=StreamingConfig(seed=7))) - b = _entry_ids(StreamingDataset(corpus, tokenizer=MagicMock(), config=StreamingConfig(seed=7))) - assert a == b - - -def test_different_seed_different_order(patch_dist): - """Sanity: changing the seed actually reshuffles (else seed is vacuous).""" - corpus = _entries(50) - patch_dist(0, 1) - a = _entry_ids(StreamingDataset(corpus, tokenizer=MagicMock(), config=StreamingConfig(seed=1))) - b = _entry_ids(StreamingDataset(corpus, tokenizer=MagicMock(), config=StreamingConfig(seed=2))) - assert a != b - assert sorted(a) == sorted(b) - - -def test_non_rank_zero_iter_is_empty(patch_dist): - """Non-zero ranks must yield nothing on ``__iter__`` — their producer would burn - server requests that ``DataLoaderDispatcher`` would discard.""" - corpus = _entries(8) - patch_dist(2, 4) - ds = StreamingDataset(corpus, tokenizer=MagicMock(), config=StreamingConfig(seed=0)) - assert list(iter(ds)) == [] - - -def test_iter_rejects_dataloader_workers(patch_dist, monkeypatch): - """Iterating from within a DataLoader worker must raise — multiple workers would - each spawn an asyncio loop and N× the request load on the server.""" - patch_dist(0, 1) - ds = StreamingDataset(_entries(4), tokenizer=MagicMock(), config=StreamingConfig(seed=0)) - # Pretend we're inside a DataLoader worker. - monkeypatch.setattr(hf_streaming_dataset, "get_worker_info", lambda: MagicMock()) - with pytest.raises(RuntimeError, match="dataloader_num_workers=0"): - next(iter(ds)) - - -def test_empty_corpus_raises(patch_dist): - patch_dist(0, 1) +def test_empty_corpus_raises(): with pytest.raises(ValueError, match="entries is empty"): StreamingDataset([], tokenizer=MagicMock(), config=StreamingConfig()) -def test_set_resume_position_skips_entries_without_fetching(patch_dist): - """Resume should fast-forward inside the dataset without invoking _fetch. +def test_len_matches_corpus(): + ds = StreamingDataset(_entries(37), tokenizer=MagicMock(), config=StreamingConfig()) + assert len(ds) == 37 - Verifies the contract relied on by StreamingResumeCallback: skipped entries - are not sent to the server, so resume costs nothing on the inference side. - """ - patch_dist(0, 1) - fetched_ids: list[int] = [] + +def test_getitem_resamples_past_unfit_entries(): + """An unfit entry (tokenize -> None) must not be returned; __getitem__ probes + forward to the next fetchable index and returns that instead.""" + fetched_cids: list[int] = [] class _Track(StreamingDataset): def _tokenize_entry(self, entry): + # Even ids are "unfit" (e.g. truncated away / missing fields). + if entry["id"] % 2 == 0: + return None return {"cid": str(entry["id"]), "token_ids": [1], "loss_mask": None} - async def _fetch(self, client, sample): - fetched_ids.append(int(sample["cid"])) + def _fetch(self, sample): + fetched_cids.append(int(sample["cid"])) + return {"ok": True} - corpus = _entries(10) - ds = _Track(corpus, tokenizer=MagicMock(), config=StreamingConfig(seed=0, prefetch=2)) - ds.set_resume_position(5) - list(ds) + def _format(self, fetched): + return {"sentinel": fetched_cids[-1]} - expected = {e["id"] for e in ds.entries[5:]} - assert set(fetched_ids) == expected - # _resume_skip is one-shot - assert ds._resume_skip == 0 + ds = _Track(_entries(10), tokenizer=MagicMock(), config=StreamingConfig()) + # idx 0 is unfit -> resamples forward to idx 1. + out = ds[0] + assert out == {"sentinel": 1} + assert fetched_cids == [1] + # An already-fit index is returned directly. + assert ds[3] == {"sentinel": 3} -def test_circuit_breaker_trips_on_consecutive_fetch_failures(patch_dist): - """When _fetch keeps failing, the producer raises after the threshold so the - trainer sees a clear error instead of a silent empty epoch.""" - patch_dist(0, 1) +def test_circuit_breaker_trips_on_consecutive_failures(): + """When _fetch keeps hitting transient errors (server down), __getitem__ raises + after the threshold instead of silently resampling the whole corpus.""" threshold = 3 class _AlwaysFails(StreamingDataset): - # Bypass tokenization so we don't need a real tokenizer. def _tokenize_entry(self, entry): return {"cid": str(entry["id"]), "token_ids": [1], "loss_mask": None} - async def _fetch(self, client, sample): - raise RuntimeError("simulated server failure") + def _fetch(self, sample): + # A down server surfaces as a transport error, which the breaker counts. + raise httpx.ConnectError("simulated server down") ds = _AlwaysFails( _entries(20), tokenizer=MagicMock(), - config=StreamingConfig(seed=0, prefetch=2, fail_after_consecutive_skips=threshold), + config=StreamingConfig(fail_after_consecutive_skips=threshold), ) with pytest.raises(RuntimeError, match="consecutive _fetch failures"): - list(ds) + ds[0] + + +def test_contract_violation_propagates_not_swallowed(): + """A non-transient error from _fetch (e.g. a contract violation / bug) must + surface immediately, not be masked as a fetch miss and silently resampled.""" + + class _BadContract(StreamingDataset): + def _tokenize_entry(self, entry): + return {"cid": str(entry["id"]), "token_ids": [1], "loss_mask": None} + + def _fetch(self, sample): + raise RuntimeError("server token_ids drift") + + ds = _BadContract( + _entries(20), + tokenizer=MagicMock(), + # High threshold: if the error were (wrongly) swallowed, the breaker wouldn't + # fire, so a leaked breaker message would mask the regression. + config=StreamingConfig(fail_after_consecutive_skips=100), + ) + with pytest.raises(RuntimeError, match="server token_ids drift"): + ds[0] + + +def test_fetch_returning_none_exhausts_then_raises(): + """If every entry's fetch yields None (e.g. all rejected), __getitem__ raises a + clear 'no fetchable sample' error rather than hanging or returning junk.""" + + class _AllNone(StreamingDataset): + def _tokenize_entry(self, entry): + return {"cid": str(entry["id"]), "token_ids": [1], "loss_mask": None} + + def _fetch(self, sample): + return None + + ds = _AllNone( + _entries(4), + tokenizer=MagicMock(), + config=StreamingConfig(fail_after_consecutive_skips=100), + ) + with pytest.raises(RuntimeError, match="no fetchable sample"): + ds[0] + + +def test_resume_skips_consumed_samples_without_refetching(): + """Map-style resume contract: HF Trainer skips consumed batches via + accelerate.skip_first_batches, which drops their indices at the batch-sampler + level so __getitem__ (and thus _fetch) is never called for them. This is why + main.py leaves ignore_data_skip at its default (False) for streaming -- resume + lands at the exact position with no re-fetch. Guards against a regression that + would re-fetch (or re-stream) already-consumed samples on resume.""" + pytest.importorskip("accelerate") + from accelerate import skip_first_batches + from torch.utils.data import DataLoader, RandomSampler + + fetched: list[int] = [] + + class _Recording(StreamingDataset): + def _tokenize_entry(self, entry): + return {"cid": str(entry["id"]), "token_ids": [1], "loss_mask": None} + + def _fetch(self, sample): + cid = int(sample["cid"]) + fetched.append(cid) # stands in for the HTTP fetch + return {"cid": cid} + + def _format(self, payload): + return torch.tensor(payload["cid"]) + + n, batch_size, skip_batches = 20, 2, 3 + ds = _Recording(_entries(n), tokenizer=MagicMock(), config=StreamingConfig()) + + def make_dl(): + # Fresh, identically-seeded sampler -> identical permutation across runs. + return DataLoader( + ds, + batch_size=batch_size, + sampler=RandomSampler(ds, generator=torch.Generator().manual_seed(0)), + ) + + # Full pass -> ground-truth consumption order (cid == requested index here). + full_order = [int(x) for batch in make_dl() for x in batch] + fetched.clear() + + # Resume: skip the first `skip_batches` batches. + tail_order = [int(x) for batch in skip_first_batches(make_dl(), skip_batches) for x in batch] + + consumed = full_order[: skip_batches * batch_size] + expected_tail = full_order[skip_batches * batch_size :] + assert tail_order == expected_tail, "resume must continue at the exact data position" + assert set(fetched).isdisjoint(consumed), "skipped (consumed) samples must not be re-fetched" + assert fetched == expected_tail, "only the un-consumed tail is fetched after resume" + + +def test_server_urls_normalization(): + """server_urls accepts a single string, a comma-separated string, or a list, and + strips trailing slashes.""" + + def _urls(v): + cfg = EagleVllmStreamingConfig( + server_urls=v, model="m", shared_storage_root=str(Path.cwd()) + ) + return cfg.server_urls + + assert _urls("http://a:8000/") == ["http://a:8000"] + assert _urls("http://a:8000, http://b:8000/") == ["http://a:8000", "http://b:8000"] + assert _urls(["http://a:8000", "http://b:8000"]) == ["http://a:8000", "http://b:8000"] + with pytest.raises(ValueError, match="at least one non-empty URL"): + EagleVllmStreamingConfig(server_urls="", model="m", shared_storage_root=".") def _write_canned_safetensors(path: Path, seq: int, n_layers: int, hidden: int) -> None: @@ -196,14 +235,23 @@ def _tokenizer_returning(seq: int) -> MagicMock: return tok -def test_eagle_vllm_dataset_end_to_end(tmp_path, patch_dist, monkeypatch): +def _patch_sync_client(monkeypatch, handler): + """Route the dataset's per-process httpx.Client through a MockTransport handler.""" + real_client = httpx.Client + + def mock_client(*args, **kwargs): + kwargs["transport"] = httpx.MockTransport(handler) + return real_client(*args, **kwargs) + + monkeypatch.setattr(hf_streaming_dataset.httpx, "Client", mock_client) + + +def test_eagle_vllm_dataset_end_to_end(tmp_path, monkeypatch): """Drive EagleVllmStreamingDataset against an in-process mocked server. - Verifies that the wire-format → tensor → batch-dict chain produces dicts - matching what EagleOfflineDataCollator expects, and that scratch files - are cleaned up after each fetch. + Verifies the wire-format -> tensor -> batch-dict chain produces dicts matching + what EagleOfflineDataCollator expects, and that scratch files are cleaned up. """ - patch_dist(0, 1) seq, n_layers, hidden = 8, 3, 16 # n_layers = 1 final + 2 aux scratch = tmp_path / "vllm_scratch" scratch.mkdir() @@ -219,37 +267,25 @@ def handler(request: httpx.Request) -> httpx.Response: json={"kv_transfer_params": {"hidden_states_path": str(path)}}, ) - real_async_client = httpx.AsyncClient - - def mock_async_client(*args, **kwargs): - kwargs["transport"] = httpx.MockTransport(handler) - return real_async_client(*args, **kwargs) - - monkeypatch.setattr(hf_streaming_dataset.httpx, "AsyncClient", mock_async_client) + _patch_sync_client(monkeypatch, handler) n_entries = 4 entries = [ - { - "conversation_id": f"c-{i}", - "messages": [{"role": "user", "content": "x"}], - } + {"conversation_id": f"c-{i}", "messages": [{"role": "user", "content": "x"}]} for i in range(n_entries) ] ds = EagleVllmStreamingDataset( entries=entries, tokenizer=_tokenizer_returning(seq), config=EagleVllmStreamingConfig( - server_url="http://mock:8000", + server_urls="http://mock:8000", model="mock-model", shared_storage_root=str(scratch), - prefetch=2, - seed=0, ), ) - batches = list(ds) + batches = [ds[i] for i in range(n_entries)] - assert len(batches) == n_entries expected_keys = { "input_ids", "base_model_hidden_states", @@ -275,9 +311,53 @@ def mock_async_client(*args, **kwargs): assert list(scratch.iterdir()) == [], "scratch files must be unlinked after fetch" -def test_path_outside_shared_storage_root_is_rejected(tmp_path, patch_dist, monkeypatch): - """Out-of-root path from server is not opened or unlinked.""" - patch_dist(0, 1) +def test_fetch_round_robins_across_server_urls(tmp_path, monkeypatch): + """With multiple server_urls, consecutive fetches alternate across endpoints so + load is spread over replicas rather than pinned to the first one.""" + seq, n_layers, hidden = 8, 3, 16 + scratch = tmp_path / "vllm_scratch" + scratch.mkdir() + + hosts: list[str] = [] + counter = {"n": 0} + + def handler(request: httpx.Request) -> httpx.Response: + hosts.append(request.url.host) + counter["n"] += 1 + path = scratch / f"req_{counter['n']}.safetensors" + _write_canned_safetensors(path, seq, n_layers, hidden) + return httpx.Response( + 200, + json={"kv_transfer_params": {"hidden_states_path": str(path)}}, + ) + + _patch_sync_client(monkeypatch, handler) + + n_entries = 4 + entries = [ + {"conversation_id": f"c-{i}", "messages": [{"role": "user", "content": "x"}]} + for i in range(n_entries) + ] + ds = EagleVllmStreamingDataset( + entries=entries, + tokenizer=_tokenizer_returning(seq), + config=EagleVllmStreamingConfig( + server_urls=["http://a:8000", "http://b:8000"], + model="mock-model", + shared_storage_root=str(scratch), + ), + ) + + for i in range(n_entries): + ds[i] + + # Per-process round-robin cursor: a, b, a, b -- one request each, alternating. + assert hosts == ["a", "b", "a", "b"] + + +def test_path_outside_shared_storage_root_is_rejected(tmp_path, monkeypatch): + """Out-of-root path from the server is not opened or unlinked; the fetch yields + None, so the single-entry corpus is exhausted and __getitem__ raises.""" seq, n_layers, hidden = 8, 3, 16 allowed = tmp_path / "allowed" allowed.mkdir() @@ -292,26 +372,43 @@ def handler(request: httpx.Request) -> httpx.Response: json={"kv_transfer_params": {"hidden_states_path": str(forbidden)}}, ) - real_async_client = httpx.AsyncClient - - def mock_async_client(*args, **kwargs): - kwargs["transport"] = httpx.MockTransport(handler) - return real_async_client(*args, **kwargs) - - monkeypatch.setattr(hf_streaming_dataset.httpx, "AsyncClient", mock_async_client) + _patch_sync_client(monkeypatch, handler) ds = EagleVllmStreamingDataset( entries=[{"conversation_id": "c-0", "messages": [{"role": "user", "content": "x"}]}], tokenizer=_tokenizer_returning(seq), config=EagleVllmStreamingConfig( - server_url="http://mock:8000", + server_urls="http://mock:8000", model="mock-model", shared_storage_root=str(allowed), fail_after_consecutive_skips=100, - prefetch=1, - seed=0, ), ) - assert list(ds) == [] + with pytest.raises(RuntimeError, match="no fetchable sample"): + ds[0] assert forbidden.exists(), "rejected path must not be unlinked" + + +def test_load_safetensors_retries_past_writer_race(tmp_path, monkeypatch): + """The connector writes asynchronously, so an immediate read can race it; + _load_safetensors must retry past the transient FileNotFound/Safetensor error.""" + seq, n_layers, hidden = 4, 2, 8 + path = tmp_path / "late.safetensors" + _write_canned_safetensors(path, seq, n_layers, hidden) + + calls = {"n": 0} + real_safe_open = hf_streaming_dataset.safe_open + + def flaky_safe_open(p, framework): + calls["n"] += 1 + if calls["n"] < 3: # first 2 reads race the writer (file not ready yet) + raise FileNotFoundError(f"No such file or directory: {p}") + return real_safe_open(p, framework=framework) + + monkeypatch.setattr(hf_streaming_dataset, "safe_open", flaky_safe_open) + monkeypatch.setattr(hf_streaming_dataset.time, "sleep", lambda *_: None) # no real backoff + + token_ids, hidden_states = EagleVllmStreamingDataset._load_safetensors(str(path)) + assert calls["n"] == 3 + assert hidden_states.shape == (seq, n_layers, hidden) diff --git a/tools/launcher/common/eagle3/train_eagle_streaming.sh b/tools/launcher/common/eagle3/train_eagle_streaming.sh index 158bd7a0cf6..49b54709d35 100755 --- a/tools/launcher/common/eagle3/train_eagle_streaming.sh +++ b/tools/launcher/common/eagle3/train_eagle_streaming.sh @@ -15,67 +15,50 @@ # See the License for the specific language governing permissions and # limitations under the License. -# EAGLE3 streaming training: runs a `vllm serve` (KV-transfer producer of hidden -# states) alongside the trainer and routes hidden states over HTTP rather than -# dumping to disk. Sibling of train_eagle.sh. +# EAGLE3 streaming training: a `vllm serve` (KV-transfer hidden-states producer) +# runs alongside the trainer, routing hidden states over HTTP not disk. # -# Topology is chosen automatically from the Slurm allocation (the launcher yaml's -# `nodes:` field); nemo_run runs this script once per node, so it branches on -# $SLURM_NODEID: -# nodes == 1 -> co-located: vllm serve on $SERVE_GPU, trainer on the rest of -# the local GPUs (original single-node behavior). -# nodes >= 2 -> split across nodes: node 0 runs vllm serve on all its GPUs, -# node 1 runs the trainer on all its GPUs. The two roles -# rendezvous through the shared /scratchspace mount (node 0 -# publishes its address; node 1 signals completion). For large -# models whose serve needs a whole node (e.g. Kimi-K2.5 TP=8), -# allocate exactly 2 nodes. +# CANONICAL TOPOLOGY/DISPATCH (per-example YAMLs cross-reference here). Topology is +# auto-chosen from the Slurm allocation (yaml `nodes:`) and $SERVE_NODES; nemo_run +# runs this script once per node, branching on $SLURM_NODEID: +# nodes == 1 -> co-located: vllm serve on $SERVE_GPU, trainer on the rest. +# nodes >= 2 -> split: nodes 0..SERVE_NODES-1 each run an independent whole-node +# vllm serve replica; nodes SERVE_NODES..NNODES-1 are multi-node-DDP +# trainers. SERVE_NODES default 1. Rendezvous over shared +# /scratchspace: each serve i publishes .serve_addr.i; head trainer +# (first trainer node = accelerate machine_rank 0) publishes its IP; +# trainers collect every serve address. +# Map-style dataset: DistributedSampler shards the corpus across trainer ranks, each +# rank fetches only its shard round-robin across the SERVE_NODES replicas +# (data.streaming_server_url = comma-joined list). # # Env vars (required): -# HF_MODEL_CKPT Target model path. Used by both vllm serve (as the -# model arg, becomes the served-model-name) and the -# trainer (data.streaming_model_name). -# EAGLE_CAPTURE_IDS JSON list of 1-based layer ids vllm should capture. -# Must equal default_eagle_aux_layer_ids(L) shifted by +1, -# plus the final layer L. For Qwen3-8B (L=36): -# default = [1,17,32] -> capture = [2,18,33,36]. +# HF_MODEL_CKPT Target model path; vllm serve model arg (= served-model-name) +# and trainer data.streaming_model_name. +# EAGLE_CAPTURE_IDS JSON 1-based layer ids to capture = default_eagle_aux_layer_ids(L) +# +1, plus final layer L. Qwen3-8B (L=36): [1,17,32]->[2,18,33,36]. # # Env vars (optional): -# SERVE_PORT default 8765 -# SERVE_GPU_MEM_UTIL default 0.4 (single-node) / 0.9 (multi-node serve node) -# SERVE_READY_TIMEOUT seconds to wait for the server to come up. default 900 -# SERVE_EXTRA_ARGS extra flags appended to `vllm serve` (e.g. --trust-remote-code) -# SERVE_CPU_OFFLOAD_GB GB of weights/GPU to offload to host RAM (fits big models -# on too-few GPUs; slower). e.g. "10" -# SERVE_MAX_MODEL_LEN cap vllm context length (trims KV/activation). e.g. "4096" -# SERVE_MAX_NUM_SEQS cap concurrent sequences (trims KV/activation). e.g. "8" -# SERVE_HOST single-node only: bind/connect host. default 127.0.0.1 -# SERVE_GPU single-node only: CUDA_VISIBLE_DEVICES for vllm. default "0" -# SERVE_TP tensor-parallel size. default 1 (single-node) / all GPUs -# on the serve node (multi-node) -# TRAIN_GPUS single-node only: CUDA_VISIBLE_DEVICES for the trainer. -# default = all local GPUs except SERVE_GPU. -# SERVE_ADVERTISE_IP multi-node only: address node 1 should dial. default is -# node 0's first `hostname -I` IP. -# -# All script args are forwarded to launch_train.sh (typically: --config -# plus OmegaConf dotlist overrides). +# SERVE_NODES multi-node: dedicated serve replica nodes (0..SERVE_NODES-1). default 1 +# SERVE_GPU_MEM_UTIL default 0.4 single-node / 0.9 multi-node serve node +# SERVE_READY_TIMEOUT server startup wait, seconds. default 900 +# SERVE_EXTRA_ARGS extra `vllm serve` flags (e.g. --trust-remote-code) +# SERVE_CPU_OFFLOAD_GB GB/GPU offloaded to host RAM (fits big models on too-few GPUs; slower) +# SERVE_MAX_MODEL_LEN cap context length (trims KV/activation) +# SERVE_MAX_NUM_SEQS cap concurrent sequences (trims KV/activation) +# SERVE_HOST single-node: bind/connect host. default 127.0.0.1 +# SERVE_GPU single-node: CUDA_VISIBLE_DEVICES for vllm. default "0" +# SERVE_TP tensor-parallel size. default 1 single-node / all serve-node GPUs +# TRAIN_GPUS single-node: trainer CUDA_VISIBLE_DEVICES. default = all but SERVE_GPU +# SERVE_ADVERTISE_IP multi-node: address node 1 dials. default node 0's routable IP SCRIPT_DIR="$(dirname "$(readlink -f "$0")")" source "${SCRIPT_DIR}/../service_utils.sh" ################################################################################################### -# Container provisioning -# -# vllm/vllm-openai:* has vllm and torch but not modelopt or the speculative -# trainer's deps. modelopt is bind-mounted at -# /usr/local/lib/python3.12/dist-packages/modelopt, but it has no .dist-info -# (so `importlib.metadata.version('nvidia-modelopt')` would fail). nemo_run -# only ships modelopt subdirs, not the real pyproject.toml, so we synthesize -# a minimal one with a correctly-scoped setuptools.packages.find include — -# without `include = ["modelopt*"]`, setuptools sees both `modelopt/` and -# `modelopt_recipes/` at the top level and refuses with a "flat-layout" -# error. We then `pip install -e .` to register the dist-info. +# Container provisioning: the vllm image lacks modelopt's .dist-info and the real +# pyproject, so synthesize a minimal pyproject (scoped `include` avoids setuptools' +# flat-layout error) and `pip install -e .`. TOML=modules/Model-Optimizer/pyproject.toml if [ ! -f "$TOML" ]; then @@ -112,7 +95,7 @@ export PATH=$PATH:/workspace/.local/bin ################################################################################################### -trap 'error_handler $0 $LINENO' ERR # ERROR HANDLER +trap 'error_handler $0 $LINENO' ERR if [ -z "$HF_MODEL_CKPT" ]; then echo "ERROR: HF_MODEL_CKPT must be set." >&2; exit 1 @@ -121,18 +104,21 @@ if [ -z "$EAGLE_CAPTURE_IDS" ]; then echo "ERROR: EAGLE_CAPTURE_IDS must be set (e.g. '[2, 18, 33, 36]' for Qwen3-8B)." >&2; exit 1 fi -# Everything passed to this script (--config + OmegaConf dotlist) is -# forwarded verbatim to the trainer. Capture it before the helpers below run. +# Forwarded verbatim to the trainer; capture before the helpers below run. SCRIPT_ARGS=("$@") SERVE_PORT="${SERVE_PORT:-8765}" SERVE_READY_TIMEOUT="${SERVE_READY_TIMEOUT:-900}" +SERVE_NODES="${SERVE_NODES:-1}" +# Shared scratch; per-request safetensors keyed by vllm request id, so no collision. SERVE_SCRATCH="/scratchspace/streaming_serve_scratch" -SERVE_LOG="/scratchspace/vllm_serve.log" -# Multi-node rendezvous over the shared /scratchspace mount (lustre, visible on -# every node): node 0 publishes its address here, node 1 signals completion here. -SERVE_ADDR_FILE="/scratchspace/.serve_addr" -DONE_FILE="/scratchspace/.training_done" +SERVE_LOG="/scratchspace/vllm_serve.log" # serve nodes override with a per-node path +# Namespace rendezvous/sentinel files per Slurm job (SLURM_JOB_ID: same across an +# allocation's nodes, unique across allocations) so concurrent allocations on the +# shared mount don't clobber each other's addresses. Fixed token off-Slurm. +RUN_ID="${SLURM_JOB_ID:-local}" +SERVE_ADDR_FILE="/scratchspace/.serve_addr.${RUN_ID}" +DONE_FILE="/scratchspace/.training_done.${RUN_ID}" SERVE_PID="" mkdir -p "$SERVE_SCRATCH" @@ -145,34 +131,32 @@ cleanup() { gpus_on_node() { nvidia-smi --query-gpu=count --format=csv,noheader,nounits | head -n1; } +# Resolve a routable IP (other nodes must dial it). `hostname -I` can list a +# link-local/loopback first, so prefer the Slurm node name, then first non-lo/non-ll IP. +# $1 = optional override (SERVE_ADVERTISE_IP / TRAINER_ADVERTISE_IP) +resolve_routable_ip() { + local ip="$1" + [ -z "$ip" ] && ip=$(getent hosts "${SLURMD_NODENAME:-$(hostname)}" 2>/dev/null | awk '{print $1}' | head -1) + [ -z "$ip" ] && ip=$(hostname -I | tr ' ' '\n' | grep -vE '^(127\.|169\.254\.|fe80:|::1)' | head -1) + [ -z "$ip" ] && ip=$(hostname -I | awk '{print $1}') + echo "$ip" +} + # Start vllm serve in the background. Sets SERVE_PID. # $1 = bind host $2 = tensor-parallel size $3 = CUDA_VISIBLE_DEVICES ("" -> all) launch_vllm() { local bind_host="$1" tp="$2" cvd="$3" echo "Launching vllm serve on ${bind_host}:${SERVE_PORT} (TP=${tp}, CUDA_VISIBLE_DEVICES=${cvd:-all}, mem=${SERVE_GPU_MEM_UTIL}, log: $SERVE_LOG)..." - # Only pin GPUs when a non-empty set is given; an empty CUDA_VISIBLE_DEVICES - # would expose *zero* GPUs (not all), so leave it unset to use the whole node. + # Pin GPUs only for a non-empty set; empty CUDA_VISIBLE_DEVICES hides ALL, so unset = whole node. local -a gpu_env=() [ -n "$cvd" ] && gpu_env=(env "CUDA_VISIBLE_DEVICES=$cvd") - # Optional single-value memory knobs (each a space-free env value, so they - # survive nemo_run's unquoted `export FOO=value`; assembled into --flag value - # pairs here). --cpu-offload-gb spills N GB of weights/GPU to host RAM, the - # key lever for fitting a large model on too-few GPUs (slower, prefill-only - # use tolerates it). --max-model-len / --max-num-seqs trim KV/activation. + # Optional memory knobs (see header). Space-free env values to survive nemo_run's unquoted export. local -a opt_args=() [ -n "${SERVE_CPU_OFFLOAD_GB:-}" ] && opt_args+=(--cpu-offload-gb "$SERVE_CPU_OFFLOAD_GB") [ -n "${SERVE_MAX_MODEL_LEN:-}" ] && opt_args+=(--max-model-len "$SERVE_MAX_MODEL_LEN") [ -n "${SERVE_MAX_NUM_SEQS:-}" ] && opt_args+=(--max-num-seqs "$SERVE_MAX_NUM_SEQS") - # --no-enable-chunked-prefill / --no-enable-prefix-caching: the - # ExampleHiddenStatesConnector captures hidden states during prefill; both - # features skip recomputing cached/partial prefixes, which yields short or - # empty hidden_states. Required, not optional. - # --no-enable-flashinfer-autotune: on big NVFP4 MoE (Kimi) the flashinfer - # trtllm_fp4_block_scale_moe autotuner re-tunes on the first real serving - # step and stalls a worker past vLLM's execute-model timeout -> EngineCore - # dies with "RPC call to sample_tokens timed out" -> 500s -> trainer aborts. - # Disabling autotune keeps kernels static (and pairs with the larger - # VLLM_EXECUTE_MODEL_TIMEOUT_SECONDS set in the example env). + # --no-enable-chunked-prefill / --no-enable-prefix-caching: connector captures hidden states during prefill; both skip recomputing cached/partial prefixes, yielding short/empty hidden_states. Required. + # --no-enable-flashinfer-autotune: on NVFP4 MoE the autotuner re-tunes on the first serving step and stalls a worker past vLLM's execute-model timeout, killing EngineCore. "${gpu_env[@]}" vllm serve "$HF_MODEL_CKPT" \ --host "$bind_host" \ --port "$SERVE_PORT" \ @@ -218,36 +202,63 @@ wait_vllm_ready() { # Run the trainer then export the HF checkpoint. # $1 = streaming server base URL $2 = CUDA_VISIBLE_DEVICES ("" -> all) -# dataloader_num_workers must be 0: the streaming dataset owns one asyncio loop -# per process; multiple workers would duplicate requests against the server. +# DataLoader workers = in-flight fetches per rank; keep modest so (ranks x workers) stays near the serve's max_num_seqs. run_trainer_and_export() { local url="$1" cvd="$2" - echo "Launching trainer (server=${url}, CUDA_VISIBLE_DEVICES=${cvd:-all})..." - # Empty cvd -> use all GPUs on the node (don't set the var; "" would hide all). + # Optional multi-node trainer routing (see dispatch). Defaults: 1 node, no --num_nodes, export on rank 0. + local num_tnodes="${3:-1}" head_ip="${4:-}" mrank="${5:-0}" + echo "Launching trainer (server=${url}, CUDA_VISIBLE_DEVICES=${cvd:-all}, trainer_nodes=${num_tnodes}, machine_rank=${mrank})..." + # Empty cvd -> all GPUs (don't set the var; "" hides all). local -a gpu_env=() [ -n "$cvd" ] && gpu_env=(env "CUDA_VISIBLE_DEVICES=$cvd") + # accelerate multi-node routing only when >1 trainer node. + local -a mn_args=() + if [ "${num_tnodes}" -gt 1 ]; then + mn_args=(--num_nodes "$num_tnodes" --head_node_ip "$head_ip" --machine_rank "$mrank") + fi "${gpu_env[@]}" bash modules/Model-Optimizer/examples/speculative_decoding/launch_train.sh \ "${SCRIPT_ARGS[@]}" \ + "${mn_args[@]}" \ data.streaming_server_url="$url" \ data.streaming_model_name="$HF_MODEL_CKPT" \ data.streaming_shared_storage_path="$SERVE_SCRATCH" \ - training.dataloader_num_workers=0 || { echo "ERROR: trainer failed." >&2; return 1; } + training.dataloader_num_workers="${STREAMING_NUM_WORKERS:-4}" \ + || { echo "ERROR: trainer failed." >&2; return 1; } + # Export only on the head trainer (machine_rank 0); non-head nodes would race the same export dir. Export reads training.output_dir, not the serve. + if [ "${mrank}" -ne 0 ]; then + echo "machine_rank=${mrank}: training done, skipping export (head trainer handles it)." + return 0 + fi + + # Derive checkpoint dir from the forwarded training.output_dir= dotlist (EAGLE default) + # so EAGLE/DFlash runs each export their own dir. EXPORT_EXTRA_ARGS lets DFlash on a + # custom-modeling base (e.g. Kimi) pass --trust_remote_code; empty by default. + local out_dir + out_dir=$(printf '%s\n' "${SCRIPT_ARGS[@]}" | sed -n 's/^training\.output_dir=//p' | tail -1) + # Fail loud rather than guess a default: a wrong dir would silently export the + # wrong checkpoint. Every streaming yaml already forwards training.output_dir=. + if [ -z "$out_dir" ]; then + echo "ERROR: no training.output_dir= forwarded in SCRIPT_ARGS; cannot locate checkpoint to export." >&2 + return 1 + fi python3 modules/Model-Optimizer/examples/speculative_decoding/scripts/export_hf_checkpoint.py \ - --model_path /scratchspace/eagle3 \ - --export_path /scratchspace/export + --model_path "$out_dir" \ + --export_path "${EXPORT_PATH:-/scratchspace/export}" \ + ${EXPORT_EXTRA_ARGS:-} } -# --------------------------------------------------------------------------- -# Topology dispatch (driven by the Slurm allocation, i.e. the yaml `nodes:`): -# SLURM_NNODES == 1 -> co-located: vllm on $SERVE_GPU, trainer on the rest. -# SLURM_NNODES >= 2 -> split: node 0 serves on all its GPUs, node 1 trains on -# all its GPUs; they rendezvous via /scratchspace. -# nemo_run runs this script once per node, so we branch on $SLURM_NODEID. -# --------------------------------------------------------------------------- +# Topology dispatch (see header): branch on $SLURM_NNODES / $SLURM_NODEID. NNODES="${SLURM_NNODES:-1}" NODEID="${SLURM_NODEID:-0}" +# Need >=1 trainer node: with SERVE_NODES >= NNODES every node takes the serve branch, +# so nobody publishes the rendezvous/DONE_FILE and serve nodes block forever. +if [ "$NNODES" -gt 1 ] && [ "$SERVE_NODES" -ge "$NNODES" ]; then + echo "ERROR: SERVE_NODES ($SERVE_NODES) must be < SLURM_NNODES ($NNODES); need >=1 trainer node." >&2 + exit 1 +fi + if [ "$NNODES" -le 1 ]; then # ----------------------------- single node ----------------------------- SERVE_HOST="${SERVE_HOST:-127.0.0.1}" @@ -272,56 +283,75 @@ PY wait_vllm_ready "http://${SERVE_HOST}:${SERVE_PORT}" || exit 1 run_trainer_and_export "http://${SERVE_HOST}:${SERVE_PORT}" "$TRAIN_GPUS" || exit 1 -elif [ "$NODEID" -eq 0 ]; then - # ----------------------- multi-node: serve node ------------------------ - SERVE_GPU_MEM_UTIL="${SERVE_GPU_MEM_UTIL:-0.9}" # dedicated node -> use most of it - SERVE_TP="${SERVE_TP:-$(gpus_on_node)}" # default: all GPUs on this node - rm -f "$SERVE_ADDR_FILE" "$DONE_FILE" # clear stale rendezvous state +elif [ "$NODEID" -lt "$SERVE_NODES" ]; then + # ---------------------- multi-node: serve node(s) ---------------------- + # Each runs a whole-node vllm serve replica and publishes ${SERVE_ADDR_FILE}.${NODEID}. + SERVE_GPU_MEM_UTIL="${SERVE_GPU_MEM_UTIL:-0.9}" # dedicated node -> use most of it + SERVE_TP="${SERVE_TP:-$(gpus_on_node)}" # default: all GPUs on this node + SERVE_LOG="/scratchspace/vllm_serve.${NODEID}.log" # per-node log (avoid collision) + rm -f "${SERVE_ADDR_FILE}.${NODEID}" # clear own stale address + [ "$NODEID" -eq 0 ] && rm -f "$DONE_FILE" # node 0 clears the shared sentinel once trap cleanup INT TERM EXIT launch_vllm "0.0.0.0" "$SERVE_TP" "" wait_vllm_ready "http://127.0.0.1:${SERVE_PORT}" || exit 1 - # Publish a *routable* address for the trainer node. `hostname -I` can list a - # link-local (169.254.x) or loopback address first, which is unreachable from - # the other node, so resolve the Slurm node name and fall back to the first - # non-link-local / non-loopback IP. - serve_addr="${SERVE_ADVERTISE_IP:-}" - if [ -z "$serve_addr" ]; then - serve_addr=$(getent hosts "${SLURMD_NODENAME:-$(hostname)}" 2>/dev/null | awk '{print $1}' | head -1) - fi - if [ -z "$serve_addr" ]; then - serve_addr=$(hostname -I | tr ' ' '\n' | grep -vE '^(127\.|169\.254\.|fe80:|::1)' | head -1) - fi - [ -z "$serve_addr" ] && serve_addr=$(hostname -I | awk '{print $1}') - echo "$serve_addr" > "$SERVE_ADDR_FILE" - echo "Serve node published ${serve_addr}; holding the server up until the trainer signals done..." + serve_addr=$(resolve_routable_ip "${SERVE_ADVERTISE_IP:-}") + echo "$serve_addr" > "${SERVE_ADDR_FILE}.${NODEID}" + echo "Serve node ${NODEID}/${SERVE_NODES} published ${serve_addr}; holding up until training signals done..." while [ ! -f "$DONE_FILE" ]; do sleep 10; done - echo "Training-done sentinel seen; serve node exiting (EXIT trap stops vllm)." + echo "Training-done sentinel seen; serve node ${NODEID} exiting (EXIT trap stops vllm)." -elif [ "$NODEID" -eq 1 ]; then - # ---------------------- multi-node: trainer node ----------------------- - # Release the serve node on any exit (success or failure) so it doesn't hang. - trap 'touch "$DONE_FILE" 2>/dev/null || true' EXIT +else + # -------------------- multi-node: trainer node(s) ---------------------- + # Trainer nodes SERVE_NODES..NNODES-1 -> 0-based accelerate machine ranks. + NUM_TRAINER_NODES=$(( NNODES - SERVE_NODES )) + TRAINER_RANK=$(( NODEID - SERVE_NODES )) + TRAINER_ADDR_FILE="/scratchspace/.trainer_addr.${RUN_ID}" # per-job (see RUN_ID) - echo "Trainer node waiting (up to ${SERVE_READY_TIMEOUT}s) for the serve address..." - for ((i = 0; i < SERVE_READY_TIMEOUT; i++)); do - [ -f "$SERVE_ADDR_FILE" ] && break - sleep 1 - done - [ -f "$SERVE_ADDR_FILE" ] || { echo "ERROR: serve node never published its address." >&2; exit 1; } - URL="http://$(cat "$SERVE_ADDR_FILE"):${SERVE_PORT}" + # Only head trainer (rank 0) signals serves to release on exit; a non-head node + # exiting first must NOT tear them down early. + if [ "$TRAINER_RANK" -eq 0 ]; then + trap 'touch "$DONE_FILE" 2>/dev/null || true' EXIT + rm -f "$TRAINER_ADDR_FILE" # clear stale rendezvous state + fi - wait_vllm_ready "$URL" || exit 1 - run_trainer_and_export "$URL" "" || exit 1 + # Collect serve addresses into the comma-joined URL list the dataset round-robins across. + echo "Trainer node (rank ${TRAINER_RANK}/${NUM_TRAINER_NODES}) waiting for ${SERVE_NODES} serve address(es)..." + URLS="" + for ((s = 0; s < SERVE_NODES; s++)); do + af="${SERVE_ADDR_FILE}.${s}" + for ((i = 0; i < SERVE_READY_TIMEOUT; i++)); do + [ -f "$af" ] && break + sleep 1 + done + [ -f "$af" ] || { echo "ERROR: serve node ${s} never published its address." >&2; exit 1; } + surl="http://$(cat "$af"):${SERVE_PORT}" + wait_vllm_ready "$surl" || exit 1 + URLS="${URLS:+$URLS,}$surl" + done + echo "Trainer rank ${TRAINER_RANK} using serve URLs: ${URLS}" -else - # ------------- multi-node: extra nodes (unused by default) ------------- - echo "Node rank ${NODEID} idle: the default split uses node 0 = vllm serve, node 1 = trainer." - echo "Multi-node *training* (>1 trainer node) is not wired up yet; allocate exactly 2 nodes." - while [ ! -f "$DONE_FILE" ]; do sleep 10; done + if [ "$NUM_TRAINER_NODES" -le 1 ]; then + # 1 trainer node: single-node DDP. + run_trainer_and_export "$URLS" "" || exit 1 + else + # >1 trainer node: head publishes its routable IP for accelerate rendezvous (29500); all read and join. + if [ "$TRAINER_RANK" -eq 0 ]; then + head_addr=$(resolve_routable_ip "${TRAINER_ADVERTISE_IP:-}") + echo "$head_addr" > "$TRAINER_ADDR_FILE" + echo "Head trainer (rank 0) published ${head_addr} for accelerate rendezvous." + else + echo "Trainer rank ${TRAINER_RANK} waiting for head-trainer address..." + for ((i = 0; i < SERVE_READY_TIMEOUT; i++)); do + [ -f "$TRAINER_ADDR_FILE" ] && break + sleep 1 + done + [ -f "$TRAINER_ADDR_FILE" ] || { echo "ERROR: head trainer never published its address." >&2; exit 1; } + fi + HEAD_IP=$(cat "$TRAINER_ADDR_FILE") + run_trainer_and_export "$URLS" "" "$NUM_TRAINER_NODES" "$HEAD_IP" "$TRAINER_RANK" || exit 1 + fi fi ################################################################################################### - -#exit_handler $0 diff --git a/tools/launcher/core.py b/tools/launcher/core.py index aa60bbad9e9..9154c1427bc 100644 --- a/tools/launcher/core.py +++ b/tools/launcher/core.py @@ -257,8 +257,8 @@ def build_slurm_executor( # use a LocalTunnel: nemo_run then runs sbatch and copies artifacts via local # subprocess/shutil instead of ssh+rsync. This avoids flaky/hanging ssh-to- # localhost (e.g. MaxStartups throttling on a shared login node, or clusters - # like HSG that are only reachable through an sss proxy so paramiko can't - # tunnel in from outside). For real remote hosts, keep the SSHTunnel. + # only reachable through a login proxy so paramiko can't tunnel in from + # outside). For real remote hosts, keep the SSHTunnel. if slurm_config.host in ("localhost", "127.0.0.1"): tunnel = run.LocalTunnel(job_dir=job_dir) else: @@ -270,6 +270,15 @@ def build_slurm_executor( identity=identity, ) + # --segment=: pin all nodes into one topology block (one NVL72 / NVLink domain). + # getattr (not attribute access) keeps older/custom SlurmConfig types patched in via + # set_slurm_config_type that predate the `segment` field from raising AttributeError. + # None -> omit the kwarg entirely so the scheduler places freely (default behavior). + optional_kwargs = {} + segment = getattr(slurm_config, "segment", None) + if segment is not None: + optional_kwargs["segment"] = segment + executor = run.SlurmExecutor( account=slurm_config.account, partition=slurm_config.partition, @@ -286,6 +295,7 @@ def build_slurm_executor( retries=0, packager=packager, srun_args=slurm_config.srun_args, + **optional_kwargs, ) return executor diff --git a/tools/launcher/examples/Qwen/Qwen3-8B/hf_streaming_eagle3.yaml b/tools/launcher/examples/Qwen/Qwen3-8B/hf_streaming_eagle3.yaml index d93525632f3..d46e0eee68b 100644 --- a/tools/launcher/examples/Qwen/Qwen3-8B/hf_streaming_eagle3.yaml +++ b/tools/launcher/examples/Qwen/Qwen3-8B/hf_streaming_eagle3.yaml @@ -2,14 +2,11 @@ # # 3-step pipeline: # task_0: Build input conversations (jsonl) -# task_1: Streaming train — vllm serve + trainer; hidden states are fetched -# per sample over HTTP (no on-disk dump) -# task_2: Benchmark — evaluate speculative decoding speedup via VLLM +# task_1: Streaming train — vllm serve + trainer; hidden states fetched over HTTP +# task_2: Benchmark — speculative decoding speedup via VLLM # -# task_1 here uses the multi-node split (nodes=2): node 0 runs vllm serve, node 1 -# runs the trainer; they rendezvous via the shared /scratchspace mount. (Set -# nodes=1 to co-locate both on one node instead.) All tasks share /scratchspace -# to pass artifacts between steps. +# task_1 uses nodes=2: node 0 runs vllm serve, node 1 the trainer. Tasks share +# /scratchspace to pass artifacts. # # Usage: # uv run launch.py --yaml examples/Qwen/Qwen3-8B/hf_streaming_eagle3.yaml --yes @@ -23,7 +20,6 @@ pipeline: global_vars: hf_model: /hf-local/Qwen/Qwen3-8B - # Step 1: Build input conversations task_0: script: common/eagle3/make_dataset.sh args: @@ -36,11 +32,8 @@ pipeline: gpus_per_node: 1 container: nvcr.io/nvidia/tensorrt-llm/release:1.3.0rc10 - # Step 2: Streaming EAGLE3 training - # - # Qwen3-8B has 36 hidden layers; default_eagle_aux_layer_ids(36) = [1, 17, 32]; - # vllm capture ids are those shifted by +1, plus the final layer: - # [2, 18, 33] + [36] = [2, 18, 33, 36]. + # capture ids = default_eagle_aux_layer_ids(36)=[1,17,32] shifted +1, plus final + # layer 36 -> [2,18,33,36]. task_1: script: common/eagle3/train_eagle_streaming.sh args: @@ -48,7 +41,6 @@ pipeline: - model.model_name_or_path=<> - data.mode=streaming - data.data_path=/scratchspace/data/train.jsonl - - data.streaming_prefetch=64 - training.output_dir=/scratchspace/eagle3 - training.training_seq_len=4096 - training.disable_tqdm=true @@ -57,8 +49,7 @@ pipeline: - eagle.eagle_use_torch_compile=false environment: - HF_MODEL_CKPT: <> - # No spaces: nemo_run emits `export FOO=value` without quotes, so a - # space-separated value would be split by the shell. + # No spaces: nemo_run emits unquoted `export FOO=value`, so spaces would split. - EAGLE_CAPTURE_IDS: "[2,18,33,36]" - SERVE_TP: "1" slurm_config: @@ -68,7 +59,6 @@ pipeline: gpus_per_node: 1 container: vllm/vllm-openai:latest - # Step 3: Benchmark speculative decoding (VLLM backend) task_2: script: common/specdec_bench/quick_check.sh args: diff --git a/tools/launcher/examples/Qwen/Qwen3-8B/hf_streaming_eagle3_multi_node.yaml b/tools/launcher/examples/Qwen/Qwen3-8B/hf_streaming_eagle3_multi_node.yaml new file mode 100644 index 00000000000..3751ecbe96a --- /dev/null +++ b/tools/launcher/examples/Qwen/Qwen3-8B/hf_streaming_eagle3_multi_node.yaml @@ -0,0 +1,87 @@ +# EAGLE3 streaming speculative decoding pipeline for Qwen3-8B — MULTI-NODE. +# +# task_1 splits N nodes into K serve replicas + (N-K) DDP trainers via SERVE_NODES; +# see common/eagle3/train_eagle_streaming.sh for dispatch, rendezvous, and sharding. +# +# 3-step pipeline: +# task_0: Build input conversations (jsonl) +# task_1: Streaming train — 2 serve nodes (2 GPU, TP=2) + 2 trainer nodes (2 GPU) +# task_2: Benchmark — evaluate speculative decoding speedup via VLLM +# +# Usage: +# uv run launch.py --yaml examples/Qwen/Qwen3-8B/hf_streaming_eagle3_multi_node.yaml --yes + +job_name: Qwen3-8B_EAGLE3_streaming_multi_node +pipeline: + allow_to_fail: false + skip: false + note: + + global_vars: + hf_model: /hf-local/Qwen/Qwen3-8B + + # Step 1: Build input conversations + task_0: + script: common/eagle3/make_dataset.sh + args: + - -f modules/Model-Optimizer/examples/dataset/example_data_config.yaml + - --full-conversations + slurm_config: + _factory_: "slurm_factory" + nodes: 1 + ntasks_per_node: 1 + gpus_per_node: 1 + container: nvcr.io/nvidia/tensorrt-llm/release:1.3.0rc10 + + # Step 2: Streaming EAGLE3 training — 2 serve replicas (TP=2) + 2 trainer nodes (2 GPU each). + # Capture ids: default_eagle_aux_layer_ids(36)=[1,17,32] +1, plus final layer 36. + task_1: + script: common/eagle3/train_eagle_streaming.sh + args: + - --config modules/Model-Optimizer/modelopt_recipes/general/speculative_decoding/eagle3.yaml + - model.model_name_or_path=<> + - data.mode=streaming + - data.data_path=/scratchspace/data/train.jsonl + - training.output_dir=/scratchspace/eagle3 + - training.training_seq_len=4096 + - training.disable_tqdm=true + - training.ar_validate_steps=500000 + - training.num_train_epochs=1 + - eagle.eagle_use_torch_compile=false + environment: + - HF_MODEL_CKPT: <> + # No spaces: nemo_run emits `export FOO=value` unquoted. + - EAGLE_CAPTURE_IDS: "[2,18,33,36]" + - SERVE_TP: "2" + # K serve replica nodes (Slurm nodes 0..K-1); the rest are trainers. + - SERVE_NODES: "2" + # Per-rank in-flight fetches; keep low so the cold NVFP4-MoE serve isn't flooded past its execute-model timeout (kills EngineCore). + - STREAMING_NUM_WORKERS: "4" + slurm_config: + _factory_: "slurm_factory" + nodes: 4 + ntasks_per_node: 1 + gpus_per_node: 2 + container: vllm/vllm-openai:latest + + # Step 3: Benchmark speculative decoding (VLLM backend) + task_2: + script: common/specdec_bench/quick_check.sh + args: + - --draft_model_dir /scratchspace/export + - --draft_length 3 + - --output_length 4096 + - --engine VLLM + - --tp_size 1 + - --ep_size 1 + - --speculative_algorithm EAGLE3 + - --mtbench /hf-local/HuggingFaceH4/mt_bench_prompts/raw/question.jsonl + - --concurrency 1 + environment: + - HF_MODEL_CKPT: <> + slurm_config: + _factory_: "slurm_factory" + nodes: 1 + ntasks_per_node: 1 + gpus_per_node: 1 + container: vllm/vllm-openai:latest diff --git a/tools/launcher/examples/moonshotai/Kimi-K2.5/hf_dflash_dryrun.yaml b/tools/launcher/examples/moonshotai/Kimi-K2.5/hf_dflash_dryrun.yaml new file mode 100644 index 00000000000..5cb467b3f6a --- /dev/null +++ b/tools/launcher/examples/moonshotai/Kimi-K2.5/hf_dflash_dryrun.yaml @@ -0,0 +1,39 @@ +# DFlash dry-run smoke test for Kimi-K2.5 (NVFP4): exercises the full +# convert->save->export path WITHOUT training, to validate launcher/export +# plumbing and downstream loaders. Exported draft has untrained weights. +# +# Usage: +# uv run launch.py --yaml examples/moonshotai/Kimi-K2.5/hf_dflash_dryrun.yaml --yes + +job_name: Kimi-K2.5_DFlash_dryrun +pipeline: + allow_to_fail: false + skip: false + note: + + global_vars: + hf_model: /hf-local/nvidia/Kimi-K2.5-NVFP4/ + + # Convert -> save -> export (no training). + task_0: + script: common/specdec/dflash_online_training.sh + args: + # Skips trainer.train(), saves the untrained checkpoint right after convert. + - --dry_run + - --config modules/Model-Optimizer/modelopt_recipes/general/speculative_decoding/dflash.yaml + - model.model_name_or_path=<> + # FakeBaseModel (embed + lm_head only) so the base fits one GPU; never read in dry-run. + - model.use_fake_base_for_offline=true + - model.trust_remote_code=true + # An offline path forces mode=offline; value unused in dry-run. + - data.offline_data_path=/tmp/dryrun-placeholder + - training.output_dir=/scratchspace/dflash + - training.disable_tqdm=true + # Kimi has no dedicated mask token; 163838 is a reserved slot used as the mask. + - dflash.dflash_mask_token_id=163838 + slurm_config: + _factory_: "slurm_factory" + nodes: 1 + ntasks_per_node: 1 + gpus_per_node: 1 + container: nvcr.io/nvidia/tensorrt-llm/release:1.3.0rc10 diff --git a/tools/launcher/examples/moonshotai/Kimi-K2.5/hf_streaming_dflash.yaml b/tools/launcher/examples/moonshotai/Kimi-K2.5/hf_streaming_dflash.yaml new file mode 100644 index 00000000000..8f82b1919b1 --- /dev/null +++ b/tools/launcher/examples/moonshotai/Kimi-K2.5/hf_streaming_dflash.yaml @@ -0,0 +1,97 @@ +# DFlash streaming speculative-decoding training for Kimi-K2.5-NVFP4 on +# GB200/Blackwell: node 0 = vllm serve (TP=4, whole node), node 1 = DFlash +# trainer. See common/eagle3/train_eagle_streaming.sh header for the mechanism. +# +# Requires GB200: native NVFP4 + 192 GB/GPU fits ~551 GB Kimi at TP=4 on one node. +# +# data.mode=streaming sets dflash_offline so the DFlash module consumes streamed +# hidden states instead of running the fake base. +# Capture ids = [2,16,31,45,59,60] (kimi_k25/deepseek_v3, 61 layers): 5 DFlash +# target layers + base 60. n_captured = num_target_layers + 1. +# +# answer_only_loss=true: Kimi ships only a slow tokenizer, so it can't derive the +# assistant mask the standard way (return_assistant_tokens_mask needs a fast +# tokenizer's char_to_token). The mask is instead recovered from token ids by the +# registered model-specific recovery in modelopt.torch.utils.loss_mask. +# +# Run ON the cluster login node (paramiko can't reach the cluster through its login proxy): +# export SLURM_HOST=localhost SLURM_ACCOUNT= \ +# SLURM_PARTITION=batch \ +# SLURM_HF_LOCAL= \ +# SLURM_JOB_DIR= \ +# NEMORUN_HOME=$PWD +# uv run launch.py --yaml examples/moonshotai/Kimi-K2.5/hf_streaming_dflash.yaml \ +# identity=$HOME/.ssh/id_ecdsa detach=True --yes +# +# The export lands in /scratchspace/export. To benchmark it, point +# specdec_bench.yaml's --draft_model_dir there (or copy it under /hf-local). + +job_name: Kimi-K2.5-NVFP4_DFlash_streaming +pipeline: + allow_to_fail: false + skip: false + note: + + global_vars: + hf_model: /hf-local/nvidia/Kimi-K2.5-NVFP4 + + task_0: + script: common/eagle3/make_dataset.sh + args: + - -f modules/Model-Optimizer/examples/dataset/example_data_config.yaml + - --full-conversations + slurm_config: + _factory_: "slurm_factory" + nodes: 1 + ntasks_per_node: 1 + # The cluster QOS requires whole-node GPU allocation though make_dataset is CPU-only. + gpus_per_node: 4 + container: nvcr.io/nvidia/tensorrt-llm/release:1.3.0rc10 + + task_1: + script: common/eagle3/train_eagle_streaming.sh + args: + - --config modules/Model-Optimizer/modelopt_recipes/general/speculative_decoding/dflash.yaml + - model.model_name_or_path=<> + - model.use_fake_base_for_offline=true + - model.trust_remote_code=true + - data.mode=streaming + - data.data_path=/scratchspace/data/train.jsonl + - training.output_dir=/scratchspace/dflash + # Must be divisible by dflash_block_size (8). + - training.training_seq_len=4096 + - training.disable_tqdm=true + - training.num_train_epochs=1 + - training.max_steps=3000 + # Assistant mask recovered from token ids for Kimi's slow tokenizer (see header). + - training.answer_only_loss=true + # vLLM container has no tensorboard (dflash.yaml's default report_to); disable. + - training.report_to=none + # Kimi-K2.5 has no dedicated mask token; 163838 is a reserved slot used as one. + - dflash.dflash_mask_token_id=163838 + environment: + - HF_MODEL_CKPT: <> + # No spaces in values: nemo_run emits `export FOO=value` unquoted. + - EAGLE_CAPTURE_IDS: "[2,16,31,45,59,60]" + - SERVE_TP: "4" + # Per-rank in-flight fetches; keep low so the cold NVFP4-MoE serve isn't flooded past its execute-model timeout (kills EngineCore). + - STREAMING_NUM_WORKERS: "1" + # DFlash on a custom-modeling base (Kimi) needs --trust_remote_code at export. + - EXPORT_EXTRA_ARGS: "--trust_remote_code" + # Cap context to the train seq len; the model's native 262144 KV-cache OOMs at TP=4. + - SERVE_MAX_MODEL_LEN: "4096" + - SERVE_MAX_NUM_SEQS: "4" + - SERVE_GPU_MEM_UTIL: "0.8" + - SERVE_READY_TIMEOUT: "2400" + - SERVE_EXTRA_ARGS: "--trust-remote-code" + # Cold NVFP4-MoE kernels stall the first serving step past vLLM's default execute-model timeout; raise it (seconds). + - VLLM_EXECUTE_MODEL_TIMEOUT_SECONDS: "1200" + - VLLM_ENGINE_ITERATION_TIMEOUT_S: "1200" + slurm_config: + _factory_: "slurm_factory" + nodes: 2 + # Pin nodes into one NVL72 block (latency/locality; inter-node is HTTP + lustre, not NCCL). + segment: 2 + ntasks_per_node: 1 + gpus_per_node: 4 + container: vllm/vllm-openai:latest diff --git a/tools/launcher/examples/moonshotai/Kimi-K2.5/hf_streaming_dflash_multi_node.yaml b/tools/launcher/examples/moonshotai/Kimi-K2.5/hf_streaming_dflash_multi_node.yaml new file mode 100644 index 00000000000..6b70e94d262 --- /dev/null +++ b/tools/launcher/examples/moonshotai/Kimi-K2.5/hf_streaming_dflash_multi_node.yaml @@ -0,0 +1,98 @@ +# DFlash streaming speculative-decoding training for Kimi-K2.5-NVFP4 on +# GB200/Blackwell — MULTI-NODE: both serve and trainer sides scale out. +# nodes=N, SERVE_NODES=K -> K serve replicas (TP=4, whole node) + (N-K) trainer +# nodes. See common/eagle3/train_eagle_streaming.sh for dispatch/sharding/scaling. +# +# Requires GB200: native NVFP4 + 192 GB/GPU fits ~551 GB Kimi-K2.5-NVFP4 at TP=4 +# on one 4-GPU node, so each serve replica owns a whole node. +# +# Capture ids: build_target_layer_ids(num_orig=61, num_draft=5)=[1,15,30,44,58] +# -> +1 for embedding = [2,16,31,45,59], append base 60 (final layer uncapturable). +# 6 captured = 5 aux layers, matching the 5-layer DFlash draft block. +# +# Run ON the cluster login node (paramiko can't reach the cluster through its login proxy): +# export SLURM_HOST=localhost SLURM_ACCOUNT= \ +# SLURM_PARTITION=batch \ +# SLURM_HF_LOCAL= \ +# SLURM_JOB_DIR= \ +# NEMORUN_HOME=$PWD +# uv run launch.py --yaml examples/moonshotai/Kimi-K2.5/hf_streaming_dflash_multi_node.yaml \ +# identity=$HOME/.ssh/id_ecdsa detach=True --yes +# +# The export lands in /scratchspace/export. To benchmark it, point +# specdec_bench.yaml's --draft_model_dir there (or copy it under /hf-local). + +job_name: Kimi-K2.5-NVFP4_DFlash_streaming_multi_node +pipeline: + allow_to_fail: false + skip: false + note: + + global_vars: + hf_model: /hf-local/nvidia/Kimi-K2.5-NVFP4 + + # Build input conversations. + task_0: + script: common/eagle3/make_dataset.sh + args: + - -f modules/Model-Optimizer/examples/dataset/example_data_config.yaml + - --full-conversations + slurm_config: + _factory_: "slurm_factory" + nodes: 1 + ntasks_per_node: 1 + # The cluster QOS requires whole-node GPU alloc even though make_dataset is CPU-only. + gpus_per_node: 4 + container: nvcr.io/nvidia/tensorrt-llm/release:1.3.0rc10 + + # Streaming DFlash training: 2 serve replicas (TP=4) + 2 trainer nodes. + task_1: + script: common/eagle3/train_eagle_streaming.sh + args: + - --config modules/Model-Optimizer/modelopt_recipes/general/speculative_decoding/dflash.yaml + - model.model_name_or_path=<> + - model.use_fake_base_for_offline=true + - model.trust_remote_code=true + - data.mode=streaming + - data.data_path=/scratchspace/data/train.jsonl + - training.output_dir=/scratchspace/dflash + # Must be divisible by dflash_block_size (8). + - training.training_seq_len=4096 + - training.disable_tqdm=true + - training.ar_validate_steps=500000 + - training.num_train_epochs=1 + - training.max_steps=500 + # Kimi's slow tokenizer can't emit assistant masks the standard way; the mask + # is recovered from token ids (modelopt.torch.utils.loss_mask). + - training.answer_only_loss=true + # vLLM container has no tensorboard (dflash.yaml's default) -> init crash. + - training.report_to=none + # Kimi has no dedicated mask token; 163838 is a reserved slot used as the mask. + - dflash.dflash_mask_token_id=163838 + environment: + - HF_MODEL_CKPT: <> + # See header for derivation. + - EAGLE_CAPTURE_IDS: "[2,16,31,45,59,60]" + - SERVE_NODES: "2" + - SERVE_TP: "4" + # Per-rank in-flight fetches; keep low so the cold NVFP4-MoE serve isn't flooded past its execute-model timeout (kills EngineCore). + - STREAMING_NUM_WORKERS: "1" + # Kimi's custom-modeling base needs --trust_remote_code at export. + - EXPORT_EXTRA_ARGS: "--trust_remote_code" + # Cap context to the train seq len; the model's native 262144 KV-cache OOMs at TP=4. + - SERVE_MAX_MODEL_LEN: "4096" + - SERVE_MAX_NUM_SEQS: "4" + - SERVE_GPU_MEM_UTIL: "0.8" + - SERVE_READY_TIMEOUT: "2400" + - SERVE_EXTRA_ARGS: "--trust-remote-code" + # Cold NVFP4-MoE kernels stall the first serving step past vLLM's default execute-model timeout; raise it (seconds). + - VLLM_EXECUTE_MODEL_TIMEOUT_SECONDS: "1200" + - VLLM_ENGINE_ITERATION_TIMEOUT_S: "1200" + slurm_config: + _factory_: "slurm_factory" + nodes: 4 + # Pin nodes into one NVL72 block (essential for cross-node trainer DDP). + segment: 4 + ntasks_per_node: 1 + gpus_per_node: 4 + container: vllm/vllm-openai:latest diff --git a/tools/launcher/examples/moonshotai/Kimi-K2.5/hf_streaming_eagle3.yaml b/tools/launcher/examples/moonshotai/Kimi-K2.5/hf_streaming_eagle3.yaml index 24487ab8621..5ace5e83847 100644 --- a/tools/launcher/examples/moonshotai/Kimi-K2.5/hf_streaming_eagle3.yaml +++ b/tools/launcher/examples/moonshotai/Kimi-K2.5/hf_streaming_eagle3.yaml @@ -1,22 +1,16 @@ -# EAGLE3 streaming speculative decoding for Kimi-K2.5-NVFP4 on GB200/Blackwell -# (HSG). This is the streaming config that actually runs end-to-end: on CW H100 -# the ~551 GB model needed cpu-offload (-> ~1 tok/s -> vLLM EngineCore -# TimeoutError), so the working path is GB200. +# EAGLE3 streaming speculative decoding for Kimi-K2.5-NVFP4 on GB200/Blackwell. # -# Why GB200: nodes have only 4 GPUs each (vs CW's 8), but 192 GB/GPU and native -# NVFP4. Kimi-K2.5-NVFP4 (~551 GB) fits at TP=4 on ONE node (4 x 192 = 768 GB, -# ~138 GB/GPU of weights) with NO cpu-offload. So here: node 0 = vllm serve -# (TP=4, whole node), node 1 = EAGLE3 trainer (fake base), 4 GPUs each, 2 nodes. +# Requires GB200: native NVFP4 + 192 GB/GPU fits the ~551 GB model at TP=4 on one node. +# node 0 = vllm serve (TP=4), node 1 = EAGLE3 trainer (fake base); 4 GPUs each. # -# Capture ids: kimi_k25 (deepseek_v3 arch), 61 layers. aux states are indexed -# by layer INPUT (0..60); the final layer is NOT capturable, so the base is 60. -# captured = [2,30,58] aux + [60] base = 4, matching the trainer's 3-aux+base. +# Capture ids: deepseek_v3 arch, 61 layers, indexed by layer input (0..60); +# [2,30,58] aux + [60] base (final layer not capturable). # -# Run ON the HSG login node (paramiko can't reach HSG through the sss proxy): -# export SLURM_HOST=localhost SLURM_ACCOUNT=coreai_dlalgo_modelopt \ +# Run ON the cluster login node (paramiko can't reach the cluster through its login proxy): +# export SLURM_HOST=localhost SLURM_ACCOUNT= \ # SLURM_PARTITION=batch \ -# SLURM_HF_LOCAL=/lustre/fs1/portfolios/coreai/projects/coreai_dlalgo_modelopt/hf-local \ -# SLURM_JOB_DIR=/lustre/fs1/portfolios/coreai/projects/coreai_dlalgo_modelopt/users/haoguo/experiments \ +# SLURM_HF_LOCAL= \ +# SLURM_JOB_DIR= \ # NEMORUN_HOME=$PWD # uv run launch.py --yaml examples/moonshotai/Kimi-K2.5/hf_streaming_eagle3.yaml \ # identity=$HOME/.ssh/id_ecdsa detach=True --yes @@ -40,8 +34,7 @@ pipeline: _factory_: "slurm_factory" nodes: 1 ntasks_per_node: 1 - # HSG QOS (QOSMinGRES) requires whole-node GPU allocation (4 on GB200), - # so request 4 even though make_dataset is CPU-only. + # The cluster QOS requires whole-node GPU alloc (4) even though make_dataset is CPU-only. gpus_per_node: 4 container: nvcr.io/nvidia/tensorrt-llm/release:1.3.0rc10 @@ -55,10 +48,6 @@ pipeline: - model.trust_remote_code=true - data.mode=streaming - data.data_path=/scratchspace/data/train.jsonl - # Keep concurrent in-flight requests low: a 64-wide flood made cold NVFP4 - # MoE kernels/flashinfer autotune stall a worker past vLLM's engine<->worker - # timeout, killing EngineCore (TimeoutError) mid-serve -> 500s -> trainer abort. - - data.streaming_prefetch=8 - training.output_dir=/scratchspace/eagle3 - training.training_seq_len=4096 - training.disable_tqdm=true @@ -68,28 +57,25 @@ pipeline: - eagle.eagle_use_torch_compile=false environment: - HF_MODEL_CKPT: <> - # No spaces in values: nemo_run emits `export FOO=value` unquoted. + # No spaces: nemo_run emits `export FOO=value` unquoted. - EAGLE_CAPTURE_IDS: "[2,30,58,60]" - SERVE_TP: "4" - # Kimi-K2.5-NVFP4 ~138 GB weights/GPU at TP=4; GB200 has 184 GB. The model's - # native max_seq_len is 262144, whose KV cache OOMs (first attempt died with - # 183/184 GB used). Cap context to the training seq len and leave headroom - # for activation spikes during the profiling forward. + # Per-rank in-flight fetches; keep low so the cold NVFP4-MoE serve isn't flooded past its execute-model timeout (kills EngineCore). + - STREAMING_NUM_WORKERS: "1" + # Cap context to the train seq len; the model's native 262144 KV-cache OOMs at TP=4. - SERVE_MAX_MODEL_LEN: "4096" - # Small batches: smaller per-step MoE compute stays under the engine timeout. - SERVE_MAX_NUM_SEQS: "4" - SERVE_GPU_MEM_UTIL: "0.8" - SERVE_READY_TIMEOUT: "2400" - SERVE_EXTRA_ARGS: "--trust-remote-code" - # The killer was "RPC call to sample_tokens timed out" — a worker stalls on - # the first real serving step (cold NVFP4 MoE kernels) past vLLM's default - # execute-model timeout, so EngineCore dies. Extend the timeouts that govern - # that path (seconds). VLLM_RPC_TIMEOUT (ms) is a different RPC and didn't help. + # Cold NVFP4-MoE kernels stall the first serving step past vLLM's default execute-model timeout; raise it (seconds). - VLLM_EXECUTE_MODEL_TIMEOUT_SECONDS: "1200" - VLLM_ENGINE_ITERATION_TIMEOUT_S: "1200" slurm_config: _factory_: "slurm_factory" nodes: 2 + # Pin nodes into one NVL72 block (latency nicety here; essential when trainers do cross-node DDP). + segment: 2 ntasks_per_node: 1 gpus_per_node: 4 container: vllm/vllm-openai:latest diff --git a/tools/launcher/examples/moonshotai/Kimi-K2.5/hf_streaming_eagle3_multi_node.yaml b/tools/launcher/examples/moonshotai/Kimi-K2.5/hf_streaming_eagle3_multi_node.yaml new file mode 100644 index 00000000000..e57c78f3cc1 --- /dev/null +++ b/tools/launcher/examples/moonshotai/Kimi-K2.5/hf_streaming_eagle3_multi_node.yaml @@ -0,0 +1,103 @@ +# EAGLE3 streaming speculative decoding for Kimi-K2.5-NVFP4 on GB200/Blackwell +# MULTI-NODE: K serve replicas (TP=4, whole node) + (N-K) DDP trainer nodes. +# This file: nodes=4, SERVE_NODES=2 -> 2 serve + 2 trainer. See dispatch/scaling in +# common/eagle3/train_eagle_streaming.sh header. +# +# Requires GB200: native NVFP4 + 192 GB/GPU fits ~551 GB Kimi at TP=4 on one node. +# Capture ids = [2,30,58] aux + [60] base = 4 (kimi_k25/deepseek_v3, 61 layers; +# layer 60 is the last capturable, used as base). +# +# Run ON the cluster login node (paramiko can't reach the cluster through its login proxy): +# export SLURM_HOST=localhost SLURM_ACCOUNT= \ +# SLURM_PARTITION=batch \ +# SLURM_HF_LOCAL= \ +# SLURM_JOB_DIR= \ +# NEMORUN_HOME=$PWD +# uv run launch.py --yaml examples/moonshotai/Kimi-K2.5/hf_streaming_eagle3_multi_node.yaml \ +# identity=$HOME/.ssh/id_ecdsa detach=True --yes + +job_name: Kimi-K2.5-NVFP4_EAGLE3_streaming_multi_node +pipeline: + allow_to_fail: false + skip: false + note: + + global_vars: + hf_model: /hf-local/nvidia/Kimi-K2.5-NVFP4 + + task_0: + script: common/eagle3/make_dataset.sh + args: + - -f modules/Model-Optimizer/examples/dataset/example_data_config.yaml + - --full-conversations + slurm_config: + _factory_: "slurm_factory" + nodes: 1 + ntasks_per_node: 1 + # The cluster QOS requires whole-node GPU allocation though make_dataset is CPU-only. + gpus_per_node: 4 + container: nvcr.io/nvidia/tensorrt-llm/release:1.3.0rc10 + + task_1: + script: common/eagle3/train_eagle_streaming.sh + args: + - --config modules/Model-Optimizer/modelopt_recipes/general/speculative_decoding/eagle3.yaml + - model.model_name_or_path=<> + - model.use_fake_base_for_offline=true + - model.trust_remote_code=true + - data.mode=streaming + - data.data_path=/scratchspace/data/train.jsonl + - training.output_dir=/scratchspace/eagle3 + - training.training_seq_len=4096 + - training.disable_tqdm=true + - training.ar_validate_steps=500000 + - training.num_train_epochs=1 + - training.max_steps=500 + - eagle.eagle_use_torch_compile=false + environment: + - HF_MODEL_CKPT: <> + # No spaces in values: nemo_run emits `export FOO=value` unquoted. + - EAGLE_CAPTURE_IDS: "[2,30,58,60]" + - SERVE_NODES: "2" + - SERVE_TP: "4" + # Per-rank in-flight fetches; keep low so the cold NVFP4-MoE serve isn't flooded past its execute-model timeout (kills EngineCore). + - STREAMING_NUM_WORKERS: "1" + # Cap context to the train seq len; the model's native 262144 KV-cache OOMs at TP=4. + - SERVE_MAX_MODEL_LEN: "4096" + - SERVE_MAX_NUM_SEQS: "4" + - SERVE_GPU_MEM_UTIL: "0.8" + - SERVE_READY_TIMEOUT: "2400" + - SERVE_EXTRA_ARGS: "--trust-remote-code" + # Cold NVFP4-MoE kernels stall the first serving step past vLLM's default execute-model timeout; raise it (seconds). + - VLLM_EXECUTE_MODEL_TIMEOUT_SECONDS: "1200" + - VLLM_ENGINE_ITERATION_TIMEOUT_S: "1200" + slurm_config: + _factory_: "slurm_factory" + nodes: 4 + # Pin nodes into one NVL72 block (essential for cross-node trainer DDP). + segment: 4 + ntasks_per_node: 1 + gpus_per_node: 4 + container: vllm/vllm-openai:latest + + task_2: + script: common/specdec_bench/quick_check.sh + args: + - --draft_model_dir /scratchspace/export + - --draft_length 3 + - --output_length 4096 + - --engine VLLM + - --tp_size 4 + - --ep_size 1 + - --speculative_algorithm EAGLE3 + - --mtbench /hf-local/HuggingFaceH4/mt_bench_prompts/raw/question.jsonl + - --concurrency 32 + - --trust_remote_code + environment: + - HF_MODEL_CKPT: <> + slurm_config: + _factory_: "slurm_factory" + nodes: 1 + ntasks_per_node: 1 + gpus_per_node: 4 + container: vllm/vllm-openai:latest diff --git a/tools/launcher/examples/moonshotai/Kimi-K2.5/specdec_bench.yaml b/tools/launcher/examples/moonshotai/Kimi-K2.5/specdec_bench.yaml new file mode 100644 index 00000000000..7c37015d90d --- /dev/null +++ b/tools/launcher/examples/moonshotai/Kimi-K2.5/specdec_bench.yaml @@ -0,0 +1,63 @@ +# DFLASH speculative-decoding benchmark for Kimi-K2.5-NVFP4 via vLLM (in-process +# AsyncLLM, TP=4 + EP), benchmarking on MT-Bench. Outputs to /scratchspace/specdec_bench/. +# +# Requires GB200: native NVFP4 + 192 GB/GPU fits ~551 GB Kimi-K2.5-NVFP4 at TP=4 +# on one 4-GPU node. +# +# DFLASH: draft tokens default to 8 (=block_size); --draft_length does NOT apply. +# To override sampling/engine args, add `- --runtime_params ` (see +# examples/specdec_bench/README.md). +# +# NOTE on dataset: MT-Bench needs no data-prep. For SPEED-Bench instead, first run +# `prepare_data.py --dataset speed --config all`, then replace --mtbench with +# `--dataset speed` + `--dataset_path .../data/speed/`. +# +# NOTE on container: vllm/vllm-openai:latest is x86 and may lack DFLASH; on +# GB200/aarch64 use an aarch64 DFLASH-capable image (e.g. a 0511 nightly), via +# pipeline.task_0.slurm_config.container=. UNRESOLVED. +# +# Run ON the cluster login node (paramiko can't reach the cluster through its login proxy): +# export SLURM_HOST=localhost SLURM_ACCOUNT= \ +# SLURM_PARTITION=batch \ +# SLURM_HF_LOCAL= \ +# SLURM_JOB_DIR= \ +# NEMORUN_HOME=$PWD +# uv run launch.py --yaml examples/moonshotai/Kimi-K2.5/specdec_bench.yaml \ +# identity=$HOME/.ssh/id_ecdsa detach=True --yes + +job_name: Kimi-K2.5-NVFP4_DFLASH_specdec_bench + +pipeline: + allow_to_fail: false + skip: false + note: + + global_vars: + hf_model: /hf-local/nvidia/Kimi-K2.5-NVFP4 + # Trained+exported DFLASH draft; override: pipeline.global_vars.draft_model_dir= + draft_model_dir: /hf-local/nvidia/Kimi-K2.5-DFlash + + task_0: + script: common/specdec_bench/run.sh + args: + - --draft_model_dir <> + - --speculative_algorithm DFLASH + - --engine VLLM + - --mtbench /hf-local/HuggingFaceH4/mt_bench_prompts/raw/question.jsonl + - --tp_size 4 + - --ep_size 4 + - --concurrency 32 + - --output_length 1024 + - --trust_remote_code + - --aa_timing + - --show_progress + - --save_dir /scratchspace/specdec_bench + environment: + - HF_MODEL_CKPT: <> + - HF_LOCAL: /hf-local + slurm_config: + _factory_: "slurm_factory" + nodes: 1 + ntasks_per_node: 1 + gpus_per_node: 4 + container: vllm/vllm-openai:latest diff --git a/tools/launcher/slurm_config.py b/tools/launcher/slurm_config.py index 8ecd51f6f86..9c3c853e877 100644 --- a/tools/launcher/slurm_config.py +++ b/tools/launcher/slurm_config.py @@ -24,6 +24,8 @@ import nemo_run as run +__all__ = ["SlurmConfig", "slurm_factory"] + @dataclass class SlurmConfig: @@ -48,6 +50,11 @@ class SlurmConfig: gpus_per_node: int = 1 time: str = "04:00:00" local: bool = False + # Slurm --segment=: force the job's nodes into a single topology block. + # On a topology/block cluster (e.g. GB200 NVL72, where one block = one NVLink + # domain) set this to the node count to keep all nodes in one NVL72 so + # inter-node traffic rides NVLink. None = let the scheduler place freely. + segment: Optional[int] = None @run.cli.factory @@ -68,6 +75,7 @@ def slurm_factory( srun_args: list[str] = ["--no-container-mount-home"], array: Optional[str] = None, time: str = "04:00:00", + segment: Optional[int] = None, ) -> SlurmConfig: """Generic Slurm factory — configure via environment variables or CLI overrides.""" return SlurmConfig( @@ -84,4 +92,5 @@ def slurm_factory( srun_args=srun_args, array=array, time=time, + segment=segment, )