@@ -64,7 +64,16 @@ def chatglm2_model_forward(
64
64
rotary_pos_emb = rotary_pos_emb [position_ids ]
65
65
else :
66
66
rotary_pos_emb = rotary_pos_emb [None , :seq_length ]
67
- rotary_pos_emb = rotary_pos_emb .transpose (0 , 1 ).contiguous ()
67
+ # ipex-llm change start: change rope cache shape
68
+ # rotary_pos_emb: [bsz, seq_len, rot_dim//2, 2]
69
+ cos , sin = rotary_pos_emb .permute (3 , 0 , 1 , 2 ).chunk (2 , dim = 0 )
70
+ cos = cos .squeeze (0 ).unsqueeze (1 )
71
+ sin = sin .squeeze (0 ).unsqueeze (1 )
72
+ cos = cos .repeat_interleave (2 , dim = - 1 )
73
+ sin = sin .repeat_interleave (2 , dim = - 1 )
74
+ # cos, sin: [bsz, 1, seq_len, rot_dim]
75
+ rotary_pos_emb = (cos , sin )
76
+ # ipex-llm change end
68
77
69
78
# ipex-llm changes begin:
70
79
# generate `causal_mask` and replace `full_attention_mask` with it
@@ -76,14 +85,6 @@ def chatglm2_model_forward(
76
85
dtype = inputs_embeds .dtype , device = inputs_embeds .device )
77
86
mask_value = torch .finfo (inputs_embeds .dtype ).min
78
87
causal_mask .masked_fill_ (full_attention_mask , mask_value )
79
- elif self .training or (inputs_embeds .device .type != "xpu" and past_key_values is None ):
80
- full_attention_mask = self .get_masks (input_ids ,
81
- past_key_values ,
82
- padding_mask = attention_mask )
83
- causal_mask = torch .zeros ([batch_size , 1 , seq_length , full_attention_mask .size (- 1 )],
84
- dtype = inputs_embeds .dtype , device = inputs_embeds .device )
85
- mask_value = torch .finfo (inputs_embeds .dtype ).min
86
- causal_mask .masked_fill_ (full_attention_mask , mask_value )
87
88
else :
88
89
causal_mask = None
89
90
@@ -174,24 +175,20 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
174
175
175
176
176
177
@torch .jit .script
177
- def apply_rotary_pos_emb (x : torch .Tensor , rope_cache : torch .Tensor ) -> torch .Tensor :
178
- # x: [sq, b, np, hn]
179
- sq , b , np , hn = x .size (0 ), x .size (1 ), x .size (2 ), x .size (3 )
180
- rot_dim = rope_cache .shape [- 2 ] * 2
178
+ def rotate_every_two (x : torch .Tensor ):
179
+ x1 = x [:, :, :, ::2 ]
180
+ x2 = x [:, :, :, 1 ::2 ]
181
+ x = torch .stack ((- x2 , x1 ), dim = - 1 )
182
+ return x .flatten (- 2 )
183
+
184
+
185
+ def apply_rotary_pos_emb (x : torch .Tensor , rope_cache : Tuple [torch .Tensor ]) -> torch .Tensor :
186
+ # x: [bsz, n_head, seq_len, head_dim]
187
+ cos , sin = rope_cache
188
+ rot_dim = cos .size (- 1 )
181
189
x , x_pass = x [..., :rot_dim ], x [..., rot_dim :]
182
- # truncate to support variable sizes
183
- rope_cache = rope_cache [:sq ]
184
- xshaped = x .reshape (sq , - 1 , np , rot_dim // 2 , 2 )
185
- rope_cache = rope_cache .view (sq , - 1 , 1 , xshaped .size (3 ), 2 )
186
- x_out2 = torch .stack (
187
- [
188
- xshaped [..., 0 ] * rope_cache [..., 0 ] - xshaped [..., 1 ] * rope_cache [..., 1 ],
189
- xshaped [..., 1 ] * rope_cache [..., 0 ] + xshaped [..., 0 ] * rope_cache [..., 1 ],
190
- ],
191
- - 1 ,
192
- )
193
- x_out2 = x_out2 .flatten (3 )
194
- return torch .cat ((x_out2 , x_pass ), dim = - 1 )
190
+ x_out = x * cos + rotate_every_two (x ) * sin
191
+ return torch .cat ([x_out , x_pass ], dim = - 1 )
195
192
196
193
197
194
def chatglm2_attention_forward (
@@ -246,7 +243,7 @@ def chatglm2_attention_forward(
246
243
key_states ,
247
244
value_states ,
248
245
attn_mask = attention_mask ,
249
- is_causal = q_len > 1 and bsz == 1 ,
246
+ is_causal = attention_mask is None and q_len > 1 and bsz == 1 ,
250
247
)
251
248
attn_weights = None
252
249
else :
0 commit comments