From 8bbf26abfab5db45267a920ebadcafea34c07e11 Mon Sep 17 00:00:00 2001 From: h-guo18 <67671475+h-guo18@users.noreply.github.com> Date: Sun, 29 Mar 2026 22:08:03 +0000 Subject: [PATCH 01/21] yamlize eagle configs Signed-off-by: h-guo18 <67671475+h-guo18@users.noreply.github.com> --- .../configs/kimi_k25_offline.yaml | 56 ++++ .../configs/llama32_1b_online.yaml | 56 ++++ .../speculative_decoding/eagle_config.json | 2 - .../speculative_decoding/fsdp_config.json | 1 - examples/speculative_decoding/launch_train.sh | 280 ++---------------- examples/speculative_decoding/main.py | 92 +++--- .../train_eagle3_and_export.sh | 104 +++++-- .../common/eagle3/offline_training.sh | 1 - 8 files changed, 268 insertions(+), 324 deletions(-) create mode 100644 examples/speculative_decoding/configs/kimi_k25_offline.yaml create mode 100644 examples/speculative_decoding/configs/llama32_1b_online.yaml delete mode 100644 examples/speculative_decoding/eagle_config.json delete mode 100644 examples/speculative_decoding/fsdp_config.json diff --git a/examples/speculative_decoding/configs/kimi_k25_offline.yaml b/examples/speculative_decoding/configs/kimi_k25_offline.yaml new file mode 100644 index 0000000000..08e3757fd6 --- /dev/null +++ b/examples/speculative_decoding/configs/kimi_k25_offline.yaml @@ -0,0 +1,56 @@ +model: + use_fake_base_for_offline: true + trust_remote_code: true + +data: + offline_data_path: /home/haoguo/modelopt/examples/speculative_decoding/data/kimi_k2_5_offline + data_path: + draft_vocab_cache: + vlm_processor: + vlm_img_dir: + +training: + output_dir: ckpts/k25-test + num_train_epochs: 50 + per_device_train_batch_size: 1 + per_device_eval_batch_size: 1 + learning_rate: 1.0e-6 + weight_decay: 0.0 + warmup_steps: 100 + lr_scheduler_type: linear + gradient_accumulation_steps: 1 + training_seq_len: 64 + ar_validate_steps: 100000 + logging_steps: 1 + save_steps: 8192 + save_strategy: steps + do_eval: false + eval_accumulation_steps: 1 + dataloader_drop_last: true + bf16: true + tf32: true + remove_unused_columns: false + disable_tqdm: false + estimate_ar: false + mode: eagle3 + cp_size: 1 + dp_shard_size: 0 # 0 = auto: total_gpu / cp_size + fsdp: "" # set to "full_shard" to enable FSDP2 + fsdp_config: + fsdp_version: 2 + +eagle: + # Fields map directly to EagleConfig (modelopt/torch/speculative/config.py). + # eagle_offline is derived from data.offline_data_path; do not set here. + eagle_decoder_type: kimik2 + eagle_ttt_steps: 3 + eagle_mix_hidden_states: false + eagle_use_torch_compile: true + eagle_self_logit_distillation: true + eagle_freeze_base_model: true + eagle_loss_decay_factor: 0.9 + eagle_hidden_state_distillation: false + eagle_reuse_base_decoder: false + eagle_report_acc: true + eagle_enable_nvtx: false + eagle_architecture_config: {} diff --git a/examples/speculative_decoding/configs/llama32_1b_online.yaml b/examples/speculative_decoding/configs/llama32_1b_online.yaml new file mode 100644 index 0000000000..dd9e8ecfb2 --- /dev/null +++ b/examples/speculative_decoding/configs/llama32_1b_online.yaml @@ -0,0 +1,56 @@ +model: + use_fake_base_for_offline: false + trust_remote_code: false + +data: + data_path: /home/haoguo/modelopt/examples/speculative_decoding/input_conversations/daring-anteater.jsonl + offline_data_path: + draft_vocab_cache: + vlm_processor: + vlm_img_dir: + +training: + output_dir: ckpts/llama-3.2-1b-online + num_train_epochs: 1 + per_device_train_batch_size: 1 + per_device_eval_batch_size: 1 + learning_rate: 1.0e-4 + weight_decay: 0.0 + warmup_steps: 100 + lr_scheduler_type: linear + gradient_accumulation_steps: 1 + training_seq_len: 512 + ar_validate_steps: 100000 + logging_steps: 100 + save_steps: 8192 + save_strategy: steps + do_eval: false + eval_accumulation_steps: 1 + dataloader_drop_last: true + bf16: true + tf32: true + remove_unused_columns: false + disable_tqdm: false + estimate_ar: false + mode: eagle3 + cp_size: 1 + dp_shard_size: 0 # 0 = auto: total_gpu / cp_size + fsdp: "" # set to "full_shard" to enable FSDP2 + fsdp_config: + fsdp_version: 2 + +eagle: + # Fields map directly to EagleConfig (modelopt/torch/speculative/config.py). + # eagle_offline is derived from data.offline_data_path; do not set here. + eagle_decoder_type: llama + eagle_ttt_steps: 3 + eagle_mix_hidden_states: false + eagle_use_torch_compile: true + eagle_self_logit_distillation: true + eagle_freeze_base_model: true + eagle_loss_decay_factor: 0.9 + eagle_hidden_state_distillation: false + eagle_reuse_base_decoder: false + eagle_report_acc: true + eagle_enable_nvtx: false + eagle_architecture_config: {} diff --git a/examples/speculative_decoding/eagle_config.json b/examples/speculative_decoding/eagle_config.json deleted file mode 100644 index 2c63c08510..0000000000 --- a/examples/speculative_decoding/eagle_config.json +++ /dev/null @@ -1,2 +0,0 @@ -{ -} diff --git a/examples/speculative_decoding/fsdp_config.json b/examples/speculative_decoding/fsdp_config.json deleted file mode 100644 index 6d934182fe..0000000000 --- a/examples/speculative_decoding/fsdp_config.json +++ /dev/null @@ -1 +0,0 @@ -{"fsdp_version":2} \ No newline at end of file diff --git a/examples/speculative_decoding/launch_train.sh b/examples/speculative_decoding/launch_train.sh index c15b97bdaa..6d248784d2 100755 --- a/examples/speculative_decoding/launch_train.sh +++ b/examples/speculative_decoding/launch_train.sh @@ -14,277 +14,61 @@ # See the License for the specific language governing permissions and # limitations under the License. +# Usage: +# Single GPU: ./launch_train.sh --config configs/my_experiment.yaml +# Multi-node: ./launch_train.sh --config configs/my_experiment.yaml --num_nodes 2 --head_node_ip +# +# All training config (model, data, hyperparams, eagle, fsdp) lives in the YAML file. +# Only multi-node routing args are passed here; mixed_precision is fixed to bf16. + set -eo pipefail +SCRIPT_DIR="$(dirname "$(readlink -f "$0")")" + +CONFIG_FILE="" +MODEL="" +NUM_NODES=1 +HEAD_NODE_IP="" while [ $# -gt 0 ]; do case "$1" in - --training_seq_len*) - if [[ "$1" != *=* ]]; then shift; fi - TRAINING_SEQ_LEN="${1#*=}" - ;; - --model*) - if [[ "$1" != *=* ]]; then shift; fi - MODEL="${1#*=}" - ;; - --data*) - if [[ "$1" != *=* ]]; then shift; fi - DATA="${1#*=}" - ;; - --offline-data*) - if [[ "$1" != *=* ]]; then shift; fi - OFFLINE_DATA_PATH="${1#*=}" - ;; - --mode*) - if [[ "$1" != *=* ]]; then shift; fi - MODE="${1#*=}" - ;; - --eagle_decoder_type*) - if [[ "$1" != *=* ]]; then shift; fi - EAGLE_DECODER_TYPE="${1#*=}" - ;; - --output_dir*) - if [[ "$1" != *=* ]]; then shift; fi - OUTPUT_DIR="${1#*=}" - ;; - --num_epochs*) - if [[ "$1" != *=* ]]; then shift; fi - NUM_EPOCHS="${1#*=}" - ;; - --save_steps*) - if [[ "$1" != *=* ]]; then shift; fi - SAVE_STEPS="${1#*=}" - ;; - --lr*) - if [[ "$1" != *=* ]]; then shift; fi - LR="${1#*=}" - ;; - --train_bs*) - if [[ "$1" != *=* ]]; then shift; fi - TRAIN_BS="${1#*=}" - ;; - --eagle_config*) - if [[ "$1" != *=* ]]; then shift; fi - EAGLE_CONFIG="${1#*=}" - ;; - --disable_tqdm*) - if [[ "$1" != *=* ]]; then shift; fi - DISABLE_TQDM="${1#*=}" - ;; - --vlm_processor*) - if [[ "$1" != *=* ]]; then shift; fi - VLM_PROCESSOR="${1#*=}" - ;; - --vlm_img_dir*) - if [[ "$1" != *=* ]]; then shift; fi - VLM_IMG_DIR="${1#*=}" - ;; - --estimate_ar*) - if [[ "$1" != *=* ]]; then shift; fi - ESTIMATE_AR="${1#*=}" - ;; - --ar_validate_steps*) - if [[ "$1" != *=* ]]; then shift; fi - AR_VALIDATE_STEPS="${1#*=}" - ;; - --num_ttt_steps*) - if [[ "$1" != *=* ]]; then shift; fi - NUM_TTT_STEPS="${1#*=}" - ;; - --cp_size*) - if [[ "$1" != *=* ]]; then shift; fi - CP_SIZE="${1#*=}" - ;; - --dp_size*) - if [[ "$1" != *=* ]]; then shift; fi - DP_SHARD_SIZE="${1#*=}" - ;; - --log_steps*) - if [[ "$1" != *=* ]]; then shift; fi - LOG_STEPS="${1#*=}" - ;; - --draft_vocab_cache*) - if [[ "$1" != *=* ]]; then shift; fi - DRAFT_VOCAB_CACHE="${1#*=}" - ;; - --num_nodes*) - if [[ "$1" != *=* ]]; then shift; fi - NUM_NODES="${1#*=}" - ;; - --head_node_ip*) - if [[ "$1" != *=* ]]; then shift; fi - HEAD_NODE_IP="${1#*=}" - ;; - --mix_hidden_states*) - if [[ "$1" != *=* ]]; then shift; fi - MIX_HIDDEN_STATES="${1#*=}" - ;; - --disable_torch_compile*) - if [[ "$1" != *=* ]]; then shift; fi - DISABLE_TORCH_COMPILE="${1#*=}" - ;; - --use_fake_base_for_offline*) - if [[ "$1" != *=* ]]; then shift; fi - USE_FAKE_BASE_FOR_OFFLINE="${1#*=}" - ;; - --trust_remote_code*) - if [[ "$1" != *=* ]]; then shift; fi - TRUST_REMOTE_CODE="${1#*=}" - ;; - --fsdp*) - if [[ "$1" != *=* ]]; then shift; fi - FSDP="${1#*=}" - ;; - *) - >&2 printf "Error: Invalid argument ${1#*=}\n" - exit 1 - ;; + --config*) if [[ "$1" != *=* ]]; then shift; fi; CONFIG_FILE="${1#*=}" ;; + --model*) if [[ "$1" != *=* ]]; then shift; fi; MODEL="${1#*=}" ;; + --num_nodes*) if [[ "$1" != *=* ]]; then shift; fi; NUM_NODES="${1#*=}" ;; + --head_node_ip*) if [[ "$1" != *=* ]]; then shift; fi; HEAD_NODE_IP="${1#*=}" ;; + *) >&2 echo "Error: Unknown argument '$1'"; exit 1 ;; esac shift done -set -x +if [ -z "$CONFIG_FILE" ] || [ -z "$MODEL" ]; then + >&2 echo "Usage: ./launch_train.sh --config --model [--num_nodes N] [--head_node_ip IP]" + exit 1 +fi -SCRIPT_DIR="$(dirname "$(readlink -f "$0")")" -NUM_NODES=${NUM_NODES:-1} -if [[ "$NUM_NODES" != 1 ]]; then - #Multi Node Training +# GPU count detection +if [[ "$NUM_NODES" != "1" ]]; then GPU_PER_NODE=${GPU_PER_NODE:-$(nvidia-smi --query-gpu=name --format=csv,noheader | wc -l)} TOTAL_GPU=$((NUM_NODES * GPU_PER_NODE)) echo "Total GPUs: $TOTAL_GPU (NUM_NODES: $NUM_NODES, GPU_PER_NODE: $GPU_PER_NODE)" else - #Single Node Training, GPU can be specified by $CUDA_VISIBLE_DEVICES - TOTAL_GPU=$(python -c "import torch; print(torch.cuda.device_count())") - echo "Total GPUs: $TOTAL_GPU (Single Node Training)" -fi -# Calculate save_steps -DEFAULT_SAVE_STEPS=$((8192 / TOTAL_GPU)) - -MODEL=${MODEL:-"TinyLlama/TinyLlama-1.1B-Chat-v1.0"} -MODE=${MODE:-"eagle3"} -EAGLE_DECODER_TYPE=${EAGLE_DECODER_TYPE:-"llama"} -# Set default OUTPUT_DIR to ckpts/{modelname}, where {modelname} is the last part of the model path -MODEL_BASENAME=$(basename "$MODEL") -OUTPUT_DIR=${OUTPUT_DIR:-"ckpts/${MODEL_BASENAME}-$(date +%Y%m%d_%H%M)"} -NUM_EPOCHS=${NUM_EPOCHS:-1} -SAVE_STEPS=${SAVE_STEPS:-$DEFAULT_SAVE_STEPS} -LR=${LR:-"1e-4"} -TRAIN_BS=${TRAIN_BS:-1} -TRAINING_SEQ_LEN=${TRAINING_SEQ_LEN:-2048} -DATA=${DATA:-""} -OFFLINE_DATA_PATH=${OFFLINE_DATA_PATH:-""} -DISABLE_TQDM=${DISABLE_TQDM:-False} -VLM_PROCESSOR=${VLM_PROCESSOR:-} -VLM_IMG_DIR=${VLM_IMG_DIR:-} -AR_VALIDATE_STEPS=${AR_VALIDATE_STEPS:-1000} -ESTIMATE_AR=${ESTIMATE_AR:-False} -CP_SIZE=${CP_SIZE:-1} -DP_SHARD_SIZE=${DP_SHARD_SIZE:-$((TOTAL_GPU/CP_SIZE))} -LOG_STEPS=${LOG_STEPS:-100} -DRAFT_VOCAB_CACHE=${DRAFT_VOCAB_CACHE:-""} -MIX_HIDDEN_STATES=${MIX_HIDDEN_STATES:-"False"} -DISABLE_TORCH_COMPILE=${DISABLE_TORCH_COMPILE:-"False"} -NUM_TTT_STEPS=${NUM_TTT_STEPS:-3} - -USE_FAKE_BASE_FOR_OFFLINE=${USE_FAKE_BASE_FOR_OFFLINE:-"False"} -TRUST_REMOTE_CODE=${TRUST_REMOTE_CODE:-"False"} -FSDP=${FSDP:-"False"} - -if [[ "$MODE" == "eagle3" ]]; then - if [[ -n "$EAGLE_CONFIG" ]]; then - SPECULATIVE_ARGS="--eagle_config $EAGLE_CONFIG" - else - SPECULATIVE_ARGS="" - fi -else - echo "Only eagle3 supported for now!" - exit 1 -fi - -if [[ "$OFFLINE_DATA_PATH" != "" ]]; then - if [[ ! -d "$OFFLINE_DATA_PATH" ]]; then - echo "Offline data path $OFFLINE_DATA_PATH does not exist or is not a directory." - exit 1 - else - DATA_ARGS="--offline-data-path $OFFLINE_DATA_PATH --ar_validate_steps -1" - fi -else - DATA_ARGS="--data_path $DATA" -fi - - -if [[ "$VLM_PROCESSOR" != "" ]]; then - VLM_ARGS="--vlm_processor $VLM_PROCESSOR --vlm_img_dir $VLM_IMG_DIR" -else - VLM_ARGS="" + TOTAL_GPU=$(python3 -c "import torch; print(torch.cuda.device_count())") + echo "Total GPUs: $TOTAL_GPU (single node)" fi -if [[ "$TOTAL_GPU" -gt 1 && "$FSDP" == "True" ]]; then - #Use FSDP2 when multi GPU available - FSDP_ARGS="--fsdp 'full_shard' --fsdp_config ${SCRIPT_DIR}/fsdp_config.json" -else - #Otherwise, single GPU training - FSDP_ARGS="" -fi - - -if [[ "$DRAFT_VOCAB_CACHE" != "" ]]; then - DRAFT_VOCAB_CACHE_ARGS="--draft_vocab_cache $DRAFT_VOCAB_CACHE" -else - DRAFT_VOCAB_CACHE_ARGS="" -fi - -if [[ "$NUM_NODES" != 1 ]]; then +# Multi-node routing args (accelerate only; training config comes from the YAML) +MULTI_NODE_ARGS="" +if [[ "$NUM_NODES" != "1" ]]; then MULTI_NODE_ARGS="--num_processes $TOTAL_GPU \ --num_machines $NUM_NODES \ --machine_rank $SLURM_PROCID \ --rdzv_backend c10d \ --main_process_ip $HEAD_NODE_IP \ --main_process_port 29500" -else - MULTI_NODE_ARGS="" fi -# Disable tokenizers parallelism to avoid warning export TOKENIZERS_PARALLELISM=False -CMD="accelerate launch $MULTI_NODE_ARGS --mixed_precision bf16 ${SCRIPT_DIR}/main.py \ - --mode $MODE \ - --eagle_decoder_type $EAGLE_DECODER_TYPE \ - --model_name_or_path $MODEL \ - --training_seq_len $TRAINING_SEQ_LEN \ - --dataloader_drop_last True \ - --bf16 True \ - --output_dir $OUTPUT_DIR \ - --num_train_epochs $NUM_EPOCHS \ - --per_device_train_batch_size $TRAIN_BS \ - --per_device_eval_batch_size $TRAIN_BS \ - --gradient_accumulation_steps 1 \ - --do_eval False \ - --eval_accumulation_steps 1 \ - --save_strategy steps \ - --save_steps $SAVE_STEPS \ - --learning_rate $LR \ - --weight_decay 0.0 \ - --warmup_steps 100 \ - --lr_scheduler_type linear \ - --logging_steps $LOG_STEPS \ - --tf32 True \ - $DATA_ARGS \ - --disable_tqdm $DISABLE_TQDM \ - --estimate_ar $ESTIMATE_AR \ - --ar_validate_steps $AR_VALIDATE_STEPS \ - --mix_hidden_states $MIX_HIDDEN_STATES \ - --disable_torch_compile $DISABLE_TORCH_COMPILE \ - --use_fake_base_for_offline $USE_FAKE_BASE_FOR_OFFLINE \ - --trust_remote_code $TRUST_REMOTE_CODE \ - $DRAFT_VOCAB_CACHE_ARGS \ - $VLM_ARGS \ - $SPECULATIVE_ARGS \ - $FSDP_ARGS \ - --cp_size $CP_SIZE \ - --dp_shard_size $DP_SHARD_SIZE \ - --num_ttt_steps $NUM_TTT_STEPS \ -" +set -x start_time=$(date +%s) -sh -c "$CMD" -echo "Total time taken: $(( $(date +%s) - $start_time )) seconds" +sh -c "accelerate launch --mixed_precision bf16 $MULTI_NODE_ARGS ${SCRIPT_DIR}/main.py --config $CONFIG_FILE --model $MODEL" +echo "Total time: $(( $(date +%s) - $start_time )) seconds" diff --git a/examples/speculative_decoding/main.py b/examples/speculative_decoding/main.py index 3369d399c2..19ff9a02e4 100644 --- a/examples/speculative_decoding/main.py +++ b/examples/speculative_decoding/main.py @@ -29,13 +29,14 @@ # See the License for the specific language governing permissions and # limitations under the License. -import json +import argparse import os from dataclasses import dataclass, field from typing import Literal import torch import transformers +import yaml from accelerate import ParallelismConfig from eagle_utils import ( EagleTrainerWithAccLog, @@ -123,42 +124,64 @@ class MedusaArguments: medusa_num_layers: int | None = field(default=1) -@dataclass -class EagleArguments: - eagle_config: str = field(default=None, metadata={"help": "Path to eagle_config.json"}) - eagle_decoder_type: str = field( - default="llama", - metadata={"help": "The class of eagle decoder to use. Available options: llama, kimik2"}, - ) - mix_hidden_states: bool = field( - default=False, - metadata={"help": "Whether to mix hidden states from previous TTT step."}, - ) - disable_torch_compile: bool = field( - default=False, - metadata={"help": "Disable torch.compile on eagle forward/loss methods."}, - ) - num_ttt_steps: int = field( - default=3, - metadata={"help": "Number of train-time-test steps to use during training."}, - ) +def _parse_cli() -> tuple[str, str]: + """Parse required --config and --model from argv; ignore all other args.""" + p = argparse.ArgumentParser(add_help=False) + p.add_argument("--config", required=True, help="Path to the YAML config file.") + p.add_argument("--model", required=True, help="Model name or path (overrides yaml).") + args, _ = p.parse_known_args() + return args.config, args.model + + +def _load_config(config_path: str) -> tuple[dict, dict]: + """Load training config from a YAML file with sections: model, data, training, eagle. + + Returns: + hf_cfg: Flat dict from model/data/training sections, for HfArgumentParser.parse_dict() + eagle_cfg: Eagle section dict (EagleConfig fields), passed directly to mtsp.convert() + """ + with open(config_path) as f: + cfg = yaml.safe_load(f) + + # Eagle section maps directly to EagleConfig fields — no field enumeration needed. + # eagle_architecture_config is a nested dict and is included as-is. + eagle_cfg = cfg.get("eagle", {}) + + hf_cfg = { + **cfg.get("model", {}), + **cfg.get("data", {}), + **cfg.get("training", {}), + } + + # dp_shard_size sentinel: 0 means auto-compute as total_gpu / cp_size + if hf_cfg.get("dp_shard_size", 1) == 0: + cp_size = hf_cfg.get("cp_size", 1) + hf_cfg["dp_shard_size"] = torch.cuda.device_count() // cp_size + + return hf_cfg, eagle_cfg def train(): + config_path, model_path = _parse_cli() + hf_cfg, eagle_cfg = _load_config(config_path) + hf_cfg["model_name_or_path"] = model_path + parser = transformers.HfArgumentParser( ( ModelArguments, DataArguments, TrainingArguments, MedusaArguments, - EagleArguments, ) ) - model_args, data_args, training_args, medusa_args, eagle_args = ( - parser.parse_args_into_dataclasses() + model_args, data_args, training_args, medusa_args = parser.parse_dict( + hf_cfg, allow_extra_keys=True ) + if not data_args.data_path and not data_args.offline_data_path: - raise ValueError("Either data_path or offline_data_path must be provided.") + raise ValueError( + "Either data.data_path or data.offline_data_path must be set in the config." + ) if training_args.cp_size > 1 or training_args.dp_shard_size > 1: training_args.parallelism_config = ParallelismConfig( cp_size=training_args.cp_size, dp_shard_size=training_args.dp_shard_size @@ -167,7 +190,7 @@ def train(): patch_ring_attention_for_ttt() # Specific patch to accelerate 1.12.0. Removable after move to 1.13.0 training_args.parallelism_config.sp_backend = None - print_rank_0(f"arguments: {model_args}, {training_args}, {medusa_args}, {eagle_args}") + print_rank_0(f"arguments: {model_args}, {training_args}, {medusa_args}, eagle_cfg={eagle_cfg}") # Detect checkpoint to resume from last_checkpoint = ( @@ -213,22 +236,11 @@ def train(): } mtsp.convert(model, [("medusa", config)]) elif training_args.mode == "eagle3": - custom_config = ( - json.load(open(eagle_args.eagle_config)) if eagle_args.eagle_config else {} - ) - - config = { - "eagle_decoder_type": eagle_args.eagle_decoder_type, - "eagle_offline": use_offline_training, - "eagle_mix_hidden_states": eagle_args.mix_hidden_states, - "eagle_use_torch_compile": not eagle_args.disable_torch_compile, - "eagle_ttt_steps": eagle_args.num_ttt_steps, - "eagle_architecture_config": custom_config, - } - - mtsp.convert(model, [("eagle", config)]) + # eagle_cfg maps directly to EagleConfig fields; eagle_offline is derived here. + eagle_cfg["eagle_offline"] = use_offline_training + mtsp.convert(model, [("eagle", eagle_cfg)]) - # read draft vocab cache + # Load draft vocab cache if the draft model uses a compressed vocabulary if model.eagle_config.draft_vocab_size < model.eagle_config.vocab_size: if not os.path.isfile(data_args.draft_vocab_cache): raise FileNotFoundError( diff --git a/examples/speculative_decoding/train_eagle3_and_export.sh b/examples/speculative_decoding/train_eagle3_and_export.sh index 0f5fef9354..b4c00af349 100755 --- a/examples/speculative_decoding/train_eagle3_and_export.sh +++ b/examples/speculative_decoding/train_eagle3_and_export.sh @@ -17,50 +17,90 @@ set -eo pipefail -# Set default values for BASE_MODEL and DATA BASE_MODEL=meta-llama/Llama-3.2-1B-Instruct DATA=input_conversations/train.jsonl -# Parse input arguments --base_model and --data while [[ $# -gt 0 ]]; do - key="$1" - case $key in - --base_model) - BASE_MODEL="$2" - shift; shift - ;; - --data) - DATA="$2" - shift; shift - ;; - --offline_data) - OFFLINE_DATA_PATH="$2" - shift; shift - ;; - *) - echo "Unknown argument: $1" - exit 1 - ;; + case $1 in + --base_model) BASE_MODEL="$2"; shift; shift ;; + --data) DATA="$2"; shift; shift ;; + --offline_data) OFFLINE_DATA_PATH="$2"; shift; shift ;; + *) echo "Unknown argument: $1"; exit 1 ;; esac done -if [[ "$OFFLINE_DATA_PATH" != "" ]]; then - OFFLINE_DATA_ARGS="--offline-data $OFFLINE_DATA_PATH" +MODEL_BASENAME=$(basename "$BASE_MODEL") +OUTPUT_DIR=ckpts/${MODEL_BASENAME}-$(date +%Y%m%d_%H%M) +mkdir -p "$OUTPUT_DIR" + +# Build offline_data_path YAML value +if [[ -n "$OFFLINE_DATA_PATH" ]]; then + OFFLINE_DATA_YAML="\"$OFFLINE_DATA_PATH\"" else - OFFLINE_DATA_ARGS="" + OFFLINE_DATA_YAML="null" fi -MODEL_BASENAME=$(basename "$BASE_MODEL") +# Write config to output dir so it's preserved alongside the checkpoint +YAML_FILE="$OUTPUT_DIR/train_config.yaml" +cat > "$YAML_FILE" << EOF +model: + use_fake_base_for_offline: false + trust_remote_code: false + +data: + data_path: "$DATA" + offline_data_path: $OFFLINE_DATA_YAML + draft_vocab_cache: null + vlm_processor: null + vlm_img_dir: null + +training: + output_dir: "$OUTPUT_DIR" + num_train_epochs: 2 + per_device_train_batch_size: 1 + per_device_eval_batch_size: 1 + learning_rate: 1.0e-4 + weight_decay: 0.0 + warmup_steps: 100 + lr_scheduler_type: linear + gradient_accumulation_steps: 1 + training_seq_len: 2048 + ar_validate_steps: 100000 + logging_steps: 100 + save_steps: 8192 + save_strategy: steps + do_eval: false + eval_accumulation_steps: 1 + dataloader_drop_last: true + bf16: true + tf32: true + remove_unused_columns: false + disable_tqdm: false + estimate_ar: false + mode: eagle3 + cp_size: 1 + dp_shard_size: 0 + fsdp: "" + fsdp_config: + fsdp_version: 2 + +eagle: + eagle_decoder_type: llama + eagle_ttt_steps: 3 + eagle_mix_hidden_states: false + eagle_use_torch_compile: true + eagle_self_logit_distillation: true + eagle_freeze_base_model: true + eagle_loss_decay_factor: 0.9 + eagle_hidden_state_distillation: false + eagle_reuse_base_decoder: false + eagle_report_acc: true + eagle_enable_nvtx: false + eagle_architecture_config: {} +EOF echo "==== [1/3] Training draft model ====" -OUTPUT_DIR=ckpts/${MODEL_BASENAME}-$(date +%Y%m%d_%H%M) -mkdir -p "$(dirname "$OUTPUT_DIR")" -./launch_train.sh --model $BASE_MODEL \ - --output_dir $OUTPUT_DIR \ - $OFFLINE_DATA_ARGS \ - --data $DATA \ - --num_epochs 2 \ - --eagle_config eagle_config.json +./launch_train.sh --config "$YAML_FILE" --model "$BASE_MODEL" echo "==== [2/3] Evaluating ModelOpt checkpoint on MT-Bench ====" python scripts/ar_validate.py --model_path $OUTPUT_DIR diff --git a/tools/launcher/common/eagle3/offline_training.sh b/tools/launcher/common/eagle3/offline_training.sh index 09384a499b..630b9e8f70 100644 --- a/tools/launcher/common/eagle3/offline_training.sh +++ b/tools/launcher/common/eagle3/offline_training.sh @@ -27,7 +27,6 @@ export PATH=$PATH:/workspace/.local/bin trap 'error_handler $0 $LINENO' ERR # ERROR HANDLER bash modules/Model-Optimizer/examples/speculative_decoding/launch_train.sh \ - --model ${HF_MODEL_CKPT} \ ${@} python modules/Model-Optimizer/examples/speculative_decoding/scripts/export_hf_checkpoint.py \ From 1c5706626ab3aff3dce8266eee96093b26b88bb4 Mon Sep 17 00:00:00 2001 From: h-guo18 <67671475+h-guo18@users.noreply.github.com> Date: Mon, 30 Mar 2026 04:08:23 +0000 Subject: [PATCH 02/21] separate base config Signed-off-by: h-guo18 <67671475+h-guo18@users.noreply.github.com> --- .../configs/kimi_k25_offline.yaml | 53 ++------------- .../configs/llama32_1b_online.yaml | 53 ++------------- examples/speculative_decoding/main.py | 65 +++++++++++-------- 3 files changed, 49 insertions(+), 122 deletions(-) diff --git a/examples/speculative_decoding/configs/kimi_k25_offline.yaml b/examples/speculative_decoding/configs/kimi_k25_offline.yaml index 08e3757fd6..1b1d7214d0 100644 --- a/examples/speculative_decoding/configs/kimi_k25_offline.yaml +++ b/examples/speculative_decoding/configs/kimi_k25_offline.yaml @@ -1,56 +1,17 @@ +__base__: _base_eagle3.yaml + model: + model_name_or_path: moonshotai/Kimi-K2.5 use_fake_base_for_offline: true trust_remote_code: true data: - offline_data_path: /home/haoguo/modelopt/examples/speculative_decoding/data/kimi_k2_5_offline - data_path: - draft_vocab_cache: - vlm_processor: - vlm_img_dir: + offline_data_path: training: - output_dir: ckpts/k25-test - num_train_epochs: 50 - per_device_train_batch_size: 1 - per_device_eval_batch_size: 1 - learning_rate: 1.0e-6 - weight_decay: 0.0 - warmup_steps: 100 - lr_scheduler_type: linear - gradient_accumulation_steps: 1 - training_seq_len: 64 - ar_validate_steps: 100000 - logging_steps: 1 - save_steps: 8192 - save_strategy: steps - do_eval: false - eval_accumulation_steps: 1 - dataloader_drop_last: true - bf16: true - tf32: true - remove_unused_columns: false - disable_tqdm: false - estimate_ar: false - mode: eagle3 - cp_size: 1 - dp_shard_size: 0 # 0 = auto: total_gpu / cp_size - fsdp: "" # set to "full_shard" to enable FSDP2 - fsdp_config: - fsdp_version: 2 + output_dir: ckpts/kimi-k25-eagle3 + num_train_epochs: 1 + training_seq_len: 4096 eagle: - # Fields map directly to EagleConfig (modelopt/torch/speculative/config.py). - # eagle_offline is derived from data.offline_data_path; do not set here. eagle_decoder_type: kimik2 - eagle_ttt_steps: 3 - eagle_mix_hidden_states: false - eagle_use_torch_compile: true - eagle_self_logit_distillation: true - eagle_freeze_base_model: true - eagle_loss_decay_factor: 0.9 - eagle_hidden_state_distillation: false - eagle_reuse_base_decoder: false - eagle_report_acc: true - eagle_enable_nvtx: false - eagle_architecture_config: {} diff --git a/examples/speculative_decoding/configs/llama32_1b_online.yaml b/examples/speculative_decoding/configs/llama32_1b_online.yaml index dd9e8ecfb2..62e559c630 100644 --- a/examples/speculative_decoding/configs/llama32_1b_online.yaml +++ b/examples/speculative_decoding/configs/llama32_1b_online.yaml @@ -1,56 +1,13 @@ +__base__: _base_eagle3.yaml + + model: - use_fake_base_for_offline: false + model_name_or_path: meta-llama/Llama-3.2-1B trust_remote_code: false data: - data_path: /home/haoguo/modelopt/examples/speculative_decoding/input_conversations/daring-anteater.jsonl - offline_data_path: - draft_vocab_cache: - vlm_processor: - vlm_img_dir: + data_path: input_conversations/daring-anteater.jsonl training: output_dir: ckpts/llama-3.2-1b-online - num_train_epochs: 1 - per_device_train_batch_size: 1 - per_device_eval_batch_size: 1 - learning_rate: 1.0e-4 - weight_decay: 0.0 - warmup_steps: 100 - lr_scheduler_type: linear - gradient_accumulation_steps: 1 training_seq_len: 512 - ar_validate_steps: 100000 - logging_steps: 100 - save_steps: 8192 - save_strategy: steps - do_eval: false - eval_accumulation_steps: 1 - dataloader_drop_last: true - bf16: true - tf32: true - remove_unused_columns: false - disable_tqdm: false - estimate_ar: false - mode: eagle3 - cp_size: 1 - dp_shard_size: 0 # 0 = auto: total_gpu / cp_size - fsdp: "" # set to "full_shard" to enable FSDP2 - fsdp_config: - fsdp_version: 2 - -eagle: - # Fields map directly to EagleConfig (modelopt/torch/speculative/config.py). - # eagle_offline is derived from data.offline_data_path; do not set here. - eagle_decoder_type: llama - eagle_ttt_steps: 3 - eagle_mix_hidden_states: false - eagle_use_torch_compile: true - eagle_self_logit_distillation: true - eagle_freeze_base_model: true - eagle_loss_decay_factor: 0.9 - eagle_hidden_state_distillation: false - eagle_reuse_base_decoder: false - eagle_report_acc: true - eagle_enable_nvtx: false - eagle_architecture_config: {} diff --git a/examples/speculative_decoding/main.py b/examples/speculative_decoding/main.py index 19ff9a02e4..70a56a888c 100644 --- a/examples/speculative_decoding/main.py +++ b/examples/speculative_decoding/main.py @@ -32,11 +32,11 @@ import argparse import os from dataclasses import dataclass, field +from pathlib import Path from typing import Literal import torch import transformers -import yaml from accelerate import ParallelismConfig from eagle_utils import ( EagleTrainerWithAccLog, @@ -44,6 +44,7 @@ make_eagle_supervised_data_module, patch_ring_attention_for_ttt, ) +from omegaconf import OmegaConf from transformers.trainer_utils import get_last_checkpoint import modelopt.torch.opt as mto @@ -57,12 +58,18 @@ @dataclass class ModelArguments: - model_name_or_path: str | None = field(default="TinyLlama/TinyLlama-1.1B-Chat-v1.0") + model_name_or_path: str | None = field( + default="TinyLlama/TinyLlama-1.1B-Chat-v1.0", + metadata={"help": "HuggingFace model ID or local path to the base model."}, + ) use_fake_base_for_offline: bool = field( - default=False, metadata={"help": "Whether to use fake base for offline training."} + default=False, + metadata={ + "help": "Load model architecture without real base weights. Offline training only." + }, ) trust_remote_code: bool = field( - default=False, metadata={"help": "Whether to trust remote code."} + default=False, metadata={"help": "Trust remote code when loading model."} ) @@ -70,23 +77,18 @@ class ModelArguments: class DataArguments: data_path: str = field( default=None, - metadata={"help": "Path to the training data."}, + metadata={"help": "Path to the online training data."}, ) - eval_data_path: str = field(default=None, metadata={"help": "Path to the evaluation data."}) offline_data_path: str = field( default=None, metadata={ - "help": """Path to the offline training data. Providing this flag sets - `eagle_offline` in the EagleConfig and enables offline training. - The directory should contain many `.pt` files, each containing a pre-processed - data sample. `data_path` should still point to the original conversations file. - """ + "help": "Path to offline training data directory (.pt files). This argument enables offline mode.", }, ) lazy_preprocess: bool = True draft_vocab_cache: str | None = field( default=None, - metadata={"help": "Path to d2t.pt cache file."}, + metadata={"help": "Path to draft vocabulary cache file."}, ) vlm_img_dir: str = field(default=None, metadata={"help": "Path to the VLM image directory."}) vlm_processor: str = field(default=None, metadata={"help": "Path to the VLM processor."}) @@ -94,26 +96,19 @@ class DataArguments: @dataclass class TrainingArguments(transformers.TrainingArguments): - cache_dir: str | None = field(default=None) training_seq_len: int = field( default=2048, metadata={ "help": ( - "Maximum sequence length. Sequences will be right padded (and possibly truncated)." + "Training sequence length. Sequences will be right padded or truncated to this length." ) }, ) - dataloader_drop_last: bool = field(default=True) - bf16: bool = field(default=True) mode: Literal["eagle3", "medusa"] = "eagle3" estimate_ar: bool = field( - default=False, metadata={"help": "Whether to estimate AR during training for logging."} - ) - ar_validate_steps: int = field(default=1000, metadata={"help": "Steps between AR validation."}) - disable_tqdm: bool = field(default=False, metadata={"help": "Disable tqdm progress bar."}) - remove_unused_columns: bool = field( - default=False, metadata={"help": "Set to False to keep extra args for VLM."} + default=False, metadata={"help": "Whether to estimate AR using training accuracy to log."} ) + ar_validate_steps: int = field(default=1000, metadata={"help": "AR validation ."}) cp_size: int = field(default=1, metadata={"help": "Context parallelism size."}) dp_shard_size: int = field(default=1, metadata={"help": "Data parallelism shard size."}) @@ -124,11 +119,15 @@ class MedusaArguments: medusa_num_layers: int | None = field(default=1) -def _parse_cli() -> tuple[str, str]: - """Parse required --config and --model from argv; ignore all other args.""" +def _parse_cli() -> tuple[str, str | None]: + """Parse --config (required) and --model (optional) from argv; ignore all other args.""" p = argparse.ArgumentParser(add_help=False) p.add_argument("--config", required=True, help="Path to the YAML config file.") - p.add_argument("--model", required=True, help="Model name or path (overrides yaml).") + p.add_argument( + "--model", + default=None, + help="Model name or path (overrides yaml model.model_name_or_path).", + ) args, _ = p.parse_known_args() return args.config, args.model @@ -136,12 +135,21 @@ def _parse_cli() -> tuple[str, str]: def _load_config(config_path: str) -> tuple[dict, dict]: """Load training config from a YAML file with sections: model, data, training, eagle. + Supports ``__base__: relative/path.yaml`` for config inheritance via OmegaConf merge. + Returns: hf_cfg: Flat dict from model/data/training sections, for HfArgumentParser.parse_dict() eagle_cfg: Eagle section dict (EagleConfig fields), passed directly to mtsp.convert() """ - with open(config_path) as f: - cfg = yaml.safe_load(f) + override = OmegaConf.load(config_path) + base_path = OmegaConf.select(override, "__base__", default=None) + if base_path is not None: + base = OmegaConf.load(Path(config_path).parent / base_path) + override = OmegaConf.masked_copy(override, [k for k in override if k != "__base__"]) + merged = OmegaConf.merge(base, override) + else: + merged = override + cfg = OmegaConf.to_container(merged, resolve=True) # Eagle section maps directly to EagleConfig fields — no field enumeration needed. # eagle_architecture_config is a nested dict and is included as-is. @@ -164,7 +172,8 @@ def _load_config(config_path: str) -> tuple[dict, dict]: def train(): config_path, model_path = _parse_cli() hf_cfg, eagle_cfg = _load_config(config_path) - hf_cfg["model_name_or_path"] = model_path + if model_path is not None: + hf_cfg["model_name_or_path"] = model_path parser = transformers.HfArgumentParser( ( From dcda29273d5aceceff5bba049a1068b7b84cb4cd Mon Sep 17 00:00:00 2001 From: h-guo18 <67671475+h-guo18@users.noreply.github.com> Date: Mon, 30 Mar 2026 04:18:08 +0000 Subject: [PATCH 03/21] polish Signed-off-by: h-guo18 <67671475+h-guo18@users.noreply.github.com> --- examples/speculative_decoding/launch_train.sh | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/examples/speculative_decoding/launch_train.sh b/examples/speculative_decoding/launch_train.sh index 6d248784d2..c080d13ff2 100755 --- a/examples/speculative_decoding/launch_train.sh +++ b/examples/speculative_decoding/launch_train.sh @@ -18,6 +18,7 @@ # Single GPU: ./launch_train.sh --config configs/my_experiment.yaml # Multi-node: ./launch_train.sh --config configs/my_experiment.yaml --num_nodes 2 --head_node_ip # +# --model overrides model.model_name_or_path in the YAML; if omitted, the YAML value is used. # All training config (model, data, hyperparams, eagle, fsdp) lives in the YAML file. # Only multi-node routing args are passed here; mixed_precision is fixed to bf16. @@ -40,8 +41,8 @@ while [ $# -gt 0 ]; do shift done -if [ -z "$CONFIG_FILE" ] || [ -z "$MODEL" ]; then - >&2 echo "Usage: ./launch_train.sh --config --model [--num_nodes N] [--head_node_ip IP]" +if [ -z "$CONFIG_FILE" ]; then + >&2 echo "Usage: ./launch_train.sh --config [--model ] [--num_nodes N] [--head_node_ip IP]" exit 1 fi @@ -66,9 +67,14 @@ if [[ "$NUM_NODES" != "1" ]]; then --main_process_port 29500" fi +MODEL_ARG="" +if [ -n "$MODEL" ]; then + MODEL_ARG="--model $MODEL" +fi + export TOKENIZERS_PARALLELISM=False set -x start_time=$(date +%s) -sh -c "accelerate launch --mixed_precision bf16 $MULTI_NODE_ARGS ${SCRIPT_DIR}/main.py --config $CONFIG_FILE --model $MODEL" +sh -c "accelerate launch --mixed_precision bf16 $MULTI_NODE_ARGS ${SCRIPT_DIR}/main.py --config $CONFIG_FILE $MODEL_ARG" echo "Total time: $(( $(date +%s) - $start_time )) seconds" From 1875722a89956a102aa8f4eab4bc05fe2f92093b Mon Sep 17 00:00:00 2001 From: h-guo18 <67671475+h-guo18@users.noreply.github.com> Date: Mon, 30 Mar 2026 04:38:29 +0000 Subject: [PATCH 04/21] add file Signed-off-by: h-guo18 <67671475+h-guo18@users.noreply.github.com> --- .../configs/_base_eagle3.yaml | 46 +++++++++++++++++++ 1 file changed, 46 insertions(+) create mode 100644 examples/speculative_decoding/configs/_base_eagle3.yaml diff --git a/examples/speculative_decoding/configs/_base_eagle3.yaml b/examples/speculative_decoding/configs/_base_eagle3.yaml new file mode 100644 index 0000000000..de5efc8f53 --- /dev/null +++ b/examples/speculative_decoding/configs/_base_eagle3.yaml @@ -0,0 +1,46 @@ +# Base config for EAGLE3 training. Override per-model fields in a child YAML via: +# __base__: _base_eagle3.yaml + +# training: maps to TrainingArguments (main.py) — HF TrainingArguments + a few extra fields. +training: + # --- commonly modified --- + mode: eagle3 + output_dir: + num_train_epochs: 1 + per_device_train_batch_size: 1 + learning_rate: 1.0e-4 + warmup_steps: 1000 + training_seq_len: 2048 + logging_steps: 100 + save_steps: 8192 + cp_size: 1 + dp_shard_size: 0 # 0 = auto: total_gpu / cp_size + disable_tqdm: false + estimate_ar: false + ar_validate_steps: -1 + + # --- rarely modified --- + do_eval: false + lr_scheduler_type: linear + save_strategy: steps + weight_decay: 0.0 + dataloader_drop_last: true + bf16: true + tf32: true + remove_unused_columns: false + +# eagle: maps to EagleConfig (modelopt/torch/speculative/config.py). +eagle: + # eagle_offline is derived from data.offline_data_path; do not set here. + eagle_decoder_type: llama + eagle_ttt_steps: 3 + eagle_mix_hidden_states: false + eagle_use_torch_compile: true + eagle_self_logit_distillation: true + eagle_freeze_base_model: true + eagle_loss_decay_factor: 0.9 + eagle_hidden_state_distillation: false + eagle_reuse_base_decoder: false + eagle_report_acc: true + eagle_enable_nvtx: false + eagle_architecture_config: {} From 530256c5dbcd0fcabf0aaf67ce20ebc25b6ac741 Mon Sep 17 00:00:00 2001 From: h-guo18 <67671475+h-guo18@users.noreply.github.com> Date: Mon, 30 Mar 2026 05:04:45 +0000 Subject: [PATCH 05/21] move yaml files to recipe lib Signed-off-by: h-guo18 <67671475+h-guo18@users.noreply.github.com> --- examples/speculative_decoding/launch_train.sh | 4 ++-- .../speculative_decoding}/_base_eagle3.yaml | 0 .../speculative_decoding}/kimi_k25_offline.yaml | 0 .../speculative_decoding}/llama32_1b_online.yaml | 0 4 files changed, 2 insertions(+), 2 deletions(-) rename {examples/speculative_decoding/configs => modelopt_recipes/speculative_decoding}/_base_eagle3.yaml (100%) rename {examples/speculative_decoding/configs => modelopt_recipes/speculative_decoding}/kimi_k25_offline.yaml (100%) rename {examples/speculative_decoding/configs => modelopt_recipes/speculative_decoding}/llama32_1b_online.yaml (100%) diff --git a/examples/speculative_decoding/launch_train.sh b/examples/speculative_decoding/launch_train.sh index c080d13ff2..a633783aa6 100755 --- a/examples/speculative_decoding/launch_train.sh +++ b/examples/speculative_decoding/launch_train.sh @@ -15,8 +15,8 @@ # limitations under the License. # Usage: -# Single GPU: ./launch_train.sh --config configs/my_experiment.yaml -# Multi-node: ./launch_train.sh --config configs/my_experiment.yaml --num_nodes 2 --head_node_ip +# Single GPU: ./launch_train.sh --config ../../modelopt_recipes/speculative_decoding/my_experiment.yaml +# Multi-node: ./launch_train.sh --config ../../modelopt_recipes/speculative_decoding/my_experiment.yaml --num_nodes 2 --head_node_ip # # --model overrides model.model_name_or_path in the YAML; if omitted, the YAML value is used. # All training config (model, data, hyperparams, eagle, fsdp) lives in the YAML file. diff --git a/examples/speculative_decoding/configs/_base_eagle3.yaml b/modelopt_recipes/speculative_decoding/_base_eagle3.yaml similarity index 100% rename from examples/speculative_decoding/configs/_base_eagle3.yaml rename to modelopt_recipes/speculative_decoding/_base_eagle3.yaml diff --git a/examples/speculative_decoding/configs/kimi_k25_offline.yaml b/modelopt_recipes/speculative_decoding/kimi_k25_offline.yaml similarity index 100% rename from examples/speculative_decoding/configs/kimi_k25_offline.yaml rename to modelopt_recipes/speculative_decoding/kimi_k25_offline.yaml diff --git a/examples/speculative_decoding/configs/llama32_1b_online.yaml b/modelopt_recipes/speculative_decoding/llama32_1b_online.yaml similarity index 100% rename from examples/speculative_decoding/configs/llama32_1b_online.yaml rename to modelopt_recipes/speculative_decoding/llama32_1b_online.yaml From c57de6267ac9eb3e12b7860a7237182c84e8dbf3 Mon Sep 17 00:00:00 2001 From: h-guo18 <67671475+h-guo18@users.noreply.github.com> Date: Mon, 30 Mar 2026 05:43:45 +0000 Subject: [PATCH 06/21] update readme Signed-off-by: h-guo18 <67671475+h-guo18@users.noreply.github.com> --- examples/speculative_decoding/README.md | 88 +++++-------------- .../llama32_1b_online.yaml | 13 --- 2 files changed, 23 insertions(+), 78 deletions(-) delete mode 100644 modelopt_recipes/speculative_decoding/llama32_1b_online.yaml diff --git a/examples/speculative_decoding/README.md b/examples/speculative_decoding/README.md index 2a29f644e6..3cc043268c 100644 --- a/examples/speculative_decoding/README.md +++ b/examples/speculative_decoding/README.md @@ -73,14 +73,12 @@ This one-line command runs a minimal example workflow of training and exporting For small base models that fit in GPU memory, we can collocate them with draft models and train with the following command: ```bash -./launch_train.sh --model $BASE_MODEL \ - --output_dir $OUTPUT_DIR \ - --data input_conversations/train.jsonl \ - --num_epochs $NUM_EPOCH \ - --eagle_config eagle_config.json +./launch_train.sh --config ../../modelopt_recipes/speculative_decoding/llama3_eagle_online.yaml ``` -FSDP2 is used by default. To enable context parallelism for long-context training, specify `--cp_size n`. +All training settings (model path, data path, hyperparameters, eagle config) are specified in the YAML file. The `--model` flag can optionally override `model.model_name_or_path` in the YAML. + +To enable context parallelism for long-context training, set `cp_size` in the YAML. The saved modelopt checkpoint is similar in architecture to HF models. It can be further optimized through **ModelOpt**, e.g., PTQ and QAT. ## Training Draft Model with Offline Base Model @@ -113,15 +111,10 @@ python collect_hidden_states/compute_hidden_states_hf.py \ ### Train Draft Model with Dumped Hidden States -Once we finish dumping hidden states, launch offline training with an extra `--offline-data` argument: +Once we finish dumping hidden states, set `data.offline_data_path` in the YAML to the hidden states directory and launch: ```bash -./launch_train.sh --model $BASE_MODEL \ - --output_dir $OUTPUT_DIR \ - --data $DATA \ - --num_epochs $NUM_EPOCH \ - --eagle_config eagle_config.json \ - --offline-data $HIDDEN_STATES_DIR +./launch_train.sh --config ../../modelopt_recipes/speculative_decoding/llama3_eagle_offline.yaml ``` ## Model Validation @@ -244,13 +237,13 @@ For large scale data generation, please see [SLURM prepare data](SLURM_prepare_d ### Configuring Draft Model -For EAGLE‑1 and EAGLE‑3 we provide a [default model architecture config](https://github.com/NVIDIA/Model-Optimizer/blob/main/modelopt/torch/speculative/config.py#L37) in ModelOpt. You can override default settings by providing an additional JSON dict. E.g. To use 2-layer eagle with 8192 intermediate size for MLP, set `eagle_config.json` to: +For EAGLE‑1 and EAGLE‑3 we provide a [default model architecture config](https://github.com/NVIDIA/Model-Optimizer/blob/main/modelopt/torch/speculative/config.py#L37) in ModelOpt. You can override default settings via `eagle.eagle_architecture_config` in the YAML. E.g. to use a 2-layer EAGLE head with 8192 intermediate size: -```json -{ - "num_hidden_layers": 2, - "intermediate_size":8192 -} +```yaml +eagle: + eagle_architecture_config: + num_hidden_layers: 2 + intermediate_size: 8192 ``` ### Draft Vocabulary Compression @@ -263,61 +256,26 @@ python scripts/calibrate_draft_vocab.py --model meta-llama/Llama-3.2-1B-Instruct This will produce a `d2t.pt` file in `save_dir`, which is the mapping from draft token to target token. During inference, draft tokens can be mapped back to target tokens by `target_token = draft_token + d2t[draft_token]`. -Then, simply set `{"draft_vocab_size":32000}` in `eagle_config.json` and include `--draft_vocab_cache ` when running `./launch_train.sh`. The draft model will use this provided vocab table during training and export. +Then, set `eagle_architecture_config.draft_vocab_size: 32000` and `data.draft_vocab_cache: ` in your YAML. The draft model will use this provided vocab table during training and export. ### Interact with `modelopt.torch.speculative` -`main.py` provides an example for converting a HF base model for speculative decoding and training it. It consists of a few simple steps: -First, load the base model and tokenizer from Hugging Face: - -```python -model = transformers.AutoModelForCausalLM.from_pretrained( - "" -) -``` - -Then, load default eagle config and make necessary overwrites: +`main.py` provides a complete example for converting a HF base model for speculative decoding and training it. The core steps are loading the base model, converting it with an eagle config dict, and training with HF Trainer: ```python -# Load default config -config = { - "eagle1": EAGLE1_DEFAULT_CFG, - "eagle3": EAGLE3_DEFAULT_CFG, -}[training_args.mode]["config"] - -# overwrite config with custom config -config["eagle_architecture_config"].update({"": ""}) - -# Mandatory: hidden size, vocab size and max position embeddings must match base model -config["eagle_architecture_config"].update( - { - "hidden_size": model.config.hidden_size, - "vocab_size": model.config.vocab_size, - "max_position_embeddings": model.config.max_position_embeddings, - } -) -``` +import modelopt.torch.speculative as mtsp -Then, we convert model to a speculative decoding model: +# Convert base model in-place to an EAGLE speculative decoding model +eagle_cfg = {"eagle_decoder_type": "llama", ...} # fields from EagleConfig +mtsp.convert(model, [("eagle", eagle_cfg)]) -```python -mtsp.convert(model, [("eagle", config)]) +# Train with HF Trainer as usual +trainer = transformers.Trainer(model=model, ...) +trainer.train() +trainer.save_model("") ``` -This will modify the model in-place with eagle training forward, making it compatible with HF trainer: - -```python -# Create a trainer -trainer = transformers.Trainer(model=model, tokenizer=tokenizer, args=training_args, **data_module) -trainer._move_model_to_device(model, trainer.args.device) - -# Enable HF checkpointing so that the saved model will contain the speculative decoding module -mto.enable_huggingface_checkpointing() - -trainer.train(resume_from_checkpoint=checkpoint) -trainer.save_state() -trainer.save_model("") -``` +See `main.py` for the full example including tokenizer setup, dataset loading, and checkpoint handling. ## Support Matrix diff --git a/modelopt_recipes/speculative_decoding/llama32_1b_online.yaml b/modelopt_recipes/speculative_decoding/llama32_1b_online.yaml deleted file mode 100644 index 62e559c630..0000000000 --- a/modelopt_recipes/speculative_decoding/llama32_1b_online.yaml +++ /dev/null @@ -1,13 +0,0 @@ -__base__: _base_eagle3.yaml - - -model: - model_name_or_path: meta-llama/Llama-3.2-1B - trust_remote_code: false - -data: - data_path: input_conversations/daring-anteater.jsonl - -training: - output_dir: ckpts/llama-3.2-1b-online - training_seq_len: 512 From 4c2451cf9c69b59c29ad693e7a941e47ae980e2a Mon Sep 17 00:00:00 2001 From: h-guo18 <67671475+h-guo18@users.noreply.github.com> Date: Mon, 30 Mar 2026 05:44:39 +0000 Subject: [PATCH 07/21] add files Signed-off-by: h-guo18 <67671475+h-guo18@users.noreply.github.com> --- .../speculative_decoding/llama3_eagle_offline.yaml | 13 +++++++++++++ .../speculative_decoding/llama3_eagle_online.yaml | 13 +++++++++++++ 2 files changed, 26 insertions(+) create mode 100644 modelopt_recipes/speculative_decoding/llama3_eagle_offline.yaml create mode 100644 modelopt_recipes/speculative_decoding/llama3_eagle_online.yaml diff --git a/modelopt_recipes/speculative_decoding/llama3_eagle_offline.yaml b/modelopt_recipes/speculative_decoding/llama3_eagle_offline.yaml new file mode 100644 index 0000000000..bc0f9438ba --- /dev/null +++ b/modelopt_recipes/speculative_decoding/llama3_eagle_offline.yaml @@ -0,0 +1,13 @@ +__base__: _base_eagle3.yaml + + +model: + model_name_or_path: meta-llama/Llama-3.2-1B + trust_remote_code: true + +data: + offline_data_path: + +training: + output_dir: ckpts/llama-3.2-1b-offline + training_seq_len: 4096 diff --git a/modelopt_recipes/speculative_decoding/llama3_eagle_online.yaml b/modelopt_recipes/speculative_decoding/llama3_eagle_online.yaml new file mode 100644 index 0000000000..cc32904cb0 --- /dev/null +++ b/modelopt_recipes/speculative_decoding/llama3_eagle_online.yaml @@ -0,0 +1,13 @@ +__base__: _base_eagle3.yaml + + +model: + model_name_or_path: meta-llama/Llama-3.2-1B + trust_remote_code: true + +data: + data_path: input_conversations/train.jsonl + +training: + output_dir: ckpts/llama-3.2-1b-online + training_seq_len: 4096 From b6b9917c9a1f4030fe8f4307697b2e389429c85f Mon Sep 17 00:00:00 2001 From: h-guo18 <67671475+h-guo18@users.noreply.github.com> Date: Mon, 30 Mar 2026 05:46:22 +0000 Subject: [PATCH 08/21] rename Signed-off-by: h-guo18 <67671475+h-guo18@users.noreply.github.com> --- .../{kimi_k25_offline.yaml => kimi_k25_eagle_offline.yaml} | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename modelopt_recipes/speculative_decoding/{kimi_k25_offline.yaml => kimi_k25_eagle_offline.yaml} (100%) diff --git a/modelopt_recipes/speculative_decoding/kimi_k25_offline.yaml b/modelopt_recipes/speculative_decoding/kimi_k25_eagle_offline.yaml similarity index 100% rename from modelopt_recipes/speculative_decoding/kimi_k25_offline.yaml rename to modelopt_recipes/speculative_decoding/kimi_k25_eagle_offline.yaml From 01d056977c84631a37b33110a8df1f1c389f5b37 Mon Sep 17 00:00:00 2001 From: h-guo18 <67671475+h-guo18@users.noreply.github.com> Date: Mon, 30 Mar 2026 06:31:40 +0000 Subject: [PATCH 09/21] update test Signed-off-by: h-guo18 <67671475+h-guo18@users.noreply.github.com> --- .../speculative_decoding/test_eagle.py | 175 +++++++++++------- 1 file changed, 103 insertions(+), 72 deletions(-) diff --git a/tests/examples/speculative_decoding/test_eagle.py b/tests/examples/speculative_decoding/test_eagle.py index 271241bcb0..4e3e7a62ff 100644 --- a/tests/examples/speculative_decoding/test_eagle.py +++ b/tests/examples/speculative_decoding/test_eagle.py @@ -13,16 +13,18 @@ # See the License for the specific language governing permissions and # limitations under the License. -import json import os from pathlib import Path import pytest import safetensors.torch import torch -from _test_utils.examples.run_command import run_example_command +import yaml +from _test_utils.examples.run_command import MODELOPT_ROOT, run_example_command from packaging.version import Version +_BASE_EAGLE_YAML = str(MODELOPT_ROOT / "modelopt_recipes/speculative_decoding/_base_eagle3.yaml") + from modelopt.torch.export.plugins.hf_spec_export import LLAMA_EAGLE_SINGLE_LAYER @@ -64,6 +66,14 @@ def generate_offline_pt_data( return output_dir +def _write_eagle_yaml(path: Path, cfg: dict) -> Path: + """Write a YAML training config to *path* and return it.""" + path = Path(path) + with open(path, "w") as f: + yaml.safe_dump(cfg, f, default_flow_style=False) + return path + + @pytest.fixture(scope="module") def eagle_output_dir(tmp_path_factory): """Eagle output directory shared in this module.""" @@ -100,7 +110,7 @@ def test_calibrate_draft_vocab(tiny_llama_path, tiny_daring_anteater_path, draft # fmt: off -@pytest.mark.parametrize(("cp_size", "mix_hidden_states"), [(1, "false"), (2, "false"), (1, "true"), (2, "true")]) +@pytest.mark.parametrize(("cp_size", "mix_hidden_states"), [(1, False), (2, False), (1, True), (2, True)]) def test_llama_eagle3(tiny_llama_path, tiny_daring_anteater_path, tmp_path, eagle_output_dir, @@ -112,8 +122,8 @@ def test_llama_eagle3(tiny_llama_path, pytest.skip("cp_size=2 requires at least 2 GPUs, but only {} found.".format(available_gpus)) if cp_size == 2 and not Version(torch.__version__) >= Version("2.10.0"): pytest.skip("cp_size=2 requires torch 2.10.0") - # Create an ultra-tiny EAGLE config for testing to reduce memory usage - tiny_eagle_config = { + + tiny_eagle_arch_config = { "max_position_embeddings": 128, "num_hidden_layers": 1, "intermediate_size": 64, @@ -121,43 +131,50 @@ def test_llama_eagle3(tiny_llama_path, "num_key_value_heads": 2, "head_dim": 64, } - - # Write the tiny config to a temporary file - config_file = tmp_path / f"tiny_eagle_config_cp{cp_size}.json" - with open(config_file, "w") as f: - json.dump(tiny_eagle_config, f) + cfg = { + "__base__": _BASE_EAGLE_YAML, + "model": {"model_name_or_path": str(tiny_llama_path)}, + "data": {"data_path": str(tiny_daring_anteater_path)}, + "training": { + "output_dir": str(eagle_output_dir / f"eagle-tinyllama-cp{cp_size}"), + "num_train_epochs": 0.25, + "learning_rate": 1e-5, + "training_seq_len": 128, + "cp_size": cp_size, + "per_device_train_batch_size": 1, + }, + "eagle": { + "eagle_mix_hidden_states": mix_hidden_states, + "eagle_architecture_config": tiny_eagle_arch_config, + }, + } + yaml_file = _write_eagle_yaml(tmp_path / f"cfg_cp{cp_size}.yaml", cfg) run_example_command( - [ - "./launch_train.sh", - "--model", tiny_llama_path, - "--data", tiny_daring_anteater_path, - "--num_epochs", "0.25", - "--lr", "1e-5", - "--mode", "eagle3", - "--eagle_config", str(config_file), - "--output_dir", eagle_output_dir / f"eagle-tinyllama-cp{cp_size}", - "--training_seq_len", "128", # Match max_position_embeddings - "--cp_size", str(cp_size), - "--mix_hidden_states", mix_hidden_states, - ], + ["./launch_train.sh", "--config", str(yaml_file)], "speculative_decoding", ) -def test_resume_training(tiny_daring_anteater_path, eagle_output_dir): +def test_resume_training(tiny_daring_anteater_path, eagle_output_dir, tmp_path): """Test resume training of Eagle3.""" + checkpoint_dir = eagle_output_dir / "eagle-tinyllama-cp1" + cfg = { + "__base__": _BASE_EAGLE_YAML, + "model": {"model_name_or_path": str(checkpoint_dir)}, + "data": {"data_path": str(tiny_daring_anteater_path)}, + "training": { + "output_dir": str(checkpoint_dir), + "num_train_epochs": 0.5, + "learning_rate": 1e-5, + "training_seq_len": 128, + "per_device_train_batch_size": 1, + }, + "eagle": {}, + } + yaml_file = _write_eagle_yaml(tmp_path / "resume_cfg.yaml", cfg) run_example_command( - [ - "./launch_train.sh", - "--model", eagle_output_dir / "eagle-tinyllama-cp1", - "--data", tiny_daring_anteater_path, - "--num_epochs", "0.5", - "--lr", "1e-5", - "--mode", "eagle3", - "--output_dir", eagle_output_dir / "eagle-tinyllama-cp1", - "--training_seq_len", "128", # Match max_position_embeddings - ], + ["./launch_train.sh", "--config", str(yaml_file)], "speculative_decoding", ) @@ -239,7 +256,7 @@ def test_offline_eagle3_training( num_aux_layers=min(cfg.num_hidden_layers, 3), ) - tiny_eagle_config = { + tiny_eagle_arch_config = { "max_position_embeddings": 128, "num_hidden_layers": 1, "intermediate_size": 64, @@ -247,27 +264,33 @@ def test_offline_eagle3_training( "num_key_value_heads": 2, "head_dim": 64, } - config_file = tmp_path / "tiny_eagle_config_offline.json" - with open(config_file, "w") as f: - json.dump(tiny_eagle_config, f) - - cmd = [ - "./launch_train.sh", - "--model", model_path, - "--data", tiny_daring_anteater_path, - "--offline-data", offline_data_dir, - "--num_epochs", "0.1", - "--lr", "1e-5", - "--mode", "eagle3", - "--eagle_config", str(config_file), - "--output_dir", output_subdir, - "--training_seq_len", "64", - "--trust_remote_code", "True", - "--fsdp", "False", - ] - if use_fake_base: - cmd += ["--use_fake_base_for_offline", "true"] - run_example_command(cmd, "speculative_decoding") + training_cfg = { + "__base__": _BASE_EAGLE_YAML, + "model": { + "model_name_or_path": str(model_path), + "trust_remote_code": True, + "use_fake_base_for_offline": use_fake_base, + }, + "data": { + "data_path": str(tiny_daring_anteater_path), + "offline_data_path": str(offline_data_dir), + }, + "training": { + "output_dir": str(output_subdir), + "num_train_epochs": 0.1, + "learning_rate": 1e-5, + "training_seq_len": 64, + "per_device_train_batch_size": 1, + }, + "eagle": { + "eagle_architecture_config": tiny_eagle_arch_config, + }, + } + yaml_file = _write_eagle_yaml(tmp_path / f"offline_cfg_{model_id}.yaml", training_cfg) + run_example_command( + ["./launch_train.sh", "--config", str(yaml_file)], + "speculative_decoding", + ) assert os.path.exists(output_subdir / "config.json") @@ -277,9 +300,9 @@ def test_offline_resume_training_kimi(tiny_daring_anteater_path, tmp_path, eagle Depends on test_offline_eagle3_training["kimi-k2.5"] having run first. Exercises AutoModelForCausalLM.from_pretrained with model_type='fake_base_model'. """ - import transformers - checkpoint_dir = eagle_output_dir / "eagle-Kimi-K2.5-offline" + + import transformers config = transformers.AutoConfig.from_pretrained(checkpoint_dir, trust_remote_code=True) offline_data_dir = generate_offline_pt_data( @@ -289,20 +312,28 @@ def test_offline_resume_training_kimi(tiny_daring_anteater_path, tmp_path, eagle num_aux_layers=min(config.num_hidden_layers, 3), ) + training_cfg = { + "__base__": _BASE_EAGLE_YAML, + "model": { + "model_name_or_path": str(checkpoint_dir), + "trust_remote_code": True, + "use_fake_base_for_offline": True, + }, + "data": { + "data_path": str(tiny_daring_anteater_path), + "offline_data_path": str(offline_data_dir), + }, + "training": { + "output_dir": str(checkpoint_dir), + "num_train_epochs": 0.2, + "learning_rate": 1e-5, + "training_seq_len": 64, + "per_device_train_batch_size": 1, + }, + "eagle": {}, + } + yaml_file = _write_eagle_yaml(tmp_path / "resume_kimi_cfg.yaml", training_cfg) run_example_command( - [ - "./launch_train.sh", - "--model", checkpoint_dir, - "--data", tiny_daring_anteater_path, - "--offline-data", offline_data_dir, - "--num_epochs", "0.2", - "--lr", "1e-5", - "--mode", "eagle3", - "--output_dir", checkpoint_dir, - "--training_seq_len", "64", - "--trust_remote_code", "True", - "--fsdp", "False", - "--use_fake_base_for_offline", "true", - ], + ["./launch_train.sh", "--config", str(yaml_file)], "speculative_decoding", ) From e52bc6c08642ba101e7c0c38f7eb8a675d33b32b Mon Sep 17 00:00:00 2001 From: h-guo18 <67671475+h-guo18@users.noreply.github.com> Date: Mon, 30 Mar 2026 06:36:41 +0000 Subject: [PATCH 10/21] revert irrelevant Signed-off-by: h-guo18 <67671475+h-guo18@users.noreply.github.com> --- tools/launcher/common/eagle3/offline_training.sh | 1 + 1 file changed, 1 insertion(+) diff --git a/tools/launcher/common/eagle3/offline_training.sh b/tools/launcher/common/eagle3/offline_training.sh index 630b9e8f70..09384a499b 100644 --- a/tools/launcher/common/eagle3/offline_training.sh +++ b/tools/launcher/common/eagle3/offline_training.sh @@ -27,6 +27,7 @@ export PATH=$PATH:/workspace/.local/bin trap 'error_handler $0 $LINENO' ERR # ERROR HANDLER bash modules/Model-Optimizer/examples/speculative_decoding/launch_train.sh \ + --model ${HF_MODEL_CKPT} \ ${@} python modules/Model-Optimizer/examples/speculative_decoding/scripts/export_hf_checkpoint.py \ From 87561805cd188dab1b4b4c6f11487821e3f3880c Mon Sep 17 00:00:00 2001 From: h-guo18 <67671475+h-guo18@users.noreply.github.com> Date: Thu, 2 Apr 2026 01:14:35 +0000 Subject: [PATCH 11/21] address comments Signed-off-by: h-guo18 <67671475+h-guo18@users.noreply.github.com> --- examples/speculative_decoding/main.py | 8 +++++--- .../speculative_decoding/llama3_eagle_offline.yaml | 2 +- .../speculative_decoding/llama3_eagle_online.yaml | 2 +- tests/examples/speculative_decoding/test_eagle.py | 4 ++-- 4 files changed, 9 insertions(+), 7 deletions(-) diff --git a/examples/speculative_decoding/main.py b/examples/speculative_decoding/main.py index 70a56a888c..c782755cb2 100644 --- a/examples/speculative_decoding/main.py +++ b/examples/speculative_decoding/main.py @@ -108,7 +108,7 @@ class TrainingArguments(transformers.TrainingArguments): estimate_ar: bool = field( default=False, metadata={"help": "Whether to estimate AR using training accuracy to log."} ) - ar_validate_steps: int = field(default=1000, metadata={"help": "AR validation ."}) + ar_validate_steps: int = field(default=1000, metadata={"help": "AR validation interval."}) cp_size: int = field(default=1, metadata={"help": "Context parallelism size."}) dp_shard_size: int = field(default=1, metadata={"help": "Data parallelism shard size."}) @@ -128,7 +128,9 @@ def _parse_cli() -> tuple[str, str | None]: default=None, help="Model name or path (overrides yaml model.model_name_or_path).", ) - args, _ = p.parse_known_args() + args, unknown = p.parse_known_args() + if unknown: + print_rank_0(f"Unrecognized arguments will be ignored: {unknown}. ") return args.config, args.model @@ -255,7 +257,7 @@ def train(): raise FileNotFoundError( f"Draft vocab cache provided but not found: {data_args.draft_vocab_cache}" ) - model.eagle_module.d2t = torch.load(data_args.draft_vocab_cache) + model.eagle_module.d2t = torch.load(data_args.draft_vocab_cache, weights_only=True) print_rank_0(f"Loaded draft vocab cache from {data_args.draft_vocab_cache}.") else: raise Exception(f"{training_args.mode} is not supported!") diff --git a/modelopt_recipes/speculative_decoding/llama3_eagle_offline.yaml b/modelopt_recipes/speculative_decoding/llama3_eagle_offline.yaml index bc0f9438ba..77b2015347 100644 --- a/modelopt_recipes/speculative_decoding/llama3_eagle_offline.yaml +++ b/modelopt_recipes/speculative_decoding/llama3_eagle_offline.yaml @@ -3,7 +3,7 @@ __base__: _base_eagle3.yaml model: model_name_or_path: meta-llama/Llama-3.2-1B - trust_remote_code: true + trust_remote_code: false data: offline_data_path: diff --git a/modelopt_recipes/speculative_decoding/llama3_eagle_online.yaml b/modelopt_recipes/speculative_decoding/llama3_eagle_online.yaml index cc32904cb0..ffcaf491cc 100644 --- a/modelopt_recipes/speculative_decoding/llama3_eagle_online.yaml +++ b/modelopt_recipes/speculative_decoding/llama3_eagle_online.yaml @@ -3,7 +3,7 @@ __base__: _base_eagle3.yaml model: model_name_or_path: meta-llama/Llama-3.2-1B - trust_remote_code: true + trust_remote_code: false data: data_path: input_conversations/train.jsonl diff --git a/tests/examples/speculative_decoding/test_eagle.py b/tests/examples/speculative_decoding/test_eagle.py index 4e3e7a62ff..b20bae0d1f 100644 --- a/tests/examples/speculative_decoding/test_eagle.py +++ b/tests/examples/speculative_decoding/test_eagle.py @@ -136,7 +136,7 @@ def test_llama_eagle3(tiny_llama_path, "model": {"model_name_or_path": str(tiny_llama_path)}, "data": {"data_path": str(tiny_daring_anteater_path)}, "training": { - "output_dir": str(eagle_output_dir / f"eagle-tinyllama-cp{cp_size}"), + "output_dir": str(eagle_output_dir / f"eagle-tinyllama-cp{cp_size}-mix{mix_hidden_states}"), "num_train_epochs": 0.25, "learning_rate": 1e-5, "training_seq_len": 128, @@ -158,7 +158,7 @@ def test_llama_eagle3(tiny_llama_path, def test_resume_training(tiny_daring_anteater_path, eagle_output_dir, tmp_path): """Test resume training of Eagle3.""" - checkpoint_dir = eagle_output_dir / "eagle-tinyllama-cp1" + checkpoint_dir = eagle_output_dir / "eagle-tinyllama-cp1-mixFalse" cfg = { "__base__": _BASE_EAGLE_YAML, "model": {"model_name_or_path": str(checkpoint_dir)}, From 3f126c7aece330d7d25ea4d191e98b4b47ac8b7b Mon Sep 17 00:00:00 2001 From: h-guo18 <67671475+h-guo18@users.noreply.github.com> Date: Thu, 2 Apr 2026 23:01:00 +0000 Subject: [PATCH 12/21] clean up yaml level Signed-off-by: h-guo18 <67671475+h-guo18@users.noreply.github.com> --- examples/speculative_decoding/README.md | 12 +++-- examples/speculative_decoding/launch_train.sh | 17 +++---- examples/speculative_decoding/main.py | 50 ++++++++----------- .../train_eagle3_and_export.sh | 1 - .../speculative_decoding/_base_eagle3.yaml | 4 +- .../kimi_k25_eagle_offline.yaml | 17 ------- .../llama3_eagle_offline.yaml | 13 ----- .../llama3_eagle_online.yaml | 13 ----- 8 files changed, 35 insertions(+), 92 deletions(-) delete mode 100644 modelopt_recipes/speculative_decoding/kimi_k25_eagle_offline.yaml delete mode 100644 modelopt_recipes/speculative_decoding/llama3_eagle_offline.yaml delete mode 100644 modelopt_recipes/speculative_decoding/llama3_eagle_online.yaml diff --git a/examples/speculative_decoding/README.md b/examples/speculative_decoding/README.md index 3cc043268c..71d4b22cdf 100644 --- a/examples/speculative_decoding/README.md +++ b/examples/speculative_decoding/README.md @@ -73,12 +73,12 @@ This one-line command runs a minimal example workflow of training and exporting For small base models that fit in GPU memory, we can collocate them with draft models and train with the following command: ```bash -./launch_train.sh --config ../../modelopt_recipes/speculative_decoding/llama3_eagle_online.yaml +bash scripts/train_llama3_online.sh ``` -All training settings (model path, data path, hyperparameters, eagle config) are specified in the YAML file. The `--model` flag can optionally override `model.model_name_or_path` in the YAML. +All training settings are specified in the base YAML (`modelopt_recipes/speculative_decoding/_base_eagle3.yaml`) and overridden via OmegaConf dotlist arguments in the shell script. See `scripts/train_llama3_online.sh` for the full example. -To enable context parallelism for long-context training, set `cp_size` in the YAML. +To enable context parallelism for long-context training, add `training.cp_size=` to the dotlist overrides. The saved modelopt checkpoint is similar in architecture to HF models. It can be further optimized through **ModelOpt**, e.g., PTQ and QAT. ## Training Draft Model with Offline Base Model @@ -111,12 +111,14 @@ python collect_hidden_states/compute_hidden_states_hf.py \ ### Train Draft Model with Dumped Hidden States -Once we finish dumping hidden states, set `data.offline_data_path` in the YAML to the hidden states directory and launch: +Once we finish dumping hidden states, launch offline training with the hidden states directory: ```bash -./launch_train.sh --config ../../modelopt_recipes/speculative_decoding/llama3_eagle_offline.yaml +bash scripts/train_llama3_offline.sh ``` +Edit `scripts/train_llama3_offline.sh` to set `OFFLINE_DATA` to your hidden states directory. + ## Model Validation For online training checkpoints, we can run in-framework evaluation on MT-bench: diff --git a/examples/speculative_decoding/launch_train.sh b/examples/speculative_decoding/launch_train.sh index a633783aa6..3e4b46eb88 100755 --- a/examples/speculative_decoding/launch_train.sh +++ b/examples/speculative_decoding/launch_train.sh @@ -17,8 +17,9 @@ # Usage: # Single GPU: ./launch_train.sh --config ../../modelopt_recipes/speculative_decoding/my_experiment.yaml # Multi-node: ./launch_train.sh --config ../../modelopt_recipes/speculative_decoding/my_experiment.yaml --num_nodes 2 --head_node_ip +# With overrides: ./launch_train.sh --config my.yaml model.model_name_or_path=xxx training.output_dir=yyy # -# --model overrides model.model_name_or_path in the YAML; if omitted, the YAML value is used. +# Extra key=value args are forwarded as OmegaConf dotlist overrides to main.py. # All training config (model, data, hyperparams, eagle, fsdp) lives in the YAML file. # Only multi-node routing args are passed here; mixed_precision is fixed to bf16. @@ -27,22 +28,21 @@ set -eo pipefail SCRIPT_DIR="$(dirname "$(readlink -f "$0")")" CONFIG_FILE="" -MODEL="" NUM_NODES=1 HEAD_NODE_IP="" +EXTRA_ARGS=() while [ $# -gt 0 ]; do case "$1" in --config*) if [[ "$1" != *=* ]]; then shift; fi; CONFIG_FILE="${1#*=}" ;; - --model*) if [[ "$1" != *=* ]]; then shift; fi; MODEL="${1#*=}" ;; --num_nodes*) if [[ "$1" != *=* ]]; then shift; fi; NUM_NODES="${1#*=}" ;; --head_node_ip*) if [[ "$1" != *=* ]]; then shift; fi; HEAD_NODE_IP="${1#*=}" ;; - *) >&2 echo "Error: Unknown argument '$1'"; exit 1 ;; + *) EXTRA_ARGS+=("$1") ;; esac shift done if [ -z "$CONFIG_FILE" ]; then - >&2 echo "Usage: ./launch_train.sh --config [--model ] [--num_nodes N] [--head_node_ip IP]" + >&2 echo "Usage: ./launch_train.sh --config [--num_nodes N] [--head_node_ip IP] [key=value ...]" exit 1 fi @@ -67,14 +67,9 @@ if [[ "$NUM_NODES" != "1" ]]; then --main_process_port 29500" fi -MODEL_ARG="" -if [ -n "$MODEL" ]; then - MODEL_ARG="--model $MODEL" -fi - export TOKENIZERS_PARALLELISM=False set -x start_time=$(date +%s) -sh -c "accelerate launch --mixed_precision bf16 $MULTI_NODE_ARGS ${SCRIPT_DIR}/main.py --config $CONFIG_FILE $MODEL_ARG" +sh -c "accelerate launch --mixed_precision bf16 $MULTI_NODE_ARGS ${SCRIPT_DIR}/main.py --config $CONFIG_FILE ${EXTRA_ARGS[*]}" echo "Total time: $(( $(date +%s) - $start_time )) seconds" diff --git a/examples/speculative_decoding/main.py b/examples/speculative_decoding/main.py index c782755cb2..694aa3303f 100644 --- a/examples/speculative_decoding/main.py +++ b/examples/speculative_decoding/main.py @@ -32,7 +32,6 @@ import argparse import os from dataclasses import dataclass, field -from pathlib import Path from typing import Literal import torch @@ -110,7 +109,10 @@ class TrainingArguments(transformers.TrainingArguments): ) ar_validate_steps: int = field(default=1000, metadata={"help": "AR validation interval."}) cp_size: int = field(default=1, metadata={"help": "Context parallelism size."}) - dp_shard_size: int = field(default=1, metadata={"help": "Data parallelism shard size."}) + dp_shard_size: int | None = field( + default=None, + metadata={"help": "Data parallelism shard size. None = auto (total_gpu / cp_size)."}, + ) @dataclass @@ -119,38 +121,31 @@ class MedusaArguments: medusa_num_layers: int | None = field(default=1) -def _parse_cli() -> tuple[str, str | None]: - """Parse --config (required) and --model (optional) from argv; ignore all other args.""" +def _parse_cli() -> tuple[str, list[str]]: + """Parse --config (required) from argv; return remaining args as config overrides. + + Extra arguments use OmegaConf dotlist syntax, e.g. + ``model.model_name_or_path=meta-llama/Llama-3.2-1B training.output_dir=ckpts/test``. + """ p = argparse.ArgumentParser(add_help=False) p.add_argument("--config", required=True, help="Path to the YAML config file.") - p.add_argument( - "--model", - default=None, - help="Model name or path (overrides yaml model.model_name_or_path).", - ) - args, unknown = p.parse_known_args() - if unknown: - print_rank_0(f"Unrecognized arguments will be ignored: {unknown}. ") - return args.config, args.model + args, overrides = p.parse_known_args() + return args.config, overrides -def _load_config(config_path: str) -> tuple[dict, dict]: +def _load_config(config_path: str, overrides: list[str] = ()) -> tuple[dict, dict]: """Load training config from a YAML file with sections: model, data, training, eagle. - Supports ``__base__: relative/path.yaml`` for config inheritance via OmegaConf merge. + *overrides* are OmegaConf dotlist entries (e.g. ``["model.model_name_or_path=xxx"]``) + applied on top of the YAML. Returns: hf_cfg: Flat dict from model/data/training sections, for HfArgumentParser.parse_dict() eagle_cfg: Eagle section dict (EagleConfig fields), passed directly to mtsp.convert() """ - override = OmegaConf.load(config_path) - base_path = OmegaConf.select(override, "__base__", default=None) - if base_path is not None: - base = OmegaConf.load(Path(config_path).parent / base_path) - override = OmegaConf.masked_copy(override, [k for k in override if k != "__base__"]) - merged = OmegaConf.merge(base, override) - else: - merged = override + merged = OmegaConf.load(config_path) + if overrides: + merged = OmegaConf.merge(merged, OmegaConf.from_dotlist(list(overrides))) cfg = OmegaConf.to_container(merged, resolve=True) # Eagle section maps directly to EagleConfig fields — no field enumeration needed. @@ -163,8 +158,7 @@ def _load_config(config_path: str) -> tuple[dict, dict]: **cfg.get("training", {}), } - # dp_shard_size sentinel: 0 means auto-compute as total_gpu / cp_size - if hf_cfg.get("dp_shard_size", 1) == 0: + if hf_cfg.get("dp_shard_size") is None: cp_size = hf_cfg.get("cp_size", 1) hf_cfg["dp_shard_size"] = torch.cuda.device_count() // cp_size @@ -172,10 +166,8 @@ def _load_config(config_path: str) -> tuple[dict, dict]: def train(): - config_path, model_path = _parse_cli() - hf_cfg, eagle_cfg = _load_config(config_path) - if model_path is not None: - hf_cfg["model_name_or_path"] = model_path + config_path, overrides = _parse_cli() + hf_cfg, eagle_cfg = _load_config(config_path, overrides) parser = transformers.HfArgumentParser( ( diff --git a/examples/speculative_decoding/train_eagle3_and_export.sh b/examples/speculative_decoding/train_eagle3_and_export.sh index b4c00af349..6aa25903db 100755 --- a/examples/speculative_decoding/train_eagle3_and_export.sh +++ b/examples/speculative_decoding/train_eagle3_and_export.sh @@ -79,7 +79,6 @@ training: estimate_ar: false mode: eagle3 cp_size: 1 - dp_shard_size: 0 fsdp: "" fsdp_config: fsdp_version: 2 diff --git a/modelopt_recipes/speculative_decoding/_base_eagle3.yaml b/modelopt_recipes/speculative_decoding/_base_eagle3.yaml index de5efc8f53..6f038da117 100644 --- a/modelopt_recipes/speculative_decoding/_base_eagle3.yaml +++ b/modelopt_recipes/speculative_decoding/_base_eagle3.yaml @@ -1,5 +1,4 @@ -# Base config for EAGLE3 training. Override per-model fields in a child YAML via: -# __base__: _base_eagle3.yaml +# Base config for EAGLE3 training. Override fields via OmegaConf dotlist on the CLI. # training: maps to TrainingArguments (main.py) — HF TrainingArguments + a few extra fields. training: @@ -14,7 +13,6 @@ training: logging_steps: 100 save_steps: 8192 cp_size: 1 - dp_shard_size: 0 # 0 = auto: total_gpu / cp_size disable_tqdm: false estimate_ar: false ar_validate_steps: -1 diff --git a/modelopt_recipes/speculative_decoding/kimi_k25_eagle_offline.yaml b/modelopt_recipes/speculative_decoding/kimi_k25_eagle_offline.yaml deleted file mode 100644 index 1b1d7214d0..0000000000 --- a/modelopt_recipes/speculative_decoding/kimi_k25_eagle_offline.yaml +++ /dev/null @@ -1,17 +0,0 @@ -__base__: _base_eagle3.yaml - -model: - model_name_or_path: moonshotai/Kimi-K2.5 - use_fake_base_for_offline: true - trust_remote_code: true - -data: - offline_data_path: - -training: - output_dir: ckpts/kimi-k25-eagle3 - num_train_epochs: 1 - training_seq_len: 4096 - -eagle: - eagle_decoder_type: kimik2 diff --git a/modelopt_recipes/speculative_decoding/llama3_eagle_offline.yaml b/modelopt_recipes/speculative_decoding/llama3_eagle_offline.yaml deleted file mode 100644 index 77b2015347..0000000000 --- a/modelopt_recipes/speculative_decoding/llama3_eagle_offline.yaml +++ /dev/null @@ -1,13 +0,0 @@ -__base__: _base_eagle3.yaml - - -model: - model_name_or_path: meta-llama/Llama-3.2-1B - trust_remote_code: false - -data: - offline_data_path: - -training: - output_dir: ckpts/llama-3.2-1b-offline - training_seq_len: 4096 diff --git a/modelopt_recipes/speculative_decoding/llama3_eagle_online.yaml b/modelopt_recipes/speculative_decoding/llama3_eagle_online.yaml deleted file mode 100644 index ffcaf491cc..0000000000 --- a/modelopt_recipes/speculative_decoding/llama3_eagle_online.yaml +++ /dev/null @@ -1,13 +0,0 @@ -__base__: _base_eagle3.yaml - - -model: - model_name_or_path: meta-llama/Llama-3.2-1B - trust_remote_code: false - -data: - data_path: input_conversations/train.jsonl - -training: - output_dir: ckpts/llama-3.2-1b-online - training_seq_len: 4096 From 617818286f7d45c5fccd628e00daf1118dadfcb5 Mon Sep 17 00:00:00 2001 From: h-guo18 <67671475+h-guo18@users.noreply.github.com> Date: Thu, 2 Apr 2026 23:05:18 +0000 Subject: [PATCH 13/21] update readme Signed-off-by: h-guo18 <67671475+h-guo18@users.noreply.github.com> --- examples/speculative_decoding/README.md | 20 +++++++++++++------- 1 file changed, 13 insertions(+), 7 deletions(-) diff --git a/examples/speculative_decoding/README.md b/examples/speculative_decoding/README.md index 71d4b22cdf..fb423eb169 100644 --- a/examples/speculative_decoding/README.md +++ b/examples/speculative_decoding/README.md @@ -73,12 +73,16 @@ This one-line command runs a minimal example workflow of training and exporting For small base models that fit in GPU memory, we can collocate them with draft models and train with the following command: ```bash -bash scripts/train_llama3_online.sh +./launch_train.sh \ + --config ../../modelopt_recipes/speculative_decoding/_base_eagle3.yaml \ + model.model_name_or_path=meta-llama/Llama-3.2-1B \ + data.data_path=input_conversations/train.jsonl \ + training.output_dir=ckpts/llama-3.2-1b-online ``` -All training settings are specified in the base YAML (`modelopt_recipes/speculative_decoding/_base_eagle3.yaml`) and overridden via OmegaConf dotlist arguments in the shell script. See `scripts/train_llama3_online.sh` for the full example. +All default training settings live in `_base_eagle3.yaml`; override any field via OmegaConf dotlist arguments on the command line. -To enable context parallelism for long-context training, add `training.cp_size=` to the dotlist overrides. +To enable context parallelism for long-context training, add `training.cp_size=` to the overrides. The saved modelopt checkpoint is similar in architecture to HF models. It can be further optimized through **ModelOpt**, e.g., PTQ and QAT. ## Training Draft Model with Offline Base Model @@ -111,14 +115,16 @@ python collect_hidden_states/compute_hidden_states_hf.py \ ### Train Draft Model with Dumped Hidden States -Once we finish dumping hidden states, launch offline training with the hidden states directory: +Once we finish dumping hidden states, launch offline training pointing to the hidden states directory: ```bash -bash scripts/train_llama3_offline.sh +./launch_train.sh \ + --config ../../modelopt_recipes/speculative_decoding/_base_eagle3.yaml \ + model.model_name_or_path=meta-llama/Llama-3.2-1B \ + data.offline_data_path=$HIDDEN_STATES_DIR \ + training.output_dir=ckpts/llama-3.2-1b-offline ``` -Edit `scripts/train_llama3_offline.sh` to set `OFFLINE_DATA` to your hidden states directory. - ## Model Validation For online training checkpoints, we can run in-framework evaluation on MT-bench: From 059598c87281b41b4f885729733d6854e2dc3d79 Mon Sep 17 00:00:00 2001 From: h-guo18 <67671475+h-guo18@users.noreply.github.com> Date: Thu, 2 Apr 2026 23:09:30 +0000 Subject: [PATCH 14/21] udpate simplified flow Signed-off-by: h-guo18 <67671475+h-guo18@users.noreply.github.com> --- .../train_eagle3_and_export.sh | 72 +++---------------- 1 file changed, 10 insertions(+), 62 deletions(-) diff --git a/examples/speculative_decoding/train_eagle3_and_export.sh b/examples/speculative_decoding/train_eagle3_and_export.sh index 6aa25903db..ff97132617 100755 --- a/examples/speculative_decoding/train_eagle3_and_export.sh +++ b/examples/speculative_decoding/train_eagle3_and_export.sh @@ -33,73 +33,21 @@ MODEL_BASENAME=$(basename "$BASE_MODEL") OUTPUT_DIR=ckpts/${MODEL_BASENAME}-$(date +%Y%m%d_%H%M) mkdir -p "$OUTPUT_DIR" -# Build offline_data_path YAML value +BASE_CFG="$(dirname "$(readlink -f "$0")")/../../modelopt_recipes/speculative_decoding/_base_eagle3.yaml" + +# Build dotlist overrides +OVERRIDES=( + model.model_name_or_path="$BASE_MODEL" + training.output_dir="$OUTPUT_DIR" +) if [[ -n "$OFFLINE_DATA_PATH" ]]; then - OFFLINE_DATA_YAML="\"$OFFLINE_DATA_PATH\"" + OVERRIDES+=( data.offline_data_path="$OFFLINE_DATA_PATH" ) else - OFFLINE_DATA_YAML="null" + OVERRIDES+=( data.data_path="$DATA" ) fi -# Write config to output dir so it's preserved alongside the checkpoint -YAML_FILE="$OUTPUT_DIR/train_config.yaml" -cat > "$YAML_FILE" << EOF -model: - use_fake_base_for_offline: false - trust_remote_code: false - -data: - data_path: "$DATA" - offline_data_path: $OFFLINE_DATA_YAML - draft_vocab_cache: null - vlm_processor: null - vlm_img_dir: null - -training: - output_dir: "$OUTPUT_DIR" - num_train_epochs: 2 - per_device_train_batch_size: 1 - per_device_eval_batch_size: 1 - learning_rate: 1.0e-4 - weight_decay: 0.0 - warmup_steps: 100 - lr_scheduler_type: linear - gradient_accumulation_steps: 1 - training_seq_len: 2048 - ar_validate_steps: 100000 - logging_steps: 100 - save_steps: 8192 - save_strategy: steps - do_eval: false - eval_accumulation_steps: 1 - dataloader_drop_last: true - bf16: true - tf32: true - remove_unused_columns: false - disable_tqdm: false - estimate_ar: false - mode: eagle3 - cp_size: 1 - fsdp: "" - fsdp_config: - fsdp_version: 2 - -eagle: - eagle_decoder_type: llama - eagle_ttt_steps: 3 - eagle_mix_hidden_states: false - eagle_use_torch_compile: true - eagle_self_logit_distillation: true - eagle_freeze_base_model: true - eagle_loss_decay_factor: 0.9 - eagle_hidden_state_distillation: false - eagle_reuse_base_decoder: false - eagle_report_acc: true - eagle_enable_nvtx: false - eagle_architecture_config: {} -EOF - echo "==== [1/3] Training draft model ====" -./launch_train.sh --config "$YAML_FILE" --model "$BASE_MODEL" +./launch_train.sh --config "$BASE_CFG" "${OVERRIDES[@]}" echo "==== [2/3] Evaluating ModelOpt checkpoint on MT-Bench ====" python scripts/ar_validate.py --model_path $OUTPUT_DIR From 101ab0c81a33305ff053e67fb7b68142678c8457 Mon Sep 17 00:00:00 2001 From: h-guo18 <67671475+h-guo18@users.noreply.github.com> Date: Thu, 2 Apr 2026 23:24:25 +0000 Subject: [PATCH 15/21] polish Signed-off-by: h-guo18 <67671475+h-guo18@users.noreply.github.com> --- .../speculative_decoding/_base_eagle3.yaml | 15 +++++++++++++-- 1 file changed, 13 insertions(+), 2 deletions(-) diff --git a/modelopt_recipes/speculative_decoding/_base_eagle3.yaml b/modelopt_recipes/speculative_decoding/_base_eagle3.yaml index 6f038da117..36f01cca47 100644 --- a/modelopt_recipes/speculative_decoding/_base_eagle3.yaml +++ b/modelopt_recipes/speculative_decoding/_base_eagle3.yaml @@ -1,6 +1,17 @@ # Base config for EAGLE3 training. Override fields via OmegaConf dotlist on the CLI. -# training: maps to TrainingArguments (main.py) — HF TrainingArguments + a few extra fields. +# maps to ModelArguments (main.py) +model: + model_name_or_path: meta-llama/Llama-3.2-1B + trust_remote_code: false + use_fake_base_for_offline: false + +# maps to DataArguments (main.py) +data: + data_path: input_conversations/train.jsonl + draft_vocab_cache: + +# maps to TrainingArguments (main.py) training: # --- commonly modified --- mode: eagle3 @@ -27,7 +38,7 @@ training: tf32: true remove_unused_columns: false -# eagle: maps to EagleConfig (modelopt/torch/speculative/config.py). +# maps to EagleConfig (modelopt/torch/speculative/config.py). eagle: # eagle_offline is derived from data.offline_data_path; do not set here. eagle_decoder_type: llama From 9104756b6cffc81ca39eb00b94bc260b9602e059 Mon Sep 17 00:00:00 2001 From: h-guo18 <67671475+h-guo18@users.noreply.github.com> Date: Thu, 2 Apr 2026 23:26:12 +0000 Subject: [PATCH 16/21] polish Signed-off-by: h-guo18 <67671475+h-guo18@users.noreply.github.com> --- modelopt_recipes/speculative_decoding/_base_eagle3.yaml | 1 - 1 file changed, 1 deletion(-) diff --git a/modelopt_recipes/speculative_decoding/_base_eagle3.yaml b/modelopt_recipes/speculative_decoding/_base_eagle3.yaml index 36f01cca47..888f03c482 100644 --- a/modelopt_recipes/speculative_decoding/_base_eagle3.yaml +++ b/modelopt_recipes/speculative_decoding/_base_eagle3.yaml @@ -2,7 +2,6 @@ # maps to ModelArguments (main.py) model: - model_name_or_path: meta-llama/Llama-3.2-1B trust_remote_code: false use_fake_base_for_offline: false From c0330da082cb8feba82ef83424bd648739b49d39 Mon Sep 17 00:00:00 2001 From: h-guo18 <67671475+h-guo18@users.noreply.github.com> Date: Thu, 2 Apr 2026 23:33:15 +0000 Subject: [PATCH 17/21] update tests Signed-off-by: h-guo18 <67671475+h-guo18@users.noreply.github.com> --- modelopt_recipes/speculative_decoding/_base_eagle3.yaml | 1 + tests/examples/speculative_decoding/test_eagle.py | 8 +------- 2 files changed, 2 insertions(+), 7 deletions(-) diff --git a/modelopt_recipes/speculative_decoding/_base_eagle3.yaml b/modelopt_recipes/speculative_decoding/_base_eagle3.yaml index 888f03c482..0d2c8066e2 100644 --- a/modelopt_recipes/speculative_decoding/_base_eagle3.yaml +++ b/modelopt_recipes/speculative_decoding/_base_eagle3.yaml @@ -51,4 +51,5 @@ eagle: eagle_reuse_base_decoder: false eagle_report_acc: true eagle_enable_nvtx: false + # overwrite to modelopt/torch/speculative/eagle/default_config.py eagle_architecture_config: {} diff --git a/tests/examples/speculative_decoding/test_eagle.py b/tests/examples/speculative_decoding/test_eagle.py index b20bae0d1f..426f6a05e2 100644 --- a/tests/examples/speculative_decoding/test_eagle.py +++ b/tests/examples/speculative_decoding/test_eagle.py @@ -20,11 +20,9 @@ import safetensors.torch import torch import yaml -from _test_utils.examples.run_command import MODELOPT_ROOT, run_example_command +from _test_utils.examples.run_command import run_example_command from packaging.version import Version -_BASE_EAGLE_YAML = str(MODELOPT_ROOT / "modelopt_recipes/speculative_decoding/_base_eagle3.yaml") - from modelopt.torch.export.plugins.hf_spec_export import LLAMA_EAGLE_SINGLE_LAYER @@ -132,7 +130,6 @@ def test_llama_eagle3(tiny_llama_path, "head_dim": 64, } cfg = { - "__base__": _BASE_EAGLE_YAML, "model": {"model_name_or_path": str(tiny_llama_path)}, "data": {"data_path": str(tiny_daring_anteater_path)}, "training": { @@ -160,7 +157,6 @@ def test_resume_training(tiny_daring_anteater_path, eagle_output_dir, tmp_path): """Test resume training of Eagle3.""" checkpoint_dir = eagle_output_dir / "eagle-tinyllama-cp1-mixFalse" cfg = { - "__base__": _BASE_EAGLE_YAML, "model": {"model_name_or_path": str(checkpoint_dir)}, "data": {"data_path": str(tiny_daring_anteater_path)}, "training": { @@ -265,7 +261,6 @@ def test_offline_eagle3_training( "head_dim": 64, } training_cfg = { - "__base__": _BASE_EAGLE_YAML, "model": { "model_name_or_path": str(model_path), "trust_remote_code": True, @@ -313,7 +308,6 @@ def test_offline_resume_training_kimi(tiny_daring_anteater_path, tmp_path, eagle ) training_cfg = { - "__base__": _BASE_EAGLE_YAML, "model": { "model_name_or_path": str(checkpoint_dir), "trust_remote_code": True, From 6e4cc176554844df805f1b9cb79bad02ff3609ea Mon Sep 17 00:00:00 2001 From: h-guo18 <67671475+h-guo18@users.noreply.github.com> Date: Fri, 3 Apr 2026 00:01:55 +0000 Subject: [PATCH 18/21] update launcher, move yaml Signed-off-by: h-guo18 <67671475+h-guo18@users.noreply.github.com> --- examples/speculative_decoding/README.md | 6 +++--- examples/speculative_decoding/launch_train.sh | 4 ++-- .../train_eagle3_and_export.sh | 2 +- .../speculative_decoding/eagle3.yaml} | 0 .../common/eagle3/offline_training.sh | 1 - .../Qwen/Qwen3-8B/hf_offline_eagle3.yaml | 21 +++++++------------ 6 files changed, 13 insertions(+), 21 deletions(-) rename modelopt_recipes/{speculative_decoding/_base_eagle3.yaml => general/speculative_decoding/eagle3.yaml} (100%) diff --git a/examples/speculative_decoding/README.md b/examples/speculative_decoding/README.md index fb423eb169..b87bb1f768 100644 --- a/examples/speculative_decoding/README.md +++ b/examples/speculative_decoding/README.md @@ -74,13 +74,13 @@ For small base models that fit in GPU memory, we can collocate them with draft m ```bash ./launch_train.sh \ - --config ../../modelopt_recipes/speculative_decoding/_base_eagle3.yaml \ + --config ../../modelopt_recipes/general/speculative_decoding/eagle3.yaml \ model.model_name_or_path=meta-llama/Llama-3.2-1B \ data.data_path=input_conversations/train.jsonl \ training.output_dir=ckpts/llama-3.2-1b-online ``` -All default training settings live in `_base_eagle3.yaml`; override any field via OmegaConf dotlist arguments on the command line. +All default training settings live in `eagle3.yaml`; override any field via OmegaConf dotlist arguments on the command line. To enable context parallelism for long-context training, add `training.cp_size=` to the overrides. The saved modelopt checkpoint is similar in architecture to HF models. It can be further optimized through **ModelOpt**, e.g., PTQ and QAT. @@ -119,7 +119,7 @@ Once we finish dumping hidden states, launch offline training pointing to the hi ```bash ./launch_train.sh \ - --config ../../modelopt_recipes/speculative_decoding/_base_eagle3.yaml \ + --config ../../modelopt_recipes/general/speculative_decoding/eagle3.yaml \ model.model_name_or_path=meta-llama/Llama-3.2-1B \ data.offline_data_path=$HIDDEN_STATES_DIR \ training.output_dir=ckpts/llama-3.2-1b-offline diff --git a/examples/speculative_decoding/launch_train.sh b/examples/speculative_decoding/launch_train.sh index 3e4b46eb88..41d71d1417 100755 --- a/examples/speculative_decoding/launch_train.sh +++ b/examples/speculative_decoding/launch_train.sh @@ -15,8 +15,8 @@ # limitations under the License. # Usage: -# Single GPU: ./launch_train.sh --config ../../modelopt_recipes/speculative_decoding/my_experiment.yaml -# Multi-node: ./launch_train.sh --config ../../modelopt_recipes/speculative_decoding/my_experiment.yaml --num_nodes 2 --head_node_ip +# Single GPU: ./launch_train.sh --config ../../modelopt_recipes/general/speculative_decoding/eagle3.yaml model.model_name_or_path=xxx +# Multi-node: ./launch_train.sh --config ../../modelopt_recipes/general/speculative_decoding/eagle3.yaml --num_nodes 2 --head_node_ip # With overrides: ./launch_train.sh --config my.yaml model.model_name_or_path=xxx training.output_dir=yyy # # Extra key=value args are forwarded as OmegaConf dotlist overrides to main.py. diff --git a/examples/speculative_decoding/train_eagle3_and_export.sh b/examples/speculative_decoding/train_eagle3_and_export.sh index ff97132617..92ccb6d513 100755 --- a/examples/speculative_decoding/train_eagle3_and_export.sh +++ b/examples/speculative_decoding/train_eagle3_and_export.sh @@ -33,7 +33,7 @@ MODEL_BASENAME=$(basename "$BASE_MODEL") OUTPUT_DIR=ckpts/${MODEL_BASENAME}-$(date +%Y%m%d_%H%M) mkdir -p "$OUTPUT_DIR" -BASE_CFG="$(dirname "$(readlink -f "$0")")/../../modelopt_recipes/speculative_decoding/_base_eagle3.yaml" +BASE_CFG="$(dirname "$(readlink -f "$0")")/../../modelopt_recipes/general/speculative_decoding/eagle3.yaml" # Build dotlist overrides OVERRIDES=( diff --git a/modelopt_recipes/speculative_decoding/_base_eagle3.yaml b/modelopt_recipes/general/speculative_decoding/eagle3.yaml similarity index 100% rename from modelopt_recipes/speculative_decoding/_base_eagle3.yaml rename to modelopt_recipes/general/speculative_decoding/eagle3.yaml diff --git a/tools/launcher/common/eagle3/offline_training.sh b/tools/launcher/common/eagle3/offline_training.sh index 09384a499b..630b9e8f70 100644 --- a/tools/launcher/common/eagle3/offline_training.sh +++ b/tools/launcher/common/eagle3/offline_training.sh @@ -27,7 +27,6 @@ export PATH=$PATH:/workspace/.local/bin trap 'error_handler $0 $LINENO' ERR # ERROR HANDLER bash modules/Model-Optimizer/examples/speculative_decoding/launch_train.sh \ - --model ${HF_MODEL_CKPT} \ ${@} python modules/Model-Optimizer/examples/speculative_decoding/scripts/export_hf_checkpoint.py \ diff --git a/tools/launcher/examples/Qwen/Qwen3-8B/hf_offline_eagle3.yaml b/tools/launcher/examples/Qwen/Qwen3-8B/hf_offline_eagle3.yaml index 934ab2928e..071cfd03a1 100644 --- a/tools/launcher/examples/Qwen/Qwen3-8B/hf_offline_eagle3.yaml +++ b/tools/launcher/examples/Qwen/Qwen3-8B/hf_offline_eagle3.yaml @@ -67,20 +67,13 @@ pipeline: task_2: script: common/eagle3/offline_training.sh args: - - --offline-data /scratchspace/offline_hidden_states - - --data_path None - - --mode eagle3 - - --num_epochs 1 - - --lr 3e-4 - - --save_steps 500000 - - --output_dir /scratchspace/eagle3 - - --train_bs 8 - - --training_seq_len 4096 - - --eagle_config modules/Model-Optimizer/examples/speculative_decoding/eagle_config.json - - --disable_tqdm True - - --ar_validate_steps 500000 - environment: - - HF_MODEL_CKPT: <> + - --config modules/Model-Optimizer/modelopt_recipes/general/speculative_decoding/eagle3.yaml + - model.model_name_or_path=<> + - data.offline_data_path=/scratchspace/offline_hidden_states + - training.output_dir=/scratchspace/eagle3 + - training.training_seq_len=4096 + - training.disable_tqdm=true + - training.ar_validate_steps=500000 slurm_config: _factory_: "slurm_factory" nodes: 1 From 9b8cafdec1ed58616fc6cacb90219c2c77a702aa Mon Sep 17 00:00:00 2001 From: h-guo18 <67671475+h-guo18@users.noreply.github.com> Date: Sat, 4 Apr 2026 03:03:01 +0000 Subject: [PATCH 19/21] add online eagle example in launcher Signed-off-by: h-guo18 <67671475+h-guo18@users.noreply.github.com> --- examples/speculative_decoding/eagle_utils.py | 2 +- tools/launcher/core.py | 2 +- .../Qwen/Qwen3-8B/hf_online_eagle3.yaml | 77 +++++++++++++++++++ tools/launcher/launch.py | 3 +- 4 files changed, 81 insertions(+), 3 deletions(-) create mode 100644 tools/launcher/examples/Qwen/Qwen3-8B/hf_online_eagle3.yaml diff --git a/examples/speculative_decoding/eagle_utils.py b/examples/speculative_decoding/eagle_utils.py index 1bc7df9810..7e40e29b07 100644 --- a/examples/speculative_decoding/eagle_utils.py +++ b/examples/speculative_decoding/eagle_utils.py @@ -194,7 +194,7 @@ class EagleTrainingPlot(TrainerCallback): def __init__(self, ar_validate_steps: int = 1000, estimate_ar: bool = False): self.ar_validate_steps = ar_validate_steps - if wandb and is_master(): + if hasattr(wandb, "init") and is_master(): wandb.init() self.estimate_ar = estimate_ar diff --git a/tools/launcher/core.py b/tools/launcher/core.py index 40e6c94419..3a0aa8ae97 100644 --- a/tools/launcher/core.py +++ b/tools/launcher/core.py @@ -295,7 +295,7 @@ def build_docker_executor( container_mounts += [f"{hf_local}:/hf-local"] scratch_dst = "/scratchspace" - scratch_src = os.path.join(job_dir, experiment_title, experiment_id, task_name) + scratch_src = os.path.join(job_dir, experiment_title, experiment_id) os.makedirs(scratch_src, exist_ok=True) modelopt_dst = slurm_config.modelopt_install_path if modelopt_src_path is None: diff --git a/tools/launcher/examples/Qwen/Qwen3-8B/hf_online_eagle3.yaml b/tools/launcher/examples/Qwen/Qwen3-8B/hf_online_eagle3.yaml new file mode 100644 index 0000000000..024916c789 --- /dev/null +++ b/tools/launcher/examples/Qwen/Qwen3-8B/hf_online_eagle3.yaml @@ -0,0 +1,77 @@ +# EAGLE3 offline speculative decoding pipeline for Qwen3-8B. +# +# 4-step pipeline: +# task_0: Data synthesis — query TRT-LLM server to generate prompt samples +# task_1: Dump hidden states — run target model to capture hidden states +# task_2: Offline training — train the EAGLE3 draft head +# task_3: Benchmark — evaluate speculative decoding speedup via VLLM +# +# All tasks share /scratchspace to pass artifacts between steps. +# +# Usage: +# uv run launch.py --yaml examples/Qwen/Qwen3-8B/hf_offline_eagle3.yaml --yes +# uv run slurm.py --yaml modules/Model-Optimizer/tools/launcher/examples/Qwen/Qwen3-8B/hf_offline_eagle3.yaml --yes + +job_name: Qwen3-8B_EAGLE3_online +pipeline: + allow_to_fail: false + skip: false + note: + + global_vars: + hf_model: /hf-local/Qwen/Qwen3-8B + + task_0: + script: common/eagle3/make_dataset.sh + args: + - -f modules/Model-Optimizer/examples/speculative_decoding/prepare_input_conversations/example_data_config.yaml + - --full-conversations + slurm_config: + _factory_: "slurm_factory" + nodes: 1 + ntasks_per_node: 1 + gpus_per_node: 1 + container: nvcr.io/nvidia/tensorrt-llm/release:1.3.0rc10 + + task_1: + script: common/eagle3/offline_training.sh + args: + - --config modules/Model-Optimizer/modelopt_recipes/general/speculative_decoding/eagle3.yaml + - model.model_name_or_path=<> + - data.data_path=/scratchspace/data/train.jsonl + - training.output_dir=/scratchspace/eagle3 + - training.training_seq_len=4096 + - training.disable_tqdm=true + - training.ar_validate_steps=500000 + environment: + - HOME: /tmp + - TORCHINDUCTOR_CACHE_DIR: /tmp/torch_cache + slurm_config: + _factory_: "slurm_factory" + nodes: 1 + ntasks_per_node: 1 + gpus_per_node: 8 + container: nvcr.io/nvidia/tensorrt-llm/release:1.3.0rc10 + + task_2: + script: common/specdec_bench/quick_check.sh + args: + - --draft_model_dir /scratchspace/export + - --draft_length 3 + - --output_length 4096 + - --engine VLLM + - --tp_size 4 + - --ep_size 1 + - --speculative_algorithm EAGLE3 + - --mtbench /hf-local/HuggingFaceH4/mt_bench_prompts/raw/question.jsonl + - --concurrency 1 + environment: + - HF_MODEL_CKPT: <> + - HOME: /tmp + - TORCHINDUCTOR_CACHE_DIR: /tmp/torch_cache + slurm_config: + _factory_: "slurm_factory" + nodes: 1 + ntasks_per_node: 1 + gpus_per_node: 1 + container: vllm/vllm-openai:latest diff --git a/tools/launcher/launch.py b/tools/launcher/launch.py index 6572a447f5..db8da9f7c7 100644 --- a/tools/launcher/launch.py +++ b/tools/launcher/launch.py @@ -61,10 +61,11 @@ "modules/Megatron-LM/examples/*", "modules/Megatron-LM/*.py", "modules/Model-Optimizer/modelopt/*", + "modules/Model-Optimizer/modelopt_recipes/*", "modules/Model-Optimizer/examples/*", "common/*", ], - relative_path=[LAUNCHER_DIR] * 6, + relative_path=[LAUNCHER_DIR] * 7, ) MODELOPT_SRC_PATH = os.path.join(LAUNCHER_DIR, "modules/Model-Optimizer/modelopt") From 041b012ce156cc7afec6bbff2c98262400962d95 Mon Sep 17 00:00:00 2001 From: h-guo18 <67671475+h-guo18@users.noreply.github.com> Date: Sat, 4 Apr 2026 15:10:12 -0700 Subject: [PATCH 20/21] fix: rope init ddp bug Signed-off-by: h-guo18 <67671475+h-guo18@users.noreply.github.com> --- modelopt/torch/speculative/plugins/transformers.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/modelopt/torch/speculative/plugins/transformers.py b/modelopt/torch/speculative/plugins/transformers.py index 8561a390fc..f82cb97b0c 100644 --- a/modelopt/torch/speculative/plugins/transformers.py +++ b/modelopt/torch/speculative/plugins/transformers.py @@ -287,9 +287,9 @@ def __init__(self, config, decoder_layer_cls, bias=False): num_layers=self.config.parallel_draft_heads_num_layers, ) - def _maybe_init_rope(self): + def _maybe_init_rope(self, device): if self.config.eagle_decoder_type == "llama" and not hasattr(self, "rotary_emb"): - self.rotary_emb = LlamaRotaryEmbedding(config=self.config) + self.rotary_emb = LlamaRotaryEmbedding(config=self.config, device=device) def _expand_first_attn_in_dim(self, first_layer_attn): """Modify qkv projection in first layer to accept 2h hidden size.""" @@ -937,7 +937,7 @@ def forward( base_outputs, ) - self.eagle_module._maybe_init_rope() + self.eagle_module._maybe_init_rope(device=eagle_input_hiddens.device) # ====Run eagle forward with extra training-time-test steps==== for ttt_step in range(self.eagle_ttt_steps): @@ -1070,7 +1070,7 @@ def pseudo_speculative_generate( else: eagle_input_hidden_states = base_model_hidden_states - self.eagle_module._maybe_init_rope() + self.eagle_module._maybe_init_rope(device=eagle_input_hidden_states.device) draft_tokens = [] for step in range(steps): b, seq_length = eagle_ids.shape From 8c04a851fff3260d820ae7407aa7e49dc229f0e3 Mon Sep 17 00:00:00 2001 From: h-guo18 <67671475+h-guo18@users.noreply.github.com> Date: Sun, 5 Apr 2026 04:12:01 +0000 Subject: [PATCH 21/21] debug launcher Signed-off-by: h-guo18 <67671475+h-guo18@users.noreply.github.com> --- examples/speculative_decoding/eagle_utils.py | 4 ++- .../torch/speculative/plugins/transformers.py | 1 + tools/launcher/common/eagle3/make_dataset.sh | 32 +++++++++++++++++++ .../{offline_training.sh => train_eagle.sh} | 0 .../Qwen/Qwen3-8B/hf_offline_eagle3.yaml | 2 +- .../Qwen/Qwen3-8B/hf_online_eagle3.yaml | 5 +-- 6 files changed, 40 insertions(+), 4 deletions(-) create mode 100755 tools/launcher/common/eagle3/make_dataset.sh rename tools/launcher/common/eagle3/{offline_training.sh => train_eagle.sh} (100%) diff --git a/examples/speculative_decoding/eagle_utils.py b/examples/speculative_decoding/eagle_utils.py index 7e40e29b07..99c8ef4e03 100644 --- a/examples/speculative_decoding/eagle_utils.py +++ b/examples/speculative_decoding/eagle_utils.py @@ -44,7 +44,9 @@ try: import wandb -except ImportError: + + wandb.log # Verify wandb is functional (not a stub module). +except (ImportError, AttributeError): wandb = None IGNORE_TOKEN_ID = LabelSmoother.ignore_index diff --git a/modelopt/torch/speculative/plugins/transformers.py b/modelopt/torch/speculative/plugins/transformers.py index f82cb97b0c..0443bc48f9 100644 --- a/modelopt/torch/speculative/plugins/transformers.py +++ b/modelopt/torch/speculative/plugins/transformers.py @@ -288,6 +288,7 @@ def __init__(self, config, decoder_layer_cls, bias=False): ) def _maybe_init_rope(self, device): + # Lazily init rope since rope buffers are not saved with state dict if self.config.eagle_decoder_type == "llama" and not hasattr(self, "rotary_emb"): self.rotary_emb = LlamaRotaryEmbedding(config=self.config, device=device) diff --git a/tools/launcher/common/eagle3/make_dataset.sh b/tools/launcher/common/eagle3/make_dataset.sh new file mode 100755 index 0000000000..b107ba59e5 --- /dev/null +++ b/tools/launcher/common/eagle3/make_dataset.sh @@ -0,0 +1,32 @@ +#!/bin/bash + +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +SCRIPT_DIR="$(dirname "$(readlink -f "$0")")" + +source ${SCRIPT_DIR}/../service_utils.sh + +################################################################################################### + +trap 'error_handler $0 $LINENO' ERR # ERROR HANDLER + +python modules/Model-Optimizer/examples/speculative_decoding/prepare_input_conversations/make_dataset.py \ + ${@} + +mkdir -p /scratchspace/data +mv input_conversations/train.jsonl /scratchspace/data/train.jsonl + + diff --git a/tools/launcher/common/eagle3/offline_training.sh b/tools/launcher/common/eagle3/train_eagle.sh similarity index 100% rename from tools/launcher/common/eagle3/offline_training.sh rename to tools/launcher/common/eagle3/train_eagle.sh diff --git a/tools/launcher/examples/Qwen/Qwen3-8B/hf_offline_eagle3.yaml b/tools/launcher/examples/Qwen/Qwen3-8B/hf_offline_eagle3.yaml index 071cfd03a1..ae2c1e957c 100644 --- a/tools/launcher/examples/Qwen/Qwen3-8B/hf_offline_eagle3.yaml +++ b/tools/launcher/examples/Qwen/Qwen3-8B/hf_offline_eagle3.yaml @@ -65,7 +65,7 @@ pipeline: # Step 3: Train EAGLE3 draft head (offline, single task) task_2: - script: common/eagle3/offline_training.sh + script: common/eagle3/train_eagle.sh args: - --config modules/Model-Optimizer/modelopt_recipes/general/speculative_decoding/eagle3.yaml - model.model_name_or_path=<> diff --git a/tools/launcher/examples/Qwen/Qwen3-8B/hf_online_eagle3.yaml b/tools/launcher/examples/Qwen/Qwen3-8B/hf_online_eagle3.yaml index 024916c789..5b55e9a222 100644 --- a/tools/launcher/examples/Qwen/Qwen3-8B/hf_online_eagle3.yaml +++ b/tools/launcher/examples/Qwen/Qwen3-8B/hf_online_eagle3.yaml @@ -34,7 +34,7 @@ pipeline: container: nvcr.io/nvidia/tensorrt-llm/release:1.3.0rc10 task_1: - script: common/eagle3/offline_training.sh + script: common/eagle3/train_eagle.sh args: - --config modules/Model-Optimizer/modelopt_recipes/general/speculative_decoding/eagle3.yaml - model.model_name_or_path=<> @@ -43,6 +43,7 @@ pipeline: - training.training_seq_len=4096 - training.disable_tqdm=true - training.ar_validate_steps=500000 + - training.num_train_epochs=1 environment: - HOME: /tmp - TORCHINDUCTOR_CACHE_DIR: /tmp/torch_cache @@ -60,7 +61,7 @@ pipeline: - --draft_length 3 - --output_length 4096 - --engine VLLM - - --tp_size 4 + - --tp_size 1 - --ep_size 1 - --speculative_algorithm EAGLE3 - --mtbench /hf-local/HuggingFaceH4/mt_bench_prompts/raw/question.jsonl