Skip to content

Commit 5837bc0

Browse files
authored
fix chatglm3 npu output (#11590)
1 parent 06930ab commit 5837bc0

File tree

1 file changed

+24
-27
lines changed
  • python/llm/src/ipex_llm/transformers/npu_models

1 file changed

+24
-27
lines changed

python/llm/src/ipex_llm/transformers/npu_models/chatglm.py

Lines changed: 24 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,16 @@ def chatglm2_model_forward(
6464
rotary_pos_emb = rotary_pos_emb[position_ids]
6565
else:
6666
rotary_pos_emb = rotary_pos_emb[None, :seq_length]
67-
rotary_pos_emb = rotary_pos_emb.transpose(0, 1).contiguous()
67+
# ipex-llm change start: change rope cache shape
68+
# rotary_pos_emb: [bsz, seq_len, rot_dim//2, 2]
69+
cos, sin = rotary_pos_emb.permute(3, 0, 1, 2).chunk(2, dim=0)
70+
cos = cos.squeeze(0).unsqueeze(1)
71+
sin = sin.squeeze(0).unsqueeze(1)
72+
cos = cos.repeat_interleave(2, dim=-1)
73+
sin = sin.repeat_interleave(2, dim=-1)
74+
# cos, sin: [bsz, 1, seq_len, rot_dim]
75+
rotary_pos_emb = (cos, sin)
76+
# ipex-llm change end
6877

6978
# ipex-llm changes begin:
7079
# generate `causal_mask` and replace `full_attention_mask` with it
@@ -76,14 +85,6 @@ def chatglm2_model_forward(
7685
dtype=inputs_embeds.dtype, device=inputs_embeds.device)
7786
mask_value = torch.finfo(inputs_embeds.dtype).min
7887
causal_mask.masked_fill_(full_attention_mask, mask_value)
79-
elif self.training or (inputs_embeds.device.type != "xpu" and past_key_values is None):
80-
full_attention_mask = self.get_masks(input_ids,
81-
past_key_values,
82-
padding_mask=attention_mask)
83-
causal_mask = torch.zeros([batch_size, 1, seq_length, full_attention_mask.size(-1)],
84-
dtype=inputs_embeds.dtype, device=inputs_embeds.device)
85-
mask_value = torch.finfo(inputs_embeds.dtype).min
86-
causal_mask.masked_fill_(full_attention_mask, mask_value)
8788
else:
8889
causal_mask = None
8990

@@ -174,24 +175,20 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
174175

175176

176177
@torch.jit.script
177-
def apply_rotary_pos_emb(x: torch.Tensor, rope_cache: torch.Tensor) -> torch.Tensor:
178-
# x: [sq, b, np, hn]
179-
sq, b, np, hn = x.size(0), x.size(1), x.size(2), x.size(3)
180-
rot_dim = rope_cache.shape[-2] * 2
178+
def rotate_every_two(x: torch.Tensor):
179+
x1 = x[:, :, :, ::2]
180+
x2 = x[:, :, :, 1::2]
181+
x = torch.stack((-x2, x1), dim=-1)
182+
return x.flatten(-2)
183+
184+
185+
def apply_rotary_pos_emb(x: torch.Tensor, rope_cache: Tuple[torch.Tensor]) -> torch.Tensor:
186+
# x: [bsz, n_head, seq_len, head_dim]
187+
cos, sin = rope_cache
188+
rot_dim = cos.size(-1)
181189
x, x_pass = x[..., :rot_dim], x[..., rot_dim:]
182-
# truncate to support variable sizes
183-
rope_cache = rope_cache[:sq]
184-
xshaped = x.reshape(sq, -1, np, rot_dim // 2, 2)
185-
rope_cache = rope_cache.view(sq, -1, 1, xshaped.size(3), 2)
186-
x_out2 = torch.stack(
187-
[
188-
xshaped[..., 0] * rope_cache[..., 0] - xshaped[..., 1] * rope_cache[..., 1],
189-
xshaped[..., 1] * rope_cache[..., 0] + xshaped[..., 0] * rope_cache[..., 1],
190-
],
191-
-1,
192-
)
193-
x_out2 = x_out2.flatten(3)
194-
return torch.cat((x_out2, x_pass), dim=-1)
190+
x_out = x * cos + rotate_every_two(x) * sin
191+
return torch.cat([x_out, x_pass], dim=-1)
195192

196193

197194
def chatglm2_attention_forward(
@@ -246,7 +243,7 @@ def chatglm2_attention_forward(
246243
key_states,
247244
value_states,
248245
attn_mask=attention_mask,
249-
is_causal=q_len > 1 and bsz == 1,
246+
is_causal=attention_mask is None and q_len > 1 and bsz == 1,
250247
)
251248
attn_weights = None
252249
else:

0 commit comments

Comments
 (0)