Skip to content

Commit

Permalink
fix glm4-9b overflow (#12455)
Browse files Browse the repository at this point in the history
  • Loading branch information
MeouSker77 authored Nov 27, 2024
1 parent 281c9b0 commit 6f3441b
Show file tree
Hide file tree
Showing 2 changed files with 72 additions and 0 deletions.
6 changes: 6 additions & 0 deletions python/llm/src/ipex_llm/transformers/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -1477,6 +1477,12 @@ def _optimize_post(model, lightweight_bmm=False):
convert_forward(model, module.ChatGLMModel, chatglm4_model_forward)
convert_forward(model, module.GLMTransformer, chatglm4_encoder_forward)
convert_forward(model, module.MLP, mlp_forward)

if model.config.num_layers == 40:
# workaround glm4-9b fp16 overflow
from ipex_llm.transformers.models.chatglm4 import chatglm4_block_forward
convert_forward(model, module.GLMBlock, chatglm4_block_forward)

elif "mpt" in model.config.model_type:
if model.config.architectures is not None:
modeling_module_name = model.__class__.__module__
Expand Down
66 changes: 66 additions & 0 deletions python/llm/src/ipex_llm/transformers/models/chatglm4.py
Original file line number Diff line number Diff line change
Expand Up @@ -363,3 +363,69 @@ def chatglm4_encoder_forward(
hidden_states = self.final_layernorm(hidden_states)

return hidden_states, presents, all_hidden_states, all_self_attentions


def chatglm4_block_forward(
self,
hidden_states,
attention_mask,
rotary_pos_emb,
kv_cache=None,
use_cache=True,
):
# hidden_states: [s, b, h]

# Layer norm at the beginning of the transformer layer.
layernorm_output = self.input_layernorm(hidden_states)
# Self attention.
attention_output, kv_cache = self.self_attention(
layernorm_output,
attention_mask,
rotary_pos_emb,
kv_cache=kv_cache,
use_cache=use_cache
)

# Residual connection.
if self.apply_residual_connection_post_layernorm:
residual = layernorm_output
else:
residual = hidden_states

layernorm_input = torch.nn.functional.dropout(attention_output, p=self.hidden_dropout,
training=self.training)
layernorm_input = residual + layernorm_input

# Layer norm post the self attention.
layernorm_output = self.post_attention_layernorm(layernorm_input)

# ipex-llm changes start: workaround fp16 overflow
scale = 10
if self.layer_number == 39 and layernorm_output.device.type == 'xpu':
gate = self.mlp.gate_proj(layernorm_output)
up = self.mlp.up_proj(layernorm_output)
down = self.mlp.activation_fn(gate) / scale * up
mlp_output = self.mlp.dense_4h_to_h(down)
else:
# MLP.
mlp_output = self.mlp(layernorm_output)
# ipex-llm changes end

# Second residual connection.
if self.apply_residual_connection_post_layernorm:
residual = layernorm_output
else:
residual = layernorm_input

output = torch.nn.functional.dropout(mlp_output, p=self.hidden_dropout,
training=self.training)

# ipex-llm changes start: workaround fp16 overflow
if self.layer_number == 39 and layernorm_output.device.type == 'xpu':
output = residual + output * scale
output = torch.nan_to_num(output)
else:
output = residual + output
# ipex-llm changes end

return output, kv_cache

0 comments on commit 6f3441b

Please sign in to comment.