39
39
TensorParallelEmbedding ,
40
40
PositionRotaryEmbedding ,
41
41
TensorParallelHead ,
42
+ get_linear ,
42
43
)
43
44
44
45
@@ -59,7 +60,8 @@ def forward(self, hidden_states, residual=None):
59
60
hidden_states += residual
60
61
residual = hidden_states
61
62
62
- variance = hidden_states .to (torch .float32 ).pow (2 ).mean (- 1 , keepdim = True )
63
+ hidden_states = hidden_states .to (torch .float32 )
64
+ variance = hidden_states .pow (2 ).mean (- 1 , keepdim = True )
63
65
hidden_states = hidden_states * torch .rsqrt (
64
66
variance + self .variance_epsilon
65
67
)
@@ -94,6 +96,27 @@ def forward(self, hidden_states, residual=None):
94
96
return normed_hidden_states , res
95
97
96
98
99
+ def _load_gqa (config , prefix : str , weights ):
100
+ w = [
101
+ weights .get_sharded (f"{ prefix } .q_proj.weight" , dim = 0 ),
102
+ weights .get_sharded (f"{ prefix } .k_proj.weight" , dim = 0 ),
103
+ weights .get_sharded (f"{ prefix } .v_proj.weight" , dim = 0 ),
104
+ ]
105
+ weight = torch .cat (w , dim = 0 )
106
+ weight = weight .to (dtype = weights .dtype ).to (device = weights .device )
107
+ bias = None
108
+ assert config .hidden_size % config .num_attention_heads == 0
109
+ head_size = config .hidden_size // config .num_attention_heads
110
+ assert config .num_attention_heads % weights .process_group .size () == 0
111
+ num_heads = config .num_attention_heads // weights .process_group .size ()
112
+ num_key_value_heads = config .num_key_value_heads // weights .process_group .size ()
113
+ assert list (weight .shape ) == [
114
+ (num_heads + 2 * num_key_value_heads ) * head_size ,
115
+ config .hidden_size ,
116
+ ], f"{ list (weight .shape )} != { [(num_heads + 2 * config .num_key_value_heads ) * head_size , config .hidden_size ]} "
117
+ return TensorParallelColumnLinear (get_linear (weight , bias , config .quantize ))
118
+
119
+
97
120
class FlashLlamaAttention (torch .nn .Module ):
98
121
def __init__ (
99
122
self ,
@@ -118,22 +141,29 @@ def __init__(
118
141
f"and `num_shards`: { weights .process_group .size ()} "
119
142
)
120
143
self .num_heads = self .num_heads // weights .process_group .size ()
121
- self .query_key_value = TensorParallelColumnLinear .load_multi (
122
- config ,
123
- prefixes = [f"{ prefix } .q_proj" , f"{ prefix } .k_proj" , f"{ prefix } .v_proj" ],
124
- dim = 0 ,
125
- weights = weights ,
126
- bias = False ,
144
+ self .num_key_value_heads = (
145
+ config .num_key_value_heads // weights .process_group .size ()
127
146
)
147
+ if config .num_attention_heads != config .num_key_value_heads :
148
+ self .query_key_value = _load_gqa (config , prefix , weights )
149
+ else :
150
+ self .query_key_value = TensorParallelColumnLinear .load_multi (
151
+ config ,
152
+ prefixes = [f"{ prefix } .q_proj" , f"{ prefix } .k_proj" , f"{ prefix } .v_proj" ],
153
+ dim = 0 ,
154
+ weights = weights ,
155
+ bias = False ,
156
+ )
128
157
self .o_proj = TensorParallelRowLinear .load (
129
158
config ,
130
159
prefix = f"{ prefix } .o_proj" ,
131
160
weights = weights ,
132
161
bias = False ,
133
162
)
163
+ self .num_groups = self .num_heads // self .num_key_value_heads
134
164
self .kv_head_mapping = torch .arange (
135
- 0 , self .num_heads , dtype = torch .int32 , device = weights .device
136
- )
165
+ 0 , self .num_key_value_heads , dtype = torch .int32 , device = weights .device
166
+ ). repeat_interleave ( self . num_groups )
137
167
138
168
def forward (
139
169
self ,
@@ -148,26 +178,33 @@ def forward(
148
178
max_s ,
149
179
):
150
180
qkv = self .query_key_value (hidden_states )
151
- qkv = qkv .view (- 1 , 3 , self .num_heads , self .head_size )
181
+ query , kv = qkv .split (
182
+ [
183
+ self .head_size * self .num_heads ,
184
+ 2 * self .head_size * self .num_key_value_heads ,
185
+ ],
186
+ dim = 1 ,
187
+ )
188
+ query = query .view (- 1 , self .num_heads , self .head_size )
189
+ kv = kv .view (- 1 , 2 , self .num_key_value_heads , self .head_size )
152
190
153
- # Inplace rotary
154
- self .rotary_emb (qkv [:, 0 ], cos , sin )
155
- self .rotary_emb (qkv [:, 1 ], cos , sin )
191
+ self .rotary_emb (query , cos , sin )
192
+ self .rotary_emb (torch .select (kv , dim = 1 , index = 0 ), cos , sin )
156
193
157
194
vllm_cache_ops .reshape_and_cache (
158
- qkv [:, 1 ], qkv [:, 2 ], kv_cache [0 ], kv_cache [1 ], slots
195
+ kv [:, 0 ], kv [:, 1 ], kv_cache [0 ], kv_cache [1 ], slots
159
196
)
160
197
161
198
# output tensor
162
- attn_output = torch .empty_like (qkv [:, 0 ] )
199
+ attn_output = torch .empty_like (query )
163
200
164
201
# Prefill
165
202
if cu_seqlen_prefill is not None :
166
203
# flash attention
167
204
attention (
168
- qkv [:, 0 ] ,
169
- qkv [:, 1 ] ,
170
- qkv [:, 2 ] ,
205
+ query ,
206
+ torch . select ( kv , dim = 1 , index = 0 ) ,
207
+ torch . select ( kv , dim = 1 , index = 1 ) ,
171
208
attn_output ,
172
209
cu_seqlen_prefill ,
173
210
max_s ,
@@ -179,7 +216,7 @@ def forward(
179
216
block_size = kv_cache [1 ].shape [3 ]
180
217
vllm_attention_ops .single_query_cached_kv_attention (
181
218
attn_output ,
182
- qkv [:, 0 ] ,
219
+ query ,
183
220
kv_cache [0 ],
184
221
kv_cache [1 ],
185
222
self .kv_head_mapping ,
@@ -316,6 +353,7 @@ def __init__(self, config, weights):
316
353
317
354
self .head_size = self .layers [0 ].self_attn .head_size
318
355
self .num_heads = self .layers [0 ].self_attn .num_heads
356
+ self .num_key_value_heads = self .layers [0 ].self_attn .num_key_value_heads
319
357
320
358
def forward (
321
359
self ,
0 commit comments