Skip to content

Commit

Permalink
add mlp for gemma2 (#11678)
Browse files Browse the repository at this point in the history
  • Loading branch information
MeouSker77 authored Jul 29, 2024
1 parent 1da1f1d commit c020039
Show file tree
Hide file tree
Showing 3 changed files with 27 additions and 2 deletions.
4 changes: 3 additions & 1 deletion python/llm/src/ipex_llm/transformers/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -1513,11 +1513,13 @@ def _optimize_post(model, lightweight_bmm=False):
from ipex_llm.transformers.models.gemma import gemma_rms_norm_forward
from ipex_llm.transformers.models.gemma2 import gemma2_attention_forward
from ipex_llm.transformers.models.gemma2 import gemma2_model_forward
from ipex_llm.transformers.models.gemma2 import gemma2_mlp_forward
from transformers.models.gemma2.modeling_gemma2 import Gemma2RMSNorm, Gemma2Attention
from transformers.models.gemma2.modeling_gemma2 import Gemma2Model
from transformers.models.gemma2.modeling_gemma2 import Gemma2Model, Gemma2MLP
convert_forward(model, Gemma2RMSNorm, gemma_rms_norm_forward)
convert_forward(model, Gemma2Attention, gemma2_attention_forward)
convert_forward(model, Gemma2Model, gemma2_model_forward)
convert_forward(model, Gemma2MLP, gemma2_mlp_forward)
elif model.config.model_type == "Yi":
modeling_module_name = model.__class__.__module__
module = importlib.import_module(modeling_module_name)
Expand Down
18 changes: 18 additions & 0 deletions python/llm/src/ipex_llm/transformers/models/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,3 +41,21 @@ def merge_qkv_base(module: torch.nn.Module, attention_class):
])
module.qkv_proj = qkv_proj
del module.q_proj, module.k_proj, module.v_proj


def fuse_mlp_base(module: torch.nn.Module, act: int, x: torch.Tensor):
from ipex_llm.transformers.models.utils import mlp_fusion_check
x_2d = x.view(-1, x.size(-1))
qtype = getattr(module.gate_proj, "qtype", None)
if mlp_fusion_check(x_2d, qtype, module.training):
import xe_linear
x_2d = x_2d.contiguous()
return module.down_proj(
xe_linear.mlp_forward_xpu(
x_2d, module.gate_proj.weight.data, module.up_proj.weight.data,
x_2d.size(0), x_2d.size(1), module.gate_proj.out_len,
act, qtype
)
)
else:
return module.down_proj(module.act_fn(module.gate_proj(x)) * module.up_proj(x))
7 changes: 6 additions & 1 deletion python/llm/src/ipex_llm/transformers/models/gemma2.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,8 @@
import torch

from typing import Optional, Tuple
from ipex_llm.transformers.models.common import merge_qkv_base
from ipex_llm.transformers.models.common import merge_qkv_base, fuse_mlp_base
from ipex_llm.transformers.models.utils import GELU
from ipex_llm.transformers.models.utils import should_use_fuse_rope, use_sdp, use_sdp_causal
from transformers.cache_utils import Cache
from transformers.models.gemma2.modeling_gemma2 import Gemma2Model, Gemma2Attention
Expand Down Expand Up @@ -177,3 +178,7 @@ def gemma2_attention_forward(
attn_weights = None

return attn_output, attn_weights, past_key_value


def gemma2_mlp_forward(self, x: torch.Tensor):
return fuse_mlp_base(self, GELU, x)

0 comments on commit c020039

Please sign in to comment.