@@ -103,20 +103,27 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
103103 # Use the original router (it returns scores and indices already softmaxed over top-k)
104104 router_scores , router_indices = self .router (x ) # scores: [tokens, E], indices: [tokens, k]
105105
106- out = self .shared_expert (x ) if self .shared_expert is not None else torch .zeros_like (x )
107-
108- # Accumulate expert outputs for chosen experts only
109- for j in range (self .top_k ):
110- idx = router_indices [:, j ]
111- w = router_scores [torch .arange (idx .size (0 ), device = idx .device ), idx ].unsqueeze (- 1 )
112- unique_experts = torch .unique (idx )
113- for e in unique_experts :
114- mask = idx == e
115- out [mask ] += self .experts [e ](x [mask ]) * w [mask ]
116-
117- out = out .view (B , T , H )
118- router_scores = router_scores .view (B * T , - 1 ) # shape doesn't matter much; it’s ignored by the decoder
119- return out , router_scores
106+ final_hidden_states = self .shared_expert (x ) if self .shared_expert is not None else torch .zeros_like (x )
107+ num_all_tokens , total_num_experts = x .size (0 ), self .num_experts
108+ mask_weights = torch .zeros ((num_all_tokens , total_num_experts ), dtype = x .dtype , device = x .device )
109+ topk_ids , experts_mask = router_indices , router_scores
110+ topk_ids = topk_ids .to (torch .int64 )
111+
112+ mask_weights .scatter_ (- 1 , topk_ids , 1 )
113+
114+ mask_weights = mask_weights [:num_all_tokens , :total_num_experts ]
115+ mask_weights = mask_weights .transpose (0 , 1 )
116+ experts_mask = experts_mask [:num_all_tokens , :total_num_experts ]
117+ experts_mask = experts_mask .transpose (0 , 1 )
118+ num_experts = total_num_experts
119+ for expert_index in range (num_experts ):
120+ mask_weight = mask_weights [expert_index ].unsqueeze (1 )
121+ current_state_static = x * mask_weight
122+ expert = self .experts [expert_index ]
123+ expert_output = expert (current_state_static )
124+ expert_output = expert_output * experts_mask [expert_index ].unsqueeze (1 )
125+ final_hidden_states += expert_output
126+ return final_hidden_states .view (B , T , H ), router_scores .view (B * T , - 1 )
120127
121128
122129def get_replacement_info (config ):
0 commit comments