Skip to content

Commit

Permalink
optimize glm4v vision attention (#12369)
Browse files Browse the repository at this point in the history
  • Loading branch information
MeouSker77 authored Nov 8, 2024
1 parent 2dfcc36 commit dc34e8c
Show file tree
Hide file tree
Showing 2 changed files with 85 additions and 7 deletions.
63 changes: 63 additions & 0 deletions python/llm/src/ipex_llm/transformers/models/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,69 @@ def merge_qkv_base(module: torch.nn.Module, attention_class):
del module.q_proj, module.k_proj, module.v_proj


def padding_linear_hd(linear: torch.nn.Linear,
old_head_dim: int, new_head_dim: int) -> torch.nn.Linear:
in_features, out_features = linear.in_features, linear.out_features

weight = linear.weight.data
weight = weight.view(-1, old_head_dim, in_features)
new_weight = torch.empty([weight.size(0), new_head_dim, in_features],
dtype=weight.dtype, device=weight.device)
new_weight[:, :old_head_dim, :] = weight
new_weight[:, old_head_dim:, :] = 0
new_weight = new_weight.view(-1, in_features)
if linear.bias is not None:
bias = linear.bias.data
bias = bias.view(-1, old_head_dim)
new_bias = torch.empty([bias.size(0), new_head_dim],
dtype=bias.dtype, device=bias.device)
new_bias[:, :old_head_dim] = bias
new_bias[:, old_head_dim:] = 0
new_bias = new_bias.flatten()

new_linear = torch.nn.Linear(0, 0, bias=True)
new_linear.bias = torch.nn.Parameter(new_bias, requires_grad=False)
else:
new_linear = torch.nn.Linear(0, 0, bias=False)
new_linear.weight = torch.nn.Parameter(new_weight, requires_grad=False)
new_linear.in_features = new_weight.size(1)
new_linear.out_features = new_weight.size(0)
return new_linear


def padding_attention_hd_base(module: torch.nn.Module, attention_class,
old_head_dim: int, new_head_dim: int):
if (
isinstance(attention_class, str) and module.__class__.__name__ == attention_class
or not isinstance(attention_class, str) and isinstance(module, attention_class)
) and module.head_dim == old_head_dim:
module.q_proj = padding_linear_hd(module.q_proj, old_head_dim, new_head_dim)
module.k_proj = padding_linear_hd(module.k_proj, old_head_dim, new_head_dim)
module.v_proj = padding_linear_hd(module.v_proj, old_head_dim, new_head_dim)
module.head_dim = new_head_dim
module.old_head_dim = old_head_dim


def padding_states_hd(states: torch.Tensor, old_head_dim: int, new_head_dim: int):
bsz, num_heads, seq_len, head_dim = states.size()
if head_dim == old_head_dim and old_head_dim < new_head_dim:
new_states = torch.empty([bsz, num_heads, seq_len, new_head_dim],
dtype=states.dtype, device=states.device)
new_states[:, :, :, :old_head_dim] = states
new_states[:, :, :, old_head_dim:] = 0
return new_states
return states


def padding_qkv_hd(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
old_head_dim: int, new_head_dim: int):
return (
padding_states_hd(q, old_head_dim, new_head_dim),
padding_states_hd(k, old_head_dim, new_head_dim),
padding_states_hd(v, old_head_dim, new_head_dim),
)


def fuse_mlp_base(module: torch.nn.Module, act: int, x: torch.Tensor):
from ipex_llm.transformers.models.utils import mlp_fusion_check
x_2d = x.view(-1, x.size(-1))
Expand Down
29 changes: 22 additions & 7 deletions python/llm/src/ipex_llm/transformers/models/minicpmv.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,9 @@
from threading import Thread
from typing import Optional, List
from torch.nn.functional import linear
from ipex_llm.transformers.models.common import merge_qkv_base
from ipex_llm.transformers.models.common import merge_qkv_base, padding_qkv_hd
from ipex_llm.transformers.models.common import attention_softmax
from ipex_llm.transformers.models.utils import use_sdp_non_causal
from transformers import AutoProcessor, TextIteratorStreamer
from transformers.generation.logits_process import RepetitionPenaltyLogitsProcessor

Expand All @@ -52,14 +53,28 @@ def siglip_attention_forward(
qkv = qkv.transpose(1, 2)
query_states, key_states, value_states = qkv.chunk(3, dim=1)

attn_weights = torch.matmul(query_states * self.scale, key_states.transpose(2, 3))
if attention_mask is not None:
attn_weights = attn_weights + attention_mask
query_states, key_states, value_states = padding_qkv_hd(
query_states, key_states, value_states,
72, 80
)

attn_weights = attention_softmax(attn_weights)
if use_sdp_non_causal(query_states.size(-1), query_states.device, query_states.dtype):
import xe_addons
attn_weights = None
attn_output = xe_addons.sdp_non_causal(query_states, key_states.contiguous(),
value_states.contiguous(), attention_mask)
else:
attn_weights = torch.matmul(query_states * self.scale, key_states.transpose(2, 3))
if attention_mask is not None:
attn_weights = attn_weights + attention_mask

attn_weights = torch.nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
attn_output = torch.matmul(attn_weights, value_states)
attn_weights = attention_softmax(attn_weights)

attn_weights = torch.nn.functional.dropout(attn_weights,
p=self.dropout, training=self.training)
attn_output = torch.matmul(attn_weights, value_states)

attn_output = attn_output[:, :, :, :self.head_dim]

attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.reshape(bsz, q_len, self.embed_dim)
Expand Down

0 comments on commit dc34e8c

Please sign in to comment.