Skip to content

Commit

Permalink
Merge branch 'revert_bincount' into 'main'
Browse files Browse the repository at this point in the history
Fix performance regression brought by torch.bincount

Closes #263

See merge request ADLR/megatron-lm!2005
  • Loading branch information
ko3n1g committed Sep 24, 2024
2 parents 075c727 + 5e23e72 commit 884b087
Show file tree
Hide file tree
Showing 3 changed files with 38 additions and 12 deletions.
6 changes: 5 additions & 1 deletion megatron/core/transformer/moe/moe_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -327,6 +327,7 @@ def topk_softmax_with_capacity(
pad_to_capacity: bool = False,
drop_policy: str = "probs",
use_pre_softmax: bool = False,
deterministic_mode: bool = False,
):
"""Apply capacity and padding to the top-k selection.
Args:
Expand Down Expand Up @@ -366,7 +367,10 @@ def topk_softmax_with_capacity(

if capacity_factor is None:
# TopK without capacity
tokens_per_expert = torch.bincount(top_indices.view(-1), minlength=num_experts)
if deterministic_mode:
tokens_per_expert = torch.bincount(top_indices.view(-1), minlength=num_experts)
else:
tokens_per_expert = torch.histc(top_indices, bins=num_experts, min=0, max=num_experts)
return probs, top_indices, tokens_per_expert
else:
# TopK with capacity
Expand Down
11 changes: 8 additions & 3 deletions megatron/core/transformer/moe/router.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,8 @@ def routing(self, logits: torch.Tensor):
logits (torch.Tensor): Logits tensor.
Returns:
Tuple[torch.Tensor, torch.Tensor]: Tuple of tensors representing max probs and the indices.
Tuple[torch.Tensor, torch.Tensor]:
Tuple of tensors representing max probs and the indices.
"""
raise NotImplementedError("Routing function not implemented.")

Expand Down Expand Up @@ -155,6 +156,7 @@ def aux_loss_load_balancing(self, logits: torch.Tensor):
pad_to_capacity=self.config.moe_pad_expert_input_to_capacity,
drop_policy=self.config.moe_token_drop_policy,
use_pre_softmax=self.config.moe_router_pre_softmax,
deterministic_mode=self.config.deterministic_mode,
)

if self.training:
Expand All @@ -172,8 +174,10 @@ def apply_load_balancing_loss(
"""Applies auxiliary loss to the MoE layer.
Args:
probs (torch.Tensor): The probs output by the router for each token. [num_tokens, num_experts]
num_local_tokens_per_expert (torch.Tensor): The number of tokens per expert. [num_experts]
probs (torch.Tensor):
The probs output by the router for each token. [num_tokens, num_experts]
num_local_tokens_per_expert (torch.Tensor):
The number of tokens per expert. [num_experts]
activation (torch.Tensor): The activation tensor to attach the gradient function to.
Returns:
Expand Down Expand Up @@ -279,6 +283,7 @@ def routing(self, logits: torch.Tensor):
pad_to_capacity=self.config.moe_pad_expert_input_to_capacity,
drop_policy=self.config.moe_token_drop_policy,
use_pre_softmax=self.config.moe_router_pre_softmax,
deterministic_mode=self.config.deterministic_mode,
)
else:
raise ValueError(f"Unsupported MoE routing type: {self.routing_type}")
Expand Down
33 changes: 25 additions & 8 deletions megatron/core/transformer/moe/token_dispatcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,13 +184,23 @@ def token_permutation(
self.global_local_map = None

with torch.no_grad():
tokens_per_expert = torch.bincount(
local_indices.view(-1), minlength=self.config.num_moe_experts
)
if self.num_local_experts < self.config.num_moe_experts:
tokens_per_expert = tokens_per_expert[
self.local_expert_indices[0] : self.local_expert_indices[-1] + 1
]
# The indices of local_indices that give its sorted order along dim 0.
self.indices = torch.argsort(local_indices, dim=0)
if self.config.deterministic_mode:
tokens_per_expert = torch.bincount(
local_indices.view(-1), minlength=self.config.num_moe_experts
)
if self.num_local_experts < self.config.num_moe_experts:
tokens_per_expert = tokens_per_expert[
self.local_expert_indices[0] : self.local_expert_indices[-1] + 1
]
else:
tokens_per_expert = torch.histc(
local_indices,
bins=self.num_local_experts,
min=self.local_expert_indices[0],
max=self.local_expert_indices[-1],
)
tokens_per_expert = tokens_per_expert.cpu().to(torch.long)

# Stage2: permute the tokens locally so that they are grouped by their expert assignment
Expand Down Expand Up @@ -382,7 +392,14 @@ def preprocess(self, indices: torch.Tensor) -> torch.Tensor:
Returns:
torch.Tensor: Tensor containing the number of tokens assigned to local expert.
"""
num_local_tokens_per_expert = torch.bincount(indices.view(-1), minlength=self.num_experts)
if self.config.deterministic_mode:
num_local_tokens_per_expert = torch.bincount(
indices.view(-1), minlength=self.num_experts
)
else:
num_local_tokens_per_expert = torch.histc(
indices, bins=self.num_experts, min=0, max=self.num_experts
)
# num_local_tokens_per_expert: [num_experts]

tp_rank = parallel_state.get_tensor_model_parallel_rank()
Expand Down

0 comments on commit 884b087

Please sign in to comment.