diff --git a/examples/speculative_decoding/main.py b/examples/speculative_decoding/main.py index cd1af9563..7399ee949 100644 --- a/examples/speculative_decoding/main.py +++ b/examples/speculative_decoding/main.py @@ -113,7 +113,9 @@ class EagleArguments: eagle_config: str = field(default=None, metadata={"help": "Path to eagle_config.json"}) eagle_decoder_type: str = field( default="llama", - metadata={"help": "The class of eagle decoder to use. Available options: llama, kimik2"}, + metadata={ + "help": "The class of eagle decoder to use. Available options: llama, deepseek_v3, kimik2" + }, ) @@ -189,12 +191,15 @@ def train(): mtsp.convert(model, [("medusa", config)]) elif training_args.mode in ["eagle1", "eagle3"]: from modelopt.torch.speculative.config import ( + deepseek_v3_eagle_default_config, default_eagle_config, eagle3_default_config, kimik2_eagle_default_config, ) - if eagle_args.eagle_decoder_type == "kimik2": + if eagle_args.eagle_decoder_type == "deepseek_v3": + eagle_architecture_config = deepseek_v3_eagle_default_config + elif eagle_args.eagle_decoder_type == "kimik2": eagle_architecture_config = kimik2_eagle_default_config else: eagle_architecture_config = { diff --git a/modelopt/torch/speculative/config.py b/modelopt/torch/speculative/config.py index 41987d4e4..6dc23cf2e 100644 --- a/modelopt/torch/speculative/config.py +++ b/modelopt/torch/speculative/config.py @@ -19,10 +19,14 @@ from modelopt.torch.opt.config import ModeloptBaseConfig, ModeloptField -from .eagle.default_config import default_eagle_config, default_kimik2_eagle_config +from .eagle.default_config import ( + default_deepseek_v3_eagle_config, + default_eagle_config, + default_kimik2_eagle_config, +) kimik2_eagle_default_config = deepcopy(default_kimik2_eagle_config) - +deepseek_v3_eagle_default_config = deepcopy(default_deepseek_v3_eagle_config) eagle3_default_config = deepcopy(default_eagle_config) eagle_mtp_default_config = deepcopy(default_eagle_config) diff --git a/modelopt/torch/speculative/eagle/default_config.py b/modelopt/torch/speculative/eagle/default_config.py index f8c4924c1..2990de297 100644 --- a/modelopt/torch/speculative/eagle/default_config.py +++ b/modelopt/torch/speculative/eagle/default_config.py @@ -112,3 +112,71 @@ "parallel_draft_heads_num_layers": 1, "has_lm_head": False, } + +default_deepseek_v3_eagle_config = { + "attention_bias": False, + "attention_dropout": 0.0, + "auto_map": { + "AutoConfig": "configuration_deepseek.DeepseekV3Config", + "AutoModel": "modeling_deepseek.DeepseekV3Model", + "AutoModelForCausalLM": "modeling_deepseek.DeepseekV3ForCausalLM", + }, + "bos_token_id": 0, + "eos_token_id": 1, + "ep_size": 1, + "first_k_dense_replace": 0, + "hidden_act": "silu", + "hidden_size": 7168, + "initializer_range": 0.02, + "intermediate_size": 18432, + "kv_lora_rank": 512, + "max_position_embeddings": 163840, + "model_type": "deepseek_v3", + "moe_intermediate_size": 2048, + "moe_layer_freq": 1, + "n_group": 8, + "n_routed_experts": 256, + "n_shared_experts": 1, + "norm_topk_prob": True, + "num_attention_heads": 128, + "num_experts_per_tok": 8, + "num_hidden_layers": 1, + "num_key_value_heads": 128, + "num_nextn_predict_layers": 1, + "q_lora_rank": 1536, + "qk_nope_head_dim": 128, + "qk_rope_head_dim": 64, + "qk_head_dim": 192, + "head_dim": 64, + "rms_norm_eps": 1e-06, + "rope_scaling": { + "beta_fast": 32, + "beta_slow": 1, + "factor": 40, + "mscale": 1.0, + "mscale_all_dim": 1.0, + "original_max_position_embeddings": 4096, + "type": "yarn", + }, + "rope_interleave": True, + "rope_theta": 10000, + "routed_scaling_factor": 2.5, + "scoring_func": "sigmoid", + "tie_word_embeddings": False, + "topk_group": 4, + "topk_method": "noaux_tc", + "torch_dtype": "bfloat16", + "transformers_version": "4.33.1", + "use_cache": True, + "v_head_dim": 128, + "_attn_implementation": "eager", + "vocab_size": 129280, + "use_input_layernorm_in_first_layer": True, + "use_last_layernorm": True, + "use_aux_hidden_state": True, + "eagle_aux_hidden_state_layer_ids": [], + "use_mtp_layernorm": False, + "parallel_draft_step": 1, + "parallel_draft_heads_num_layers": 1, + "has_lm_head": False, +} diff --git a/modelopt/torch/speculative/plugins/transformers.py b/modelopt/torch/speculative/plugins/transformers.py index 39df7b9b7..7082bfe61 100644 --- a/modelopt/torch/speculative/plugins/transformers.py +++ b/modelopt/torch/speculative/plugins/transformers.py @@ -38,6 +38,10 @@ from torch.nn import CrossEntropyLoss from torch.nn.attention.flex_attention import BlockMask, create_block_mask from transformers import Cache, DynamicCache, PretrainedConfig, PreTrainedModel +from transformers.models.deepseek_v3.modeling_deepseek_v3 import ( + DeepseekV3DecoderLayer, + DeepseekV3RotaryEmbedding, +) from transformers.models.llama.modeling_llama import ( LlamaDecoderLayer, LlamaRMSNorm, @@ -362,6 +366,12 @@ def forward( config=self.config, device=hidden_states.device ) position_embeddings = self.rotary_emb(hidden_states, position_ids) + elif self.config.eagle_decoder_type == "deepseek_v3": + if not hasattr(self, "rotary_emb"): + self.rotary_emb = DeepseekV3RotaryEmbedding( + config=self.config, device=hidden_states.device + ) + position_embeddings = self.rotary_emb(hidden_states, position_ids) else: position_embeddings = None @@ -518,6 +528,8 @@ def modify( if eagle_decoder_type == "llama": # Use default eagle config decoder_cls = LlamaDecoderLayer + elif eagle_decoder_type == "deepseek_v3": + decoder_cls = DeepseekV3DecoderLayer elif eagle_decoder_type == "kimik2": decoder_cls = _setup_kimi_k2_decoder()