diff --git a/tensorrt_llm/_torch/models/__init__.py b/tensorrt_llm/_torch/models/__init__.py index 263f0e162c8..d3bcc331ce4 100644 --- a/tensorrt_llm/_torch/models/__init__.py +++ b/tensorrt_llm/_torch/models/__init__.py @@ -7,6 +7,7 @@ from .modeling_exaone4 import Exaone4ForCausalLM from .modeling_gemma3 import Gemma3ForCausalLM from .modeling_gemma3vl import Gemma3VLM +from .modeling_glm import Glm4MoeForCausalLM from .modeling_gpt_oss import GptOssForCausalLM from .modeling_hunyuan_dense import HunYuanDenseV1ForCausalLM from .modeling_hunyuan_moe import HunYuanMoEV1ForCausalLM @@ -70,6 +71,7 @@ "Qwen3NextForCausalLM", "GptOssForCausalLM", "SeedOssForCausalLM", + "Glm4MoeForCausalLM", ] if transformers.__version__ >= "4.45.1": diff --git a/tensorrt_llm/_torch/models/modeling_glm.py b/tensorrt_llm/_torch/models/modeling_glm.py new file mode 100644 index 00000000000..21eaf0a9feb --- /dev/null +++ b/tensorrt_llm/_torch/models/modeling_glm.py @@ -0,0 +1,907 @@ +import math +import os +from typing import Dict, List, Optional, Tuple + +import torch +from torch import nn +from tqdm import tqdm +from transformers import PretrainedConfig + +from tensorrt_llm._ipc_utils import can_access_peer +from tensorrt_llm._utils import get_sm_version, is_sm_100f +from tensorrt_llm.functional import PositionEmbeddingType +from tensorrt_llm.models.modeling_utils import QuantConfig +from tensorrt_llm.quantization.mode import QuantAlgo +from tensorrt_llm.quantization.utils.fp8_utils import ( + resmooth_to_fp8_e8m0, transform_sf_into_required_layout) + +from ..attention_backend import AttentionMetadata +from ..attention_backend.interface import PositionalEmbeddingParams, RopeParams +from ..distributed import (AllReduce, AllReduceFusionOp, AllReduceParams, + MoEAllReduce, MoEAllReduceParams) +from ..model_config import ModelConfig +from ..modules.decoder_layer import DecoderLayer +from ..modules.embedding import Embedding +from ..modules.fused_moe import MoEWeightLoadingMode, create_moe +from ..modules.gated_mlp import GatedMLP +from ..modules.linear import Linear, TensorParallelMode +from ..modules.multi_stream_utils import maybe_execute_in_parallel +from ..modules.qk_norm_attention import QKNormRoPEAttention +from ..modules.rms_norm import RMSNorm +from ..speculative import SpecMetadata +from ..utils import AuxStreamType, EventType, Fp4QuantizedTensor +from .modeling_deepseekv3 import (DeepseekV3Gate, DeepseekV3MTPHead, + moe_reduce_add_shared_output) +from .modeling_speculative import SpecDecOneEngineForCausalLM +from .modeling_utils import (DecoderModel, EagerFusionConfig, + _load_weights_impl, register_auto_model) + + +class Glm4Attention(QKNormRoPEAttention): + + def __init__( + self, + model_config: ModelConfig[PretrainedConfig], + layer_idx: Optional[int] = None, + ): + config = model_config.pretrained_config + pos_embd_params = PositionalEmbeddingParams( + type=PositionEmbeddingType.yarn, + rope=RopeParams.from_config(config), + ) + + super().__init__( + hidden_size=config.hidden_size, + num_attention_heads=config.num_attention_heads, + num_key_value_heads=config.num_key_value_heads, + max_position_embeddings=config.max_position_embeddings, + bias=config.attention_bias, + pos_embd_params=pos_embd_params, + fuse_qk_norm_rope=False, + layer_idx=layer_idx, + dtype=config.torch_dtype, + dense_bias=False, + config=model_config, + ) + + +class Glm4MoE(nn.Module): + + def __init__(self, + *, + num_experts: int, + top_k: int, + hidden_size: int, + intermediate_size: int, + shared_expert_intermediate_size: int, + aux_stream_dict: Dict[AuxStreamType, torch.cuda.Stream], + dtype: Optional[torch.dtype] = None, + model_config: ModelConfig = ModelConfig(), + override_quant_config: Optional[QuantConfig] = None, + layer_idx: Optional[int] = None): + from ..distributed import AllReduce + + super().__init__() + config = model_config.pretrained_config + self.top_k = top_k + self.use_dp = model_config.mapping.enable_attention_dp + self.gate = DeepseekV3Gate( + hidden_size, + num_experts, + top_k=top_k, + n_group=config.n_group, + topk_group=config.topk_group, + routed_scaling_factor=config.routed_scaling_factor, + dtype=dtype, + fuse_routing_kernel=False, + apply_routing=False, + moe_backend=model_config.moe_backend) + self.experts = create_moe( + num_experts=num_experts, + routing_method=self.gate.routing_method, + hidden_size=hidden_size, + intermediate_size=intermediate_size, + dtype=dtype, + reduce_results= + False, # In both low‑latency and attention‑DP modes, FusedMoE skips the in‑op all‑reduce. + model_config=model_config, + override_quant_config=override_quant_config, + aux_stream_dict=aux_stream_dict, + layer_idx=layer_idx, + weight_loading_mode=MoEWeightLoadingMode.VANILLA, + ) + + self.mapping = model_config.mapping + + # FIXME: incompatible with mixed quantization mode (including excluding modules from quantization) + block_size = 1 + if model_config.quant_config and model_config.quant_config.quant_algo and model_config.quant_config.group_size is not None: + block_size = model_config.quant_config.group_size + + shared_tp_size, self.shared_output_scale = self._compute_shared_expert_tp_size( + shared_expert_intermediate_size, block_size) + + self.shared_experts = GatedMLP( + hidden_size=hidden_size, + intermediate_size=shared_expert_intermediate_size, + bias=False, + dtype=dtype, + config=model_config, + overridden_tp_size=shared_tp_size, + reduce_output=False) + + self.allreduce = AllReduce(mapping=model_config.mapping, + strategy=model_config.allreduce_strategy) + self.aux_stream = aux_stream_dict[AuxStreamType.MoeShared] + self.event_dict = { + key: torch.cuda.Event() + for key in [EventType.Main, EventType.MoeShared] + } + + def _compute_shared_expert_tp_size( + self, intermediate_size: int, + block_size: int) -> tuple[int, float | None]: + """ + In the case of GLM4, the TP size of MLP is capped by intermediate_size // block_size. + For example, when the intermediate_size is 2048 and block scaling size is 128, + TP sizes are limited to {1, 2, 4, 8, 16} because of 2048/128 = 16. + + Args: + intermediate_size (int): MLP intermediate size. + block_size (int): The quantization block scale size. For NVFP4, it's 16. + + Returns: + tuple[int, float | None]: A tuple containing (shared_tp_size, shared_output_scale). + - shared_tp_size: The computed TP size. + - shared_output_scale: The output scale factor, or None if not needed. + """ + + assert intermediate_size % block_size == 0, "intermediate_size must be divisible by block_size." + + shared_output_scale = None + # The block scale size is 128, which requires shared_expert_intermediate_size to be divisible by 128. + if self.use_dp: + # If using attention DP, the shared experts also use DP instead of TP. + shared_tp_size = 1 + else: + # Due to the restriction of block scale size (i.e., 128), the supported TP sizes only include 1, 2, 4, 8, and 16. + # The math.gcd operation ensures that shared_tp_size falls in the supported TP sizes. + shared_tp_size = math.gcd( + intermediate_size // block_size, + self.mapping.tp_size, + ) + # If shared_tp_size has been overridden, the output of shared experts needs to be scaled down accordingly before all-reduce. + if shared_tp_size != self.mapping.tp_size: + shared_output_scale = shared_tp_size / self.mapping.tp_size + + return shared_tp_size, shared_output_scale + + @staticmethod + def _get_experts_quant_config(model_config, layer_idx: int) -> QuantConfig: + if getattr(model_config, "quant_config_dict", None) is None: + return model_config.quant_config + return model_config.quant_config_dict.get( + f"model.layers.{layer_idx}.mlp.experts", model_config.quant_config) + + def compute_routed_output(self, hidden_states, hidden_states_fp4, + all_rank_num_tokens, do_finalize): + # max-throughput + use_dp_padding = False + # Add DP padding on SM120 for context comm performance + # TODO: Move this model-agonostic part to MoE + if self.use_dp and self.mapping.tp_size > 1 and get_sm_version() == 120: + use_dp_padding = True + hidden_states = torch.nn.functional.pad( + hidden_states, + (0, 0, 0, max(all_rank_num_tokens) - hidden_states.shape[0])) + + router_logits = self.gate(hidden_states) + + routed_output = self.experts( + hidden_states_fp4 + if hidden_states_fp4 is not None else hidden_states, + router_logits, + do_finalize=do_finalize, + output_dtype=hidden_states.dtype, + all_rank_num_tokens=all_rank_num_tokens, + use_dp_padding=use_dp_padding, + ) + + return routed_output + + def forward( + self, + hidden_states: torch.Tensor, + hidden_states_fp4: Optional[Fp4QuantizedTensor] = None, + all_rank_num_tokens: Optional[list[int]] = None, + final_all_reduce_params: Optional[AllReduceParams] = None, + do_finalize: Optional[bool] = True, + ) -> torch.Tensor: + if not do_finalize: + assert not self.use_dp + + def _compute_shared_output(): + shared_output = self.shared_experts( + hidden_states_fp4 + if hidden_states_fp4 is not None else hidden_states) + if self.shared_output_scale is not None: + shared_output *= self.shared_output_scale + return shared_output + + def _compute_routed_output(): + routed_output = self.compute_routed_output(hidden_states, + hidden_states_fp4, + all_rank_num_tokens, + do_finalize) + return routed_output + + # NOTE: define compiled helpers at module scope to avoid defining decorators inside compiled frames + + routed_output, shared_output = maybe_execute_in_parallel( + _compute_routed_output, _compute_shared_output, + self.event_dict[EventType.Main], + self.event_dict[EventType.MoeShared], self.aux_stream) + + if not do_finalize: + return [shared_output, *routed_output] + else: + if routed_output.dim() == 3: + assert shared_output.numel( + ) * self.top_k == routed_output.numel( + ), 'unmatched tensor shape' + final_hidden_states = moe_reduce_add_shared_output( + routed_output, shared_output) + else: + assert shared_output.size() == routed_output.size( + ), 'unmatched tensor shape' + final_hidden_states = shared_output + routed_output + + if not self.use_dp and self.mapping.tp_size > 1: + final_hidden_states = self.allreduce( + final_hidden_states, + all_reduce_params=final_all_reduce_params) + + return final_hidden_states + + +class Glm4DecoderLayer(DecoderLayer): + + def __init__(self, + model_config: ModelConfig[PretrainedConfig], + layer_idx: int, + aux_stream_dict: Dict[AuxStreamType, torch.cuda.Stream], + is_separate_draft_engine: bool = False): + super().__init__() + self.model_config = model_config + self.config = model_config.pretrained_config + config = self.config + + self.hidden_size = config.hidden_size + self.moe_intermediate_size = config.moe_intermediate_size + self.num_experts = config.n_routed_experts + self.num_shared_experts = config.n_shared_experts + self.top_k = config.num_experts_per_tok + + self.mapping = model_config.mapping + mapping = self.mapping + layer_idx_for_attention = layer_idx + if is_separate_draft_engine: + #KVCacheManager only support 1 layer for separate draft engine + layer_idx_for_attention = layer_idx - model_config.pretrained_config.num_hidden_layers + + self.self_attn = Glm4Attention( + model_config, + layer_idx=layer_idx_for_attention, + ) + self.enable_attention_dp = mapping.enable_attention_dp + + self.mlp_tp_size = mapping.tp_size + self.is_p2p_supported = can_access_peer(mapping) + + self.fusion_config = EagerFusionConfig() + self.enable_fusion = os.environ.get("TRTLLM_GLM_EAGER_FUSION_DISABLED", + "0") == "0" + self.enable_fusion &= not self.enable_attention_dp + + # FIXME: incompatible with mixed quantization mode + quant_config = self._get_decoder_layer_quant_config( + model_config, layer_idx) + self.is_nvfp4 = quant_config.layer_quant_mode.has_nvfp4() + assert ( + quant_config.quant_algo + is not QuantAlgo.MIXED_PRECISION), "MIXED_PRECISION is ambiguous" + + has_tp = mapping.has_tp() + self.allreduce = AllReduce(mapping=model_config.mapping, + strategy=model_config.allreduce_strategy, + dtype=config.torch_dtype) + self.moe_allreduce = MoEAllReduce(self.mapping) + + if (config.n_routed_experts is not None + and layer_idx >= config.first_k_dense_replace): + + self.fusion_config.PRE_MOE_FUSION = self.enable_fusion and has_tp + self.fusion_config.POST_MOE_FUSION = self.fusion_config.PRE_MOE_FUSION + + self.mlp = Glm4MoE( + num_experts=self.num_experts, + top_k=self.top_k, + hidden_size=self.hidden_size, + intermediate_size=self.moe_intermediate_size, + shared_expert_intermediate_size=self.moe_intermediate_size * + self.num_shared_experts, + dtype=config.torch_dtype, + model_config=model_config, + override_quant_config=quant_config, + aux_stream_dict=aux_stream_dict, + layer_idx=layer_idx) + else: + block_size = 1 + if quant_config and quant_config.quant_algo and quant_config.group_size is not None: + block_size = quant_config.group_size + self.mlp_tp_size = self._compute_mlp_tp_size( + config.intermediate_size, block_size) + + has_mlp_tp = self.mlp_tp_size > 1 + self.fusion_config.PRE_MLP_FUSION = self.enable_fusion and has_mlp_tp and self.is_nvfp4 + self.fusion_config.POST_MLP_FUSION = self.enable_fusion and has_mlp_tp + + self.mlp = GatedMLP(hidden_size=config.hidden_size, + intermediate_size=config.intermediate_size, + bias=False, + dtype=config.torch_dtype, + config=model_config, + overridden_tp_size=self.mlp_tp_size, + reduce_output=True) + + self.input_layernorm = RMSNorm(hidden_size=config.hidden_size, + eps=config.rms_norm_eps, + dtype=config.torch_dtype) + + self.disable_attn_allreduce = (self.fusion_config.PRE_MOE_FUSION + or self.fusion_config.PRE_MLP_FUSION + or self.mapping.tp_size == 1 + or self.enable_attention_dp) + + self.post_attention_layernorm = RMSNorm(hidden_size=config.hidden_size, + eps=config.rms_norm_eps, + dtype=config.torch_dtype) + self.layer_idx = layer_idx + self.next_layer_layernorm: RMSNorm = None + + def _get_decoder_layer_quant_config( + self, model_config: ModelConfig[PretrainedConfig], layer_idx: int): + """ + The MTP layer in the nvfp4 checkpoint is unquantized. Because the TRTLLM + moe_backend only supports fp8/fp4 quantization, we need to override + the quant_config for the MTP layer. + """ + quant_config = model_config.quant_config + + layer_name = f"model.layers.{layer_idx}" + if quant_config.is_module_excluded_from_quantization(layer_name): + return QuantConfig( + quant_algo=None, + kv_cache_quant_algo=quant_config.kv_cache_quant_algo, + ) + else: + return model_config.quant_config + + def _compute_mlp_tp_size(self, intermediate_size: int, + block_size: int) -> int: + """ + For GLM4, MLP TP size is limited by intermediate_size // block_size + and must also be multiples of gpus_per_node to avoid expensive inter‑node allreduce. + + Args: + intermediate_size (int): MLP intermediate size. + block_size (int): The quantization block scale size. For NVFP4, it's 16. + + Returns: + int: The computed tp_size. + """ + + assert intermediate_size % block_size == 0, "intermediate_size must be divisible by block_size." + if self.enable_attention_dp: + # If using attention DP, the MLP also uses DP instead of TP. + mlp_tp_size = 1 + else: + # The two math.gcd operations ensure that mlp_tp_size falls in the candidate TP sizes. + tp = math.gcd( + intermediate_size // block_size, + self.mapping.tp_size, + ) + + if tp > self.mapping.gpus_per_node: + mlp_tp_size = math.gcd( + tp, + self.mapping.gpus_per_node, + ) # Avoid costly inter-node TP + else: + mlp_tp_size = tp + return mlp_tp_size + + def forward( + self, + position_ids: torch.IntTensor, + hidden_states: torch.Tensor, + attn_metadata: AttentionMetadata, + residual: torch.Tensor, + spec_metadata: Optional[SpecMetadata] = None, + **kwargs, + ) -> Tuple[torch.Tensor, torch.Tensor]: + if residual is None: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + # Self Attention + hidden_states = self.self_attn( + position_ids=position_ids, + hidden_states=hidden_states, + attn_metadata=attn_metadata, + all_reduce_params=AllReduceParams( + enable_allreduce=not (self.disable_attn_allreduce)), + **kwargs, + ) + if isinstance(self.mlp, Glm4MoE): + if spec_metadata is not None and spec_metadata.is_layer_capture( + self.layer_idx): + self.fusion_config.POST_MOE_FUSION = False + return self.forward_MoE( + hidden_states=hidden_states, + attn_metadata=attn_metadata, + residual=residual, + spec_metadata=spec_metadata, + ) + else: + if spec_metadata is not None and spec_metadata.is_layer_capture( + self.layer_idx): + self.fusion_config.POST_MLP_FUSION = False + assert isinstance(self.mlp, GatedMLP) + return self.forward_mlp( + hidden_states=hidden_states, + residual=residual, + spec_metadata=spec_metadata, + ) + + def forward_MoE( + self, + hidden_states: torch.Tensor, + attn_metadata: AttentionMetadata, + residual: torch.Tensor, + spec_metadata: Optional[SpecMetadata] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + + def _run_MoE(hidden_states, hidden_states_fp4, do_finalize): + return self.mlp( + hidden_states, + hidden_states_fp4, + all_rank_num_tokens=attn_metadata.all_rank_num_tokens, + final_all_reduce_params=AllReduceParams( + enable_allreduce=not (self.fusion_config.POST_MOE_FUSION + or self.mapping.tp_size == 1)), + do_finalize=do_finalize, + ) + + if self.fusion_config.PRE_MOE_FUSION: + # moe_backend can be either CUTLASS or TRTLLM here + # TODO: unify the two min-latency MoE backends by enabling quant fusion + hidden_states, residual = self.allreduce( + hidden_states, + all_reduce_params=AllReduceParams( + fusion_op=AllReduceFusionOp.RESIDUAL_RMS_NORM, + residual=residual, + norm_weight=self.post_attention_layernorm.weight, + eps=self.post_attention_layernorm.variance_epsilon, + trigger_completion_at_end=False, + )) + else: + # No fusion + hidden_states, residual = self.post_attention_layernorm( + hidden_states, residual) + + # Note: this fusion pattern is only supported for single-node TRTLLM-nvfp4 backend now + do_finalize = self.mapping.is_multi_node() or ( + not (hidden_states.shape[0] <= self.moe_allreduce.max_token + and self.fusion_config.POST_MOE_FUSION + and self.model_config.moe_backend == "TRTLLM" + and self.mlp.experts.has_nvfp4 and self.is_p2p_supported)) + + hidden_states = _run_MoE(hidden_states, + hidden_states_fp4=None, + do_finalize=do_finalize) + + if self.fusion_config.POST_MOE_FUSION: + if do_finalize: + hidden_states, residual = self.allreduce( + hidden_states, + all_reduce_params=AllReduceParams( + fusion_op=AllReduceFusionOp.RESIDUAL_RMS_NORM, + residual=residual, + norm_weight=self.next_layer_layernorm.weight, + eps=self.next_layer_layernorm.variance_epsilon, + trigger_completion_at_end=False, + )) + else: + assert len( + hidden_states) == 4, "hidden_states must have 4 elements" + + shared_output = hidden_states[0] + fc2_output = hidden_states[1] + expert_scale_factor = hidden_states[2] + expanded_idx_to_permuted_idx = hidden_states[3] + + moe_all_reduce_params = MoEAllReduceParams( + expanded_idx_to_permuted_idx=expanded_idx_to_permuted_idx, + expert_scale_factor=expert_scale_factor, + shared_expert_output=shared_output, + residual=residual, + norm_weight=self.next_layer_layernorm.weight, + eps=self.next_layer_layernorm.variance_epsilon, + is_cutlass_min_latency=False, + ) + hidden_states, residual = self.moe_allreduce( + fc2_output, all_reduce_params=moe_all_reduce_params) + else: + if spec_metadata is not None and spec_metadata.is_layer_capture( + self.layer_idx): + spec_metadata.maybe_capture_hidden_states( + self.layer_idx, hidden_states, residual) + if self.next_layer_layernorm is not None: + hidden_states, residual = self.next_layer_layernorm( + hidden_states, residual) + + return hidden_states, residual + + def forward_mlp( + self, + hidden_states: torch.Tensor, + residual: torch.Tensor, + spec_metadata: Optional[SpecMetadata] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + + if self.fusion_config.PRE_MLP_FUSION: + act_fp4, act_sf, residual = self.allreduce( + hidden_states, + all_reduce_params=AllReduceParams( + fusion_op=AllReduceFusionOp.RESIDUAL_RMS_NORM_QUANT_NVFP4, + residual=residual, + norm_weight=self.post_attention_layernorm.weight, + scale=self.mlp.gate_up_proj.input_scale, + eps=self.post_attention_layernorm.variance_epsilon, + ), + ) + hidden_states = Fp4QuantizedTensor(act_fp4, act_sf) + else: + # No fusion + # We need to add twoshot allreduce here to avoid modifying MLA logic + hidden_states, residual = self.post_attention_layernorm( + hidden_states, residual) + + hidden_states = self.mlp( + hidden_states, + final_all_reduce_params=AllReduceParams(enable_allreduce=not ( + self.fusion_config.POST_MLP_FUSION or self.mlp_tp_size == 1)), + ) + + if self.fusion_config.POST_MLP_FUSION: + hidden_states, residual = self.allreduce( + hidden_states, + all_reduce_params=AllReduceParams( + fusion_op=AllReduceFusionOp.RESIDUAL_RMS_NORM, + residual=residual, + norm_weight=self.next_layer_layernorm.weight, + eps=self.next_layer_layernorm.variance_epsilon, + ), + ) + else: + if spec_metadata is not None and spec_metadata.is_layer_capture( + self.layer_idx): + spec_metadata.maybe_capture_hidden_states( + self.layer_idx, hidden_states, residual) + if self.next_layer_layernorm is not None: + hidden_states, residual = self.next_layer_layernorm( + hidden_states, residual) + + return hidden_states, residual + + +class Glm4MTP(Glm4DecoderLayer): + + def __init__(self, + model_config: ModelConfig[PretrainedConfig], + layer_idx: int, + aux_stream_dict: Dict[AuxStreamType, torch.cuda.Stream], + is_separate_draft_engine: bool = False): + super().__init__(model_config, layer_idx, aux_stream_dict, + is_separate_draft_engine) + config = model_config.pretrained_config + self.hidden_dim = config.hidden_size + self.moe_intermediate_size = config.moe_intermediate_size + self.num_experts = config.n_routed_experts + self.num_shared_experts = config.n_shared_experts + self.top_k = config.num_experts_per_tok + + self.aux_stream = aux_stream_dict[AuxStreamType.MoeShared] + self.event_dict = { + key: torch.cuda.Event() + for key in [EventType.Main, EventType.MoeShared] + } + + self.enorm = RMSNorm(hidden_size=config.hidden_size, + eps=config.rms_norm_eps, + dtype=config.torch_dtype) + + self.hnorm = RMSNorm(hidden_size=config.hidden_size, + eps=config.rms_norm_eps, + dtype=config.torch_dtype) + if model_config.mapping.enable_attention_dp: + self.eh_proj = Linear( + config.hidden_size * 2, + config.hidden_size, + bias=False, + dtype=config.torch_dtype, + skip_create_weights_in_init=model_config. + skip_create_weights_in_init, + ) + else: + self.eh_proj = Linear( + config.hidden_size * 2, + config.hidden_size, + bias=False, + dtype=config.torch_dtype, + tensor_parallel_mode=TensorParallelMode.ROW, + mapping=model_config.mapping, + reduce_output=True, + skip_create_weights_in_init=model_config. + skip_create_weights_in_init, + ) + + self.shared_head = DeepseekV3MTPHead(model_config) + + def forward( + self, + input_ids: torch.IntTensor, + position_ids: torch.IntTensor, + hidden_states: torch.Tensor, + embed_tokens: Embedding, + attn_metadata: AttentionMetadata, + all_rank_num_tokens: Optional[List[int]] = None, + **kwargs, + ) -> torch.Tensor: + + def norm_embeds(): + return self.enorm(embed_tokens(input_ids)) #emdedding + + def norm_hidden(): + return self.hnorm(hidden_states) + + inputs_embeds, hidden_states = maybe_execute_in_parallel( + norm_embeds, + norm_hidden, + self.event_dict[EventType.Main], + self.event_dict[EventType.MoeShared], + self.aux_stream, + ) + hidden_states = torch.concat([inputs_embeds, hidden_states], dim=-1) + # Split hidden_states columnwise based on TP + tp_size = self.model_config.mapping.tp_size + tp_rank = self.model_config.mapping.tp_rank + + if tp_size > 1 and not (self.model_config.mapping.enable_attention_dp): + hidden_states = torch.chunk(hidden_states, tp_size, dim=-1)[tp_rank] + hidden_states = self.eh_proj(hidden_states) + + # Input layer norm + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + + # Self Attention + hidden_states = self.self_attn( + position_ids=position_ids, + hidden_states=hidden_states, + attn_metadata=attn_metadata, + all_reduce_params=AllReduceParams( + enable_allreduce=not (self.disable_attn_allreduce)), + **kwargs, + ) + + # MTP Layer Must have sparse MOE + if self.fusion_config.PRE_MOE_FUSION: + hidden_states, residual = self.allreduce( + hidden_states, + all_reduce_params=AllReduceParams( + fusion_op=AllReduceFusionOp.RESIDUAL_RMS_NORM, + residual=residual, + norm_weight=self.post_attention_layernorm.weight, + eps=self.post_attention_layernorm.variance_epsilon, + ), + ) + else: + hidden_states, residual = self.post_attention_layernorm( + hidden_states, residual) + + # MoE + hidden_states = self.mlp( + hidden_states, + all_rank_num_tokens=all_rank_num_tokens, + final_all_reduce_params=AllReduceParams( + enable_allreduce=not (self.fusion_config.POST_MOE_FUSION + or self.mapping.tp_size == 1)), + ) + + if self.fusion_config.POST_MOE_FUSION: + hidden_states, residual = self.allreduce( + hidden_states, + all_reduce_params=AllReduceParams( + fusion_op=AllReduceFusionOp.RESIDUAL_RMS_NORM, + residual=residual, + norm_weight=self.shared_head.norm.weight, + eps=self.shared_head.norm.variance_epsilon, + ), + ) + else: + hidden_states, _ = self.shared_head.norm(hidden_states, residual) + + return hidden_states + + +class Glm4Model(DecoderModel): + + def __init__(self, model_config: ModelConfig[PretrainedConfig]): + super().__init__(model_config) + config = model_config.pretrained_config + self.vocab_size = config.vocab_size + self.num_hidden_layers = config.num_hidden_layers + aux_stream_list = [torch.cuda.Stream() for _ in range(3)] + self.aux_stream_dict = { + AuxStreamType.Attention: aux_stream_list[0], + AuxStreamType.MoeShared: aux_stream_list[0], + AuxStreamType.MoeChunkingOverlap: aux_stream_list[1], + AuxStreamType.MoeBalancer: aux_stream_list[2], + } + + self.embed_tokens = Embedding( + config.vocab_size, + config.hidden_size, + dtype=config.torch_dtype, + ) + + self.layers = nn.ModuleList([ + Glm4DecoderLayer(model_config, layer_idx, self.aux_stream_dict) + for layer_idx in range(config.num_hidden_layers) + ]) + self.norm = RMSNorm(hidden_size=config.hidden_size, + eps=config.rms_norm_eps, + dtype=config.torch_dtype) + + def forward( + self, + attn_metadata: AttentionMetadata, + input_ids: Optional[torch.IntTensor] = None, + position_ids: Optional[torch.IntTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + spec_metadata: Optional[SpecMetadata] = None, + **kwargs, + ) -> torch.Tensor: + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError( + "You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one" + ) + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + hidden_states = inputs_embeds + residual = None + + for decoder_layer in self.layers[:self.num_hidden_layers]: + hidden_states, residual = decoder_layer( + position_ids=position_ids, + hidden_states=hidden_states, + attn_metadata=attn_metadata, + residual=residual, + spec_metadata=spec_metadata, + ) + + return hidden_states + + +@register_auto_model("Glm4MoeForCausalLM") +class Glm4MoeForCausalLM(SpecDecOneEngineForCausalLM[Glm4Model, + PretrainedConfig]): + + def __init__(self, model_config: ModelConfig[PretrainedConfig]): + super().__init__(model=Glm4Model(model_config), + model_config=model_config) + + self.model_nextn = 0 + if model_config.spec_config is not None and model_config.spec_config.spec_dec_mode.is_mtp_one_model( + ): + model_nextn = model_config.spec_config.num_nextn_predict_layers + ckpt_nextn = self.config.num_nextn_predict_layers + self.num_hidden_layers = self.config.num_hidden_layers + assert ckpt_nextn > 0, "There is not MTP modules in the checkpoint." + if ckpt_nextn == 1 and not model_config.spec_config.use_mtp_vanilla: + pass + else: + # modify the QuantConfig to support duplicated mtp layers + if model_config.quant_config.exclude_modules is not None: + extend_exclude_modules = [] + for model_mtp_idx in range( + self.num_hidden_layers, + self.num_hidden_layers + model_nextn): + ckpt_mtp_idx = (model_mtp_idx - self.num_hidden_layers + ) % ckpt_nextn + self.num_hidden_layers + model_prefix = f"model.layers.{model_mtp_idx}" + ckpt_prefix = f"model.layers.{ckpt_mtp_idx}" + for exclude_module in model_config.quant_config.exclude_modules: + if ckpt_prefix in exclude_module and model_prefix not in exclude_module: + extend_exclude_modules.append( + exclude_module.replace( + ckpt_prefix, model_prefix)) + self.model_config.quant_config.exclude_modules.extend( + extend_exclude_modules) + self.model.layers.extend(self.draft_model.mtp_layers) + self.epilogue.extend(self.draft_model.mtp_layers) + self.epilogue.append(self.spec_worker) + + def forward( + self, + attn_metadata: AttentionMetadata, + input_ids: torch.IntTensor = None, + position_ids: Optional[torch.IntTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + spec_metadata: Optional[SpecMetadata] = None, + return_context_logits: bool = False, + **kwargs, + ) -> torch.Tensor: + return super().forward(attn_metadata=attn_metadata, + input_ids=input_ids, + position_ids=position_ids, + inputs_embeds=inputs_embeds, + spec_metadata=spec_metadata, + return_context_logits=return_context_logits, + **kwargs) + + def load_weights(self, weights: Dict): + # model.layers.91.mlp.experts.75.up_proj.weight_scale_2 + _load_weights_impl( + self, + weights, + params_map={ + r'(?!.*shared_experts)(?=.*experts?)(.*?)up_proj(.*)': + r'\1w3\2', + r'(?!.*shared_experts)(?=.*experts?)(.*?)down_proj(.*)': + r'\1w2\2', + r'(?!.*shared_experts)(?=.*experts?)(.*?)gate_proj(.*)': + r'\1w1\2', + }) + + def post_load_weights(self): + all_named_modules = dict(self.model.named_modules()) + for name, module in tqdm(all_named_modules.items(), + desc="Post loading weights"): + if len(module._parameters) <= 0 or name.startswith("draft_model"): + continue + else: + if self.model_config.quant_config.layer_quant_mode.has_fp8_block_scales( + ) and is_sm_100f() and hasattr(module, "weight_scale"): + weight, weight_scale = resmooth_to_fp8_e8m0( + module.weight, module.weight_scale) + transfromed_scale = transform_sf_into_required_layout( + weight_scale, + mn=weight.shape[0], + k=weight.shape[1], + recipe=(1, 128, 128), + is_sfa=False) + module.weight = nn.Parameter(weight, requires_grad=False) + module.weight_scale = nn.Parameter(transfromed_scale, + requires_grad=False) + + for idx, layer in enumerate( + self.model.layers[:self.config.num_hidden_layers]): + if idx == self.config.num_hidden_layers - 1: + layer.next_layer_layernorm = self.model.norm + else: + layer.next_layer_layernorm = self.model.layers[ + idx + 1].input_layernorm diff --git a/tensorrt_llm/_torch/models/modeling_speculative.py b/tensorrt_llm/_torch/models/modeling_speculative.py index 31d52791f6b..5a472e06085 100755 --- a/tensorrt_llm/_torch/models/modeling_speculative.py +++ b/tensorrt_llm/_torch/models/modeling_speculative.py @@ -351,7 +351,17 @@ def __init__( ): super().__init__() # Import here to avoid circular import - from .modeling_deepseekv3 import DeepseekV3MTP + model_type = model_config.pretrained_config.model_type + mtp_layer = None + match model_type: + case "glm4_moe": + from .modeling_glm import Glm4MTP + mtp_layer = Glm4MTP + case "deepseek_v3": + from .modeling_deepseekv3 import DeepseekV3MTP + mtp_layer = DeepseekV3MTP + case _: + raise ValueError(f"Model type {model_type} not supported") spec_dec_mode = model_config.spec_config.spec_dec_mode assert spec_dec_mode.is_mtp_one_model() @@ -362,8 +372,8 @@ def __init__( model_config.spec_config.num_nextn_predict_layers // mtp_num_layers) self.mtp_layers = nn.ModuleList([ - DeepseekV3MTP(model_config, layer_idx + start_layer_idx, - model.aux_stream_dict) + mtp_layer(model_config, layer_idx + start_layer_idx, + model.aux_stream_dict) for layer_idx in range(mtp_num_layers) ]) self.lm_head = lm_head diff --git a/tensorrt_llm/_torch/speculative/mtp.py b/tensorrt_llm/_torch/speculative/mtp.py index c618a240b1b..0e6ba1a235f 100644 --- a/tensorrt_llm/_torch/speculative/mtp.py +++ b/tensorrt_llm/_torch/speculative/mtp.py @@ -827,7 +827,7 @@ def sample_and_accept_draft_tokens( attn_metadata.seq_lens_cuda, dim=0, dtype=torch.long) - 1 ctx_input_ids = input_ids[:attn_metadata.num_ctx_tokens] ctx_is_think = (ctx_input_ids == - self.spec_config.BEGIN_THINKING_PHASE_TOKEN).int() + self.spec_config.begin_thinking_phase_token).int() ctx_is_think_cumsum = torch.cumsum(ctx_is_think, dim=0) ctx_last_cumsum = ctx_is_think_cumsum[ last_tokens_idx[:num_contexts]] @@ -853,8 +853,8 @@ def sample_and_accept_draft_tokens( mtp_relaxed_delta_pool, num_accepted_tokens, accepted_tokens, mtp_num_modules, batch_size, num_contexts, self.spec_config.relaxed_topk, self.spec_config.relaxed_delta, - self.spec_config.BEGIN_THINKING_PHASE_TOKEN, - self.spec_config.END_THINKING_PHASE_TOKEN) + self.spec_config.begin_thinking_phase_token, + self.spec_config.end_thinking_phase_token) # Strict acceptance else: diff --git a/tensorrt_llm/llmapi/llm_args.py b/tensorrt_llm/llmapi/llm_args.py index 89a0d8d6193..def8f5807ed 100644 --- a/tensorrt_llm/llmapi/llm_args.py +++ b/tensorrt_llm/llmapi/llm_args.py @@ -682,12 +682,11 @@ class MTPDecodingConfig(DecodingBaseConfig): # Now we need a flag when MTPDecodingConfig is updated by PyTorchModelEngine. num_nextn_predict_layers_from_model_config: int = 1 - # TODO: Hard code for DeepSeek R1 # When encounter , start thinking phase. # When encounter , end thinking phase. # [thinking phase] [real output] - BEGIN_THINKING_PHASE_TOKEN: int = 128798 - END_THINKING_PHASE_TOKEN: int = 128799 + begin_thinking_phase_token: int = 128798 + end_thinking_phase_token: int = 128799 @classmethod def from_dict(cls, data: dict): diff --git a/tests/integration/defs/accuracy/test_llm_api_pytorch.py b/tests/integration/defs/accuracy/test_llm_api_pytorch.py index 1ed6b7bf1be..a0a3b6f38a0 100644 --- a/tests/integration/defs/accuracy/test_llm_api_pytorch.py +++ b/tests/integration/defs/accuracy/test_llm_api_pytorch.py @@ -2269,6 +2269,58 @@ def test_fp8_blockscale_chunked_prefill(self, tp_size, pp_size, ep_size, task.evaluate(llm) +@pytest.mark.timeout(7200) +@pytest.mark.skip_less_device_memory(80000) +class TestGLM4_6(LlmapiAccuracyTestHarness): + MODEL_NAME = "zai-org/GLM-4.6" + MODEL_PATH = f"{llm_models_root()}/GLM-4.6" + + @skip_pre_blackwell + @pytest.mark.parametrize( + "tp_size,pp_size,mtp_nextn,fp8kv,cuda_graph,overlap_scheduler,chunked_prefill,max_batch_size", + [ + pytest.param(4, + 1, + 1, + True, + True, + True, + True, + 16, + marks=pytest.mark.skip_less_mpi_world_size(4)), + ], + ids=["throughput"]) + def test_nvfp4_multi_gpus(self, tp_size, pp_size, mtp_nextn, fp8kv, + cuda_graph, overlap_scheduler, chunked_prefill, + max_batch_size): + kv_cache_config = KvCacheConfig(free_gpu_memory_fraction=0.70) + pytorch_config = dict( + disable_overlap_scheduler=not overlap_scheduler, + cuda_graph_config=CudaGraphConfig() if cuda_graph else None) + + if fp8kv: + kv_cache_config.dtype = "fp8" + + mtp_config = None + if mtp_nextn > 0: + mtp_config = MTPDecodingConfig(num_nextn_predict_layers=mtp_nextn) + with LLM(f"{llm_models_root()}/GLM-4.6/GLM-4.6-FP4", + max_batch_size=max_batch_size, + tensor_parallel_size=tp_size, + pipeline_parallel_size=pp_size, + kv_cache_config=kv_cache_config, + **pytorch_config, + speculative_config=mtp_config, + enable_chunked_prefill=chunked_prefill) as llm: + + assert llm.args.quant_config.quant_algo == QuantAlgo.NVFP4 + + task = MMLU(self.MODEL_NAME) + task.evaluate(llm) + task = GSM8K(self.MODEL_NAME) + task.evaluate(llm) + + @pytest.mark.timeout(7200) @pytest.mark.skip_less_device_memory(100000) class TestKimiK2(LlmapiAccuracyTestHarness): diff --git a/tests/integration/test_lists/qa/llm_function_core.txt b/tests/integration/test_lists/qa/llm_function_core.txt index 7b5f27341f5..b48f4a33ab2 100644 --- a/tests/integration/test_lists/qa/llm_function_core.txt +++ b/tests/integration/test_lists/qa/llm_function_core.txt @@ -494,6 +494,7 @@ accuracy/test_llm_api_pytorch.py::TestDeepSeekR1::test_nvfp4_multi_gpus_corner_c accuracy/test_llm_api_pytorch.py::TestDeepSeekR1::test_fp8_blockscale[throughput] accuracy/test_llm_api_pytorch.py::TestDeepSeekR1::test_fp8_blockscale_chunked_prefill[latency] accuracy/test_llm_api_pytorch.py::TestDeepSeekR1::test_fp8_blockscale_chunked_prefill[throughput] +accuracy/test_llm_api_pytorch.py::TestGLM4_6::test_nvfp4_multi_gpus[throughput] accuracy/test_llm_api_pytorch.py::TestQwen3_8B::test_fp8_block_scales[latency] accuracy/test_llm_api_pytorch.py::TestQwen3_30B_A3B::test_fp8_block_scales[latency-torch_compile=False] accuracy/test_llm_api_pytorch.py::TestQwen3_30B_A3B::test_fp8_block_scales[latency-torch_compile=True]