Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions examples/speculative_decoding/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,9 @@ class EagleArguments:
eagle_config: str = field(default=None, metadata={"help": "Path to eagle_config.json"})
eagle_decoder_type: str = field(
default="llama",
metadata={"help": "The class of eagle decoder to use. Available options: llama, kimik2"},
metadata={
"help": "The class of eagle decoder to use. Available options: llama, deepseek_v3, kimik2"
},
)


Expand Down Expand Up @@ -189,12 +191,15 @@ def train():
mtsp.convert(model, [("medusa", config)])
elif training_args.mode in ["eagle1", "eagle3"]:
from modelopt.torch.speculative.config import (
deepseek_v3_eagle_default_config,
default_eagle_config,
eagle3_default_config,
kimik2_eagle_default_config,
)

if eagle_args.eagle_decoder_type == "kimik2":
if eagle_args.eagle_decoder_type == "deepseek_v3":
eagle_architecture_config = deepseek_v3_eagle_default_config
elif eagle_args.eagle_decoder_type == "kimik2":
eagle_architecture_config = kimik2_eagle_default_config
else:
eagle_architecture_config = {
Expand Down
8 changes: 6 additions & 2 deletions modelopt/torch/speculative/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,14 @@

from modelopt.torch.opt.config import ModeloptBaseConfig, ModeloptField

from .eagle.default_config import default_eagle_config, default_kimik2_eagle_config
from .eagle.default_config import (
default_deepseek_v3_eagle_config,
default_eagle_config,
default_kimik2_eagle_config,
)

kimik2_eagle_default_config = deepcopy(default_kimik2_eagle_config)

deepseek_v3_eagle_default_config = deepcopy(default_deepseek_v3_eagle_config)
eagle3_default_config = deepcopy(default_eagle_config)
eagle_mtp_default_config = deepcopy(default_eagle_config)

Expand Down
68 changes: 68 additions & 0 deletions modelopt/torch/speculative/eagle/default_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,3 +112,71 @@
"parallel_draft_heads_num_layers": 1,
"has_lm_head": False,
}

default_deepseek_v3_eagle_config = {
"attention_bias": False,
"attention_dropout": 0.0,
"auto_map": {
"AutoConfig": "configuration_deepseek.DeepseekV3Config",
"AutoModel": "modeling_deepseek.DeepseekV3Model",
"AutoModelForCausalLM": "modeling_deepseek.DeepseekV3ForCausalLM",
},
"bos_token_id": 0,
"eos_token_id": 1,
"ep_size": 1,
"first_k_dense_replace": 0,
"hidden_act": "silu",
"hidden_size": 7168,
"initializer_range": 0.02,
"intermediate_size": 18432,
"kv_lora_rank": 512,
"max_position_embeddings": 163840,
"model_type": "deepseek_v3",
"moe_intermediate_size": 2048,
"moe_layer_freq": 1,
"n_group": 8,
"n_routed_experts": 256,
"n_shared_experts": 1,
"norm_topk_prob": True,
"num_attention_heads": 128,
"num_experts_per_tok": 8,
"num_hidden_layers": 1,
"num_key_value_heads": 128,
"num_nextn_predict_layers": 1,
"q_lora_rank": 1536,
"qk_nope_head_dim": 128,
"qk_rope_head_dim": 64,
"qk_head_dim": 192,
"head_dim": 64,
"rms_norm_eps": 1e-06,
"rope_scaling": {
"beta_fast": 32,
"beta_slow": 1,
"factor": 40,
"mscale": 1.0,
"mscale_all_dim": 1.0,
"original_max_position_embeddings": 4096,
"type": "yarn",
},
"rope_interleave": True,
"rope_theta": 10000,
"routed_scaling_factor": 2.5,
"scoring_func": "sigmoid",
"tie_word_embeddings": False,
"topk_group": 4,
"topk_method": "noaux_tc",
"torch_dtype": "bfloat16",
"transformers_version": "4.33.1",
"use_cache": True,
"v_head_dim": 128,
"_attn_implementation": "eager",
"vocab_size": 129280,
"use_input_layernorm_in_first_layer": True,
"use_last_layernorm": True,
"use_aux_hidden_state": True,
"eagle_aux_hidden_state_layer_ids": [],
"use_mtp_layernorm": False,
"parallel_draft_step": 1,
"parallel_draft_heads_num_layers": 1,
"has_lm_head": False,
}
12 changes: 12 additions & 0 deletions modelopt/torch/speculative/plugins/transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,10 @@
from torch.nn import CrossEntropyLoss
from torch.nn.attention.flex_attention import BlockMask, create_block_mask
from transformers import Cache, DynamicCache, PretrainedConfig, PreTrainedModel
from transformers.models.deepseek_v3.modeling_deepseek_v3 import (
DeepseekV3DecoderLayer,
DeepseekV3RotaryEmbedding,
)
from transformers.models.llama.modeling_llama import (
LlamaDecoderLayer,
LlamaRMSNorm,
Expand Down Expand Up @@ -362,6 +366,12 @@ def forward(
config=self.config, device=hidden_states.device
)
position_embeddings = self.rotary_emb(hidden_states, position_ids)
elif self.config.eagle_decoder_type == "deepseek_v3":
if not hasattr(self, "rotary_emb"):
self.rotary_emb = DeepseekV3RotaryEmbedding(
config=self.config, device=hidden_states.device
)
position_embeddings = self.rotary_emb(hidden_states, position_ids)
else:
position_embeddings = None

Expand Down Expand Up @@ -518,6 +528,8 @@ def modify(
if eagle_decoder_type == "llama":
# Use default eagle config
decoder_cls = LlamaDecoderLayer
elif eagle_decoder_type == "deepseek_v3":
decoder_cls = DeepseekV3DecoderLayer
elif eagle_decoder_type == "kimik2":
decoder_cls = _setup_kimi_k2_decoder()

Expand Down