diff --git a/megatron/core/transformer/moe/moe_utils.py b/megatron/core/transformer/moe/moe_utils.py index ee4bb690b7..02a2cccca5 100644 --- a/megatron/core/transformer/moe/moe_utils.py +++ b/megatron/core/transformer/moe/moe_utils.py @@ -327,6 +327,7 @@ def topk_softmax_with_capacity( pad_to_capacity: bool = False, drop_policy: str = "probs", use_pre_softmax: bool = False, + deterministic_mode: bool = False, ): """Apply capacity and padding to the top-k selection. Args: @@ -366,7 +367,10 @@ def topk_softmax_with_capacity( if capacity_factor is None: # TopK without capacity - tokens_per_expert = torch.bincount(top_indices.view(-1), minlength=num_experts) + if deterministic_mode: + tokens_per_expert = torch.bincount(top_indices.view(-1), minlength=num_experts) + else: + tokens_per_expert = torch.histc(top_indices, bins=num_experts, min=0, max=num_experts) return probs, top_indices, tokens_per_expert else: # TopK with capacity diff --git a/megatron/core/transformer/moe/router.py b/megatron/core/transformer/moe/router.py index 8894dc1df3..3e85ec53c5 100644 --- a/megatron/core/transformer/moe/router.py +++ b/megatron/core/transformer/moe/router.py @@ -74,7 +74,8 @@ def routing(self, logits: torch.Tensor): logits (torch.Tensor): Logits tensor. Returns: - Tuple[torch.Tensor, torch.Tensor]: Tuple of tensors representing max probs and the indices. + Tuple[torch.Tensor, torch.Tensor]: + Tuple of tensors representing max probs and the indices. """ raise NotImplementedError("Routing function not implemented.") @@ -155,6 +156,7 @@ def aux_loss_load_balancing(self, logits: torch.Tensor): pad_to_capacity=self.config.moe_pad_expert_input_to_capacity, drop_policy=self.config.moe_token_drop_policy, use_pre_softmax=self.config.moe_router_pre_softmax, + deterministic_mode=self.config.deterministic_mode, ) if self.training: @@ -172,8 +174,10 @@ def apply_load_balancing_loss( """Applies auxiliary loss to the MoE layer. Args: - probs (torch.Tensor): The probs output by the router for each token. [num_tokens, num_experts] - num_local_tokens_per_expert (torch.Tensor): The number of tokens per expert. [num_experts] + probs (torch.Tensor): + The probs output by the router for each token. [num_tokens, num_experts] + num_local_tokens_per_expert (torch.Tensor): + The number of tokens per expert. [num_experts] activation (torch.Tensor): The activation tensor to attach the gradient function to. Returns: @@ -279,6 +283,7 @@ def routing(self, logits: torch.Tensor): pad_to_capacity=self.config.moe_pad_expert_input_to_capacity, drop_policy=self.config.moe_token_drop_policy, use_pre_softmax=self.config.moe_router_pre_softmax, + deterministic_mode=self.config.deterministic_mode, ) else: raise ValueError(f"Unsupported MoE routing type: {self.routing_type}") diff --git a/megatron/core/transformer/moe/token_dispatcher.py b/megatron/core/transformer/moe/token_dispatcher.py index e23ea4ea0f..db1b1920fa 100644 --- a/megatron/core/transformer/moe/token_dispatcher.py +++ b/megatron/core/transformer/moe/token_dispatcher.py @@ -184,13 +184,23 @@ def token_permutation( self.global_local_map = None with torch.no_grad(): - tokens_per_expert = torch.bincount( - local_indices.view(-1), minlength=self.config.num_moe_experts - ) - if self.num_local_experts < self.config.num_moe_experts: - tokens_per_expert = tokens_per_expert[ - self.local_expert_indices[0] : self.local_expert_indices[-1] + 1 - ] + # The indices of local_indices that give its sorted order along dim 0. + self.indices = torch.argsort(local_indices, dim=0) + if self.config.deterministic_mode: + tokens_per_expert = torch.bincount( + local_indices.view(-1), minlength=self.config.num_moe_experts + ) + if self.num_local_experts < self.config.num_moe_experts: + tokens_per_expert = tokens_per_expert[ + self.local_expert_indices[0] : self.local_expert_indices[-1] + 1 + ] + else: + tokens_per_expert = torch.histc( + local_indices, + bins=self.num_local_experts, + min=self.local_expert_indices[0], + max=self.local_expert_indices[-1], + ) tokens_per_expert = tokens_per_expert.cpu().to(torch.long) # Stage2: permute the tokens locally so that they are grouped by their expert assignment @@ -382,7 +392,14 @@ def preprocess(self, indices: torch.Tensor) -> torch.Tensor: Returns: torch.Tensor: Tensor containing the number of tokens assigned to local expert. """ - num_local_tokens_per_expert = torch.bincount(indices.view(-1), minlength=self.num_experts) + if self.config.deterministic_mode: + num_local_tokens_per_expert = torch.bincount( + indices.view(-1), minlength=self.num_experts + ) + else: + num_local_tokens_per_expert = torch.histc( + indices, bins=self.num_experts, min=0, max=self.num_experts + ) # num_local_tokens_per_expert: [num_experts] tp_rank = parallel_state.get_tensor_model_parallel_rank()