Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions examples/specdec_bench/specdec_bench/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,10 @@ def _checkpoint_provenance(model_dir):


def _is_sensitive_key(key):
# Engine configs can carry non-string dict keys (e.g. int layer ids in a
# serving_config); those are never sensitive field *names*, so skip them.
if not isinstance(key, str):
return False
klow = key.lower()
if klow in _SENSITIVE_KEY_ALLOWLIST:
return False
Expand Down
13 changes: 8 additions & 5 deletions examples/speculative_decoding/eagle_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,6 @@ def make_speculative_data_module(
train_len=None,
answer_only_loss=False,
shift_labels=True,
seed: int = 0,
) -> dict:
"""Create data module for speculative decoding training.

Expand Down Expand Up @@ -88,14 +87,16 @@ def make_speculative_data_module(
ds = load_dataset("json", data_files=data_args.data_path, split="train")
if data_args.sample_size > 0:
ds = ds.select(range(data_args.sample_size))
# Map-style dataset: each rank fetches its own DistributedSampler shard.
# Fetch concurrency comes from the DataLoader's num_workers, not a config knob;
# shuffling/order is the sampler's job (seeded by training_args.seed).
# ``server_urls`` accepts a comma-separated string for multi-server fan-out.
streaming_cfg = EagleVllmStreamingConfig(
server_url=data_args.streaming_server_url,
server_urls=data_args.streaming_server_url,
model=data_args.streaming_model_name,
shared_storage_root=data_args.streaming_shared_storage_path,
max_seq_len=train_len,
answer_only_loss=answer_only_loss,
prefetch=data_args.streaming_prefetch,
seed=seed,
)
train_dataset = EagleVllmStreamingDataset(
entries=ds,
Expand Down Expand Up @@ -138,7 +139,9 @@ def make_speculative_data_module(
raise ValueError("sample_size must be -1 (use all samples) or a positive integer")
if data_args.sample_size > 0:
dumped_files = dumped_files[: data_args.sample_size]
train_dataset = OfflineSupervisedDataset(dumped_files, answer_only_loss=answer_only_loss)
train_dataset = OfflineSupervisedDataset(
dumped_files, answer_only_loss=answer_only_loss, tokenizer=tokenizer
)
data_collator = EagleOfflineDataCollator(train_len=train_len)

return {
Expand Down
34 changes: 21 additions & 13 deletions examples/speculative_decoding/launch_train.sh
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,8 @@
# Multi-node: ./launch_train.sh --config ../../modelopt_recipes/general/speculative_decoding/eagle3.yaml --num_nodes 2 --head_node_ip <IP>
# With overrides: ./launch_train.sh --config my.yaml model.model_name_or_path=xxx training.output_dir=yyy
#
# Extra key=value args are forwarded as OmegaConf dotlist overrides to main.py.
# All training config (model, data, hyperparams, eagle, fsdp) lives in the YAML file.
# Only multi-node routing args are passed here; mixed_precision is fixed to bf16.
# Extra key=value args are forwarded as OmegaConf dotlist overrides to main.py; all
# training config lives in the YAML. mixed_precision is fixed to bf16.

set -eo pipefail

Expand All @@ -30,12 +29,14 @@ SCRIPT_DIR="$(dirname "$(readlink -f "$0")")"
CONFIG_FILE=""
NUM_NODES=1
HEAD_NODE_IP=""
MACHINE_RANK=""
EXTRA_ARGS=()
while [ $# -gt 0 ]; do
case "$1" in
--config*) if [[ "$1" != *=* ]]; then shift; fi; CONFIG_FILE="${1#*=}" ;;
--num_nodes*) if [[ "$1" != *=* ]]; then shift; fi; NUM_NODES="${1#*=}" ;;
--head_node_ip*) if [[ "$1" != *=* ]]; then shift; fi; HEAD_NODE_IP="${1#*=}" ;;
--machine_rank*) if [[ "$1" != *=* ]]; then shift; fi; MACHINE_RANK="${1#*=}" ;;
*) EXTRA_ARGS+=("$1") ;;
esac
shift
Expand All @@ -46,7 +47,6 @@ if [ -z "$CONFIG_FILE" ]; then
exit 1
fi

# GPU count detection
if [[ "$NUM_NODES" != "1" ]]; then
GPU_PER_NODE=${GPU_PER_NODE:-$(nvidia-smi --query-gpu=name --format=csv,noheader | wc -l)}
TOTAL_GPU=$((NUM_NODES * GPU_PER_NODE))
Expand All @@ -56,20 +56,28 @@ else
echo "Total GPUs: $TOTAL_GPU (single node)"
fi

# Multi-node routing args (accelerate only; training config comes from the YAML)
MULTI_NODE_ARGS=""
MULTI_NODE_ARGS=()
if [[ "$NUM_NODES" != "1" ]]; then
MULTI_NODE_ARGS="--num_processes $TOTAL_GPU \
--num_machines $NUM_NODES \
--machine_rank $SLURM_PROCID \
--rdzv_backend c10d \
--main_process_ip $HEAD_NODE_IP \
--main_process_port 29500"
# --multi_gpu is required even at 1 GPU/node, else accelerate won't form the DDP group.
# machine_rank defaults to $SLURM_PROCID; override --machine_rank if node 0 isn't a trainer.
MULTI_NODE_ARGS=(
--multi_gpu
--num_processes "$TOTAL_GPU"
--num_machines "$NUM_NODES"
--machine_rank "${MACHINE_RANK:-$SLURM_PROCID}"
--main_process_ip "$HEAD_NODE_IP"
--main_process_port 29500
)
fi

export TOKENIZERS_PARALLELISM=False

# argv array, not `sh -c` (which would word-split overrides and run embedded substitutions).
CMD=(accelerate launch --mixed_precision bf16
"${MULTI_NODE_ARGS[@]}"
"${SCRIPT_DIR}/main.py" --config "$CONFIG_FILE" "${EXTRA_ARGS[@]}")

set -x
start_time=$(date +%s)
sh -c "accelerate launch --mixed_precision bf16 $MULTI_NODE_ARGS ${SCRIPT_DIR}/main.py --config $CONFIG_FILE ${EXTRA_ARGS[*]}"
"${CMD[@]}"
echo "Total time: $(( $(date +%s) - $start_time )) seconds"
12 changes: 4 additions & 8 deletions examples/speculative_decoding/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,7 +267,6 @@ def train():
train_len=training_args.training_seq_len,
answer_only_loss=training_args.answer_only_loss,
shift_labels=not is_dflash,
seed=training_args.seed,
)

callbacks = [EagleTrainingPlot(training_args.ar_validate_steps, training_args.estimate_ar)]
Expand All @@ -277,13 +276,10 @@ def train():
and recipe.eagle.eagle_base_lora_warmup_steps > 0
):
callbacks.append(LoRAWarmupCallback(recipe.eagle.eagle_base_lora_warmup_steps))
if recipe.data.mode == "streaming":
# Skip-on-resume happens inside the dataset (no re-fetch from server);
# disable HF Trainer's own data skip so the offset isn't applied twice.
from modelopt.torch.speculative.plugins.hf_streaming_dataset import StreamingResumeCallback

training_args.ignore_data_skip = True
callbacks.append(StreamingResumeCallback())
# Leave training_args.ignore_data_skip at its default (False). The dataset is
# map-style, so HF Trainer's resume skips consumed indices at the batch-sampler
# level (accelerate.skip_first_batches) without re-fetching them, landing at the
# exact data position. Setting it True would restart the data order from the top.

trainer = EagleTrainerWithAccLog(
model=model,
Expand Down
18 changes: 17 additions & 1 deletion modelopt/recipe/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,18 @@
TrainingArguments as SpecTrainingArgs,
)

__all__ = [
"RECIPE_TYPE_TO_CLASS",
"ModelOptDFlashRecipe",
"ModelOptEagleRecipe",
"ModelOptMedusaRecipe",
"ModelOptPTQRecipe",
"ModelOptRecipeBase",
"ModelOptSpeculativeRecipeBase",
"RecipeMetadataConfig",
"RecipeType",
]


class RecipeType(str, Enum):
"""List of recipe types. See ``RECIPE_TYPE_TO_CLASS`` at the bottom for the schema mapping."""
Expand Down Expand Up @@ -178,7 +190,11 @@ class ModelOptDFlashRecipe(ModelOptSpeculativeRecipeBase):

@model_validator(mode="after")
def _derive_dflash_offline(self) -> ModelOptDFlashRecipe:
self.dflash.dflash_offline = self.data.offline_data_path is not None
# offline (dumped .pt) and streaming (hidden states over HTTP from a vLLM
# serve) both feed pre-computed base hidden states to the DFlash module, so
# both set dflash_offline. Only fully-online training runs the base model.
# Mirrors ModelOptEagleRecipe._derive_eagle_offline.
self.dflash.dflash_offline = self.data.mode != "online"
Comment thread
h-guo18 marked this conversation as resolved.
return self


Expand Down
18 changes: 16 additions & 2 deletions modelopt/torch/speculative/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,18 @@

from .eagle.default_config import default_eagle_config, default_kimik2_eagle_config

__all__ = [
"DFLASH_DEFAULT_CFG",
"EAGLE3_DEFAULT_CFG",
"EAGLE_MTP_DEFAULT_CFG",
"DFlashConfig",
"EagleConfig",
"MedusaConfig",
"eagle3_default_config",
"eagle_mtp_default_config",
"kimik2_eagle_default_config",
]

kimik2_eagle_default_config = deepcopy(default_kimik2_eagle_config)

eagle3_default_config = deepcopy(default_eagle_config)
Expand Down Expand Up @@ -68,8 +80,10 @@ class DFlashConfig(ModeloptBaseConfig):
dflash_offline: bool = ModeloptField(
default=False,
description=(
"Whether to use detached DFlash (offline training from pre-computed hidden states). "
"Derived by ModelOptDFlashRecipe from data.offline_data_path; not user-configurable."
"Whether the DFlash module consumes pre-computed hidden states (offline from "
"dumped .pt files, or streaming over HTTP from a vLLM serve) instead of running "
"the base model. Derived by ModelOptDFlashRecipe from data.mode (True unless "
"online); not user-configurable."
Comment thread
h-guo18 marked this conversation as resolved.
),
)

Expand Down
26 changes: 22 additions & 4 deletions modelopt/torch/speculative/eagle/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,8 @@
from torch.utils.data import Dataset
from transformers.trainer_pt_utils import LabelSmoother

from modelopt.torch.utils.loss_mask import get_loss_mask_recovery

IGNORE_TOKEN_ID = LabelSmoother.ignore_index


Expand Down Expand Up @@ -96,20 +98,27 @@ class OfflineSupervisedDataset(Dataset):
dumped_files (list): A list of file paths to the dumped .pt files.
answer_only_loss (bool): If True, use the ``loss_mask`` stored in each .pt
file so that only assistant-produced tokens contribute to the loss.
Raises ``ValueError`` on ``__getitem__`` if the file lacks ``loss_mask``.
If a file lacks ``loss_mask`` and ``tokenizer`` has a registered
model-specific recovery (see ``modelopt.torch.utils.loss_mask``), the
mask is rebuilt from ``input_ids``; otherwise ``__getitem__`` raises
``ValueError``.
If False (default), a uniform all-ones mask is used regardless of what
is stored in the file (backward compatible).
tokenizer: Optional tokenizer used to recover the assistant mask for dumps
that lack a stored ``loss_mask``.
"""

def __init__(
self,
dumped_files,
answer_only_loss: bool = False,
tokenizer=None,
):
"""Initialize with a list of .pt file paths."""
super().__init__()
self.dumped_files = dumped_files
self.answer_only_loss = answer_only_loss
self.tokenizer = tokenizer

def __len__(self):
return len(self.dumped_files)
Expand All @@ -121,13 +130,22 @@ def __getitem__(self, i) -> dict[str, torch.Tensor]:
labels[..., :-1] = offline_data["input_ids"][..., 1:]

if self.answer_only_loss:
if "loss_mask" not in offline_data:
recovery = get_loss_mask_recovery(self.tokenizer) if self.tokenizer else None
if "loss_mask" in offline_data:
loss_mask = offline_data["loss_mask"].to(offline_data["input_ids"].dtype)
elif recovery is not None:
# Dumps from tokenizers that cannot emit assistant masks carry no
# loss_mask; rebuild it from the token ids.
loss_mask = recovery.compute(self.tokenizer, offline_data["input_ids"]).to(
offline_data["input_ids"].dtype
)
else:
raise ValueError(
f"answer_only_loss=True requires a 'loss_mask' entry in the offline "
f".pt file, but {self.dumped_files[i]} does not have one. Re-dump "
f"with --answer-only-loss in compute_hidden_states_*.py."
f"with --answer-only-loss in compute_hidden_states_*.py, or pass a "
f"tokenizer with a registered loss-mask recovery."
)
loss_mask = offline_data["loss_mask"].to(offline_data["input_ids"].dtype)
else:
loss_mask = torch.ones_like(offline_data["input_ids"])

Expand Down
Loading
Loading