Skip to content

Commit 6d3d87d

Browse files
authored
Reduce mem usage of GPT-OSS (#1013)
1 parent c4d030d commit 6d3d87d

File tree

1 file changed

+21
-14
lines changed

1 file changed

+21
-14
lines changed

auto_round/modelling/gpt_oss.py

Lines changed: 21 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -103,20 +103,27 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
103103
# Use the original router (it returns scores and indices already softmaxed over top-k)
104104
router_scores, router_indices = self.router(x) # scores: [tokens, E], indices: [tokens, k]
105105

106-
out = self.shared_expert(x) if self.shared_expert is not None else torch.zeros_like(x)
107-
108-
# Accumulate expert outputs for chosen experts only
109-
for j in range(self.top_k):
110-
idx = router_indices[:, j]
111-
w = router_scores[torch.arange(idx.size(0), device=idx.device), idx].unsqueeze(-1)
112-
unique_experts = torch.unique(idx)
113-
for e in unique_experts:
114-
mask = idx == e
115-
out[mask] += self.experts[e](x[mask]) * w[mask]
116-
117-
out = out.view(B, T, H)
118-
router_scores = router_scores.view(B * T, -1) # shape doesn't matter much; it’s ignored by the decoder
119-
return out, router_scores
106+
final_hidden_states = self.shared_expert(x) if self.shared_expert is not None else torch.zeros_like(x)
107+
num_all_tokens, total_num_experts = x.size(0), self.num_experts
108+
mask_weights = torch.zeros((num_all_tokens, total_num_experts), dtype=x.dtype, device=x.device)
109+
topk_ids, experts_mask = router_indices, router_scores
110+
topk_ids = topk_ids.to(torch.int64)
111+
112+
mask_weights.scatter_(-1, topk_ids, 1)
113+
114+
mask_weights = mask_weights[:num_all_tokens, :total_num_experts]
115+
mask_weights = mask_weights.transpose(0, 1)
116+
experts_mask = experts_mask[:num_all_tokens, :total_num_experts]
117+
experts_mask = experts_mask.transpose(0, 1)
118+
num_experts = total_num_experts
119+
for expert_index in range(num_experts):
120+
mask_weight = mask_weights[expert_index].unsqueeze(1)
121+
current_state_static = x * mask_weight
122+
expert = self.experts[expert_index]
123+
expert_output = expert(current_state_static)
124+
expert_output = expert_output * experts_mask[expert_index].unsqueeze(1)
125+
final_hidden_states += expert_output
126+
return final_hidden_states.view(B, T, H), router_scores.view(B * T, -1)
120127

121128

122129
def get_replacement_info(config):

0 commit comments

Comments
 (0)