Skip to content

Commit 211b211

Browse files
authored
feat(server): add support for llamav2 (#633)
1 parent 3b71c38 commit 211b211

File tree

2 files changed

+58
-20
lines changed

2 files changed

+58
-20
lines changed

server/text_generation_server/models/custom_modeling/flash_llama_modeling.py

Lines changed: 57 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@
3939
TensorParallelEmbedding,
4040
PositionRotaryEmbedding,
4141
TensorParallelHead,
42+
get_linear,
4243
)
4344

4445

@@ -59,7 +60,8 @@ def forward(self, hidden_states, residual=None):
5960
hidden_states += residual
6061
residual = hidden_states
6162

62-
variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
63+
hidden_states = hidden_states.to(torch.float32)
64+
variance = hidden_states.pow(2).mean(-1, keepdim=True)
6365
hidden_states = hidden_states * torch.rsqrt(
6466
variance + self.variance_epsilon
6567
)
@@ -94,6 +96,27 @@ def forward(self, hidden_states, residual=None):
9496
return normed_hidden_states, res
9597

9698

99+
def _load_gqa(config, prefix: str, weights):
100+
w = [
101+
weights.get_sharded(f"{prefix}.q_proj.weight", dim=0),
102+
weights.get_sharded(f"{prefix}.k_proj.weight", dim=0),
103+
weights.get_sharded(f"{prefix}.v_proj.weight", dim=0),
104+
]
105+
weight = torch.cat(w, dim=0)
106+
weight = weight.to(dtype=weights.dtype).to(device=weights.device)
107+
bias = None
108+
assert config.hidden_size % config.num_attention_heads == 0
109+
head_size = config.hidden_size // config.num_attention_heads
110+
assert config.num_attention_heads % weights.process_group.size() == 0
111+
num_heads = config.num_attention_heads // weights.process_group.size()
112+
num_key_value_heads = config.num_key_value_heads // weights.process_group.size()
113+
assert list(weight.shape) == [
114+
(num_heads + 2 * num_key_value_heads) * head_size,
115+
config.hidden_size,
116+
], f"{list(weight.shape)} != {[(num_heads + 2 * config.num_key_value_heads) * head_size, config.hidden_size]}"
117+
return TensorParallelColumnLinear(get_linear(weight, bias, config.quantize))
118+
119+
97120
class FlashLlamaAttention(torch.nn.Module):
98121
def __init__(
99122
self,
@@ -118,22 +141,29 @@ def __init__(
118141
f"and `num_shards`: {weights.process_group.size()}"
119142
)
120143
self.num_heads = self.num_heads // weights.process_group.size()
121-
self.query_key_value = TensorParallelColumnLinear.load_multi(
122-
config,
123-
prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"],
124-
dim=0,
125-
weights=weights,
126-
bias=False,
144+
self.num_key_value_heads = (
145+
config.num_key_value_heads // weights.process_group.size()
127146
)
147+
if config.num_attention_heads != config.num_key_value_heads:
148+
self.query_key_value = _load_gqa(config, prefix, weights)
149+
else:
150+
self.query_key_value = TensorParallelColumnLinear.load_multi(
151+
config,
152+
prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"],
153+
dim=0,
154+
weights=weights,
155+
bias=False,
156+
)
128157
self.o_proj = TensorParallelRowLinear.load(
129158
config,
130159
prefix=f"{prefix}.o_proj",
131160
weights=weights,
132161
bias=False,
133162
)
163+
self.num_groups = self.num_heads // self.num_key_value_heads
134164
self.kv_head_mapping = torch.arange(
135-
0, self.num_heads, dtype=torch.int32, device=weights.device
136-
)
165+
0, self.num_key_value_heads, dtype=torch.int32, device=weights.device
166+
).repeat_interleave(self.num_groups)
137167

138168
def forward(
139169
self,
@@ -148,26 +178,33 @@ def forward(
148178
max_s,
149179
):
150180
qkv = self.query_key_value(hidden_states)
151-
qkv = qkv.view(-1, 3, self.num_heads, self.head_size)
181+
query, kv = qkv.split(
182+
[
183+
self.head_size * self.num_heads,
184+
2 * self.head_size * self.num_key_value_heads,
185+
],
186+
dim=1,
187+
)
188+
query = query.view(-1, self.num_heads, self.head_size)
189+
kv = kv.view(-1, 2, self.num_key_value_heads, self.head_size)
152190

153-
# Inplace rotary
154-
self.rotary_emb(qkv[:, 0], cos, sin)
155-
self.rotary_emb(qkv[:, 1], cos, sin)
191+
self.rotary_emb(query, cos, sin)
192+
self.rotary_emb(torch.select(kv, dim=1, index=0), cos, sin)
156193

157194
vllm_cache_ops.reshape_and_cache(
158-
qkv[:, 1], qkv[:, 2], kv_cache[0], kv_cache[1], slots
195+
kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots
159196
)
160197

161198
# output tensor
162-
attn_output = torch.empty_like(qkv[:, 0])
199+
attn_output = torch.empty_like(query)
163200

164201
# Prefill
165202
if cu_seqlen_prefill is not None:
166203
# flash attention
167204
attention(
168-
qkv[:, 0],
169-
qkv[:, 1],
170-
qkv[:, 2],
205+
query,
206+
torch.select(kv, dim=1, index=0),
207+
torch.select(kv, dim=1, index=1),
171208
attn_output,
172209
cu_seqlen_prefill,
173210
max_s,
@@ -179,7 +216,7 @@ def forward(
179216
block_size = kv_cache[1].shape[3]
180217
vllm_attention_ops.single_query_cached_kv_attention(
181218
attn_output,
182-
qkv[:, 0],
219+
query,
183220
kv_cache[0],
184221
kv_cache[1],
185222
self.kv_head_mapping,
@@ -316,6 +353,7 @@ def __init__(self, config, weights):
316353

317354
self.head_size = self.layers[0].self_attn.head_size
318355
self.num_heads = self.layers[0].self_attn.num_heads
356+
self.num_key_value_heads = self.layers[0].self_attn.num_key_value_heads
319357

320358
def forward(
321359
self,

server/text_generation_server/models/flash_llama.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@ def __init__(
6969
model=model,
7070
tokenizer=tokenizer,
7171
num_layers=len(model.model.layers),
72-
num_kv_heads=model.model.num_heads,
72+
num_kv_heads=model.model.num_key_value_heads,
7373
head_size=model.model.head_size,
7474
dtype=dtype,
7575
device=device,

0 commit comments

Comments
 (0)