Skip to content

Commit

Permalink
feat(server): add support for llamav2 (#633)
Browse files Browse the repository at this point in the history
  • Loading branch information
Narsil authored Jul 18, 2023
1 parent 3b71c38 commit 211b211
Show file tree
Hide file tree
Showing 2 changed files with 58 additions and 20 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
TensorParallelEmbedding,
PositionRotaryEmbedding,
TensorParallelHead,
get_linear,
)


Expand All @@ -59,7 +60,8 @@ def forward(self, hidden_states, residual=None):
hidden_states += residual
residual = hidden_states

variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states.to(torch.float32)
variance = hidden_states.pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(
variance + self.variance_epsilon
)
Expand Down Expand Up @@ -94,6 +96,27 @@ def forward(self, hidden_states, residual=None):
return normed_hidden_states, res


def _load_gqa(config, prefix: str, weights):
w = [
weights.get_sharded(f"{prefix}.q_proj.weight", dim=0),
weights.get_sharded(f"{prefix}.k_proj.weight", dim=0),
weights.get_sharded(f"{prefix}.v_proj.weight", dim=0),
]
weight = torch.cat(w, dim=0)
weight = weight.to(dtype=weights.dtype).to(device=weights.device)
bias = None
assert config.hidden_size % config.num_attention_heads == 0
head_size = config.hidden_size // config.num_attention_heads
assert config.num_attention_heads % weights.process_group.size() == 0
num_heads = config.num_attention_heads // weights.process_group.size()
num_key_value_heads = config.num_key_value_heads // weights.process_group.size()
assert list(weight.shape) == [
(num_heads + 2 * num_key_value_heads) * head_size,
config.hidden_size,
], f"{list(weight.shape)} != {[(num_heads + 2 * config.num_key_value_heads) * head_size, config.hidden_size]}"
return TensorParallelColumnLinear(get_linear(weight, bias, config.quantize))


class FlashLlamaAttention(torch.nn.Module):
def __init__(
self,
Expand All @@ -118,22 +141,29 @@ def __init__(
f"and `num_shards`: {weights.process_group.size()}"
)
self.num_heads = self.num_heads // weights.process_group.size()
self.query_key_value = TensorParallelColumnLinear.load_multi(
config,
prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"],
dim=0,
weights=weights,
bias=False,
self.num_key_value_heads = (
config.num_key_value_heads // weights.process_group.size()
)
if config.num_attention_heads != config.num_key_value_heads:
self.query_key_value = _load_gqa(config, prefix, weights)
else:
self.query_key_value = TensorParallelColumnLinear.load_multi(
config,
prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"],
dim=0,
weights=weights,
bias=False,
)
self.o_proj = TensorParallelRowLinear.load(
config,
prefix=f"{prefix}.o_proj",
weights=weights,
bias=False,
)
self.num_groups = self.num_heads // self.num_key_value_heads
self.kv_head_mapping = torch.arange(
0, self.num_heads, dtype=torch.int32, device=weights.device
)
0, self.num_key_value_heads, dtype=torch.int32, device=weights.device
).repeat_interleave(self.num_groups)

def forward(
self,
Expand All @@ -148,26 +178,33 @@ def forward(
max_s,
):
qkv = self.query_key_value(hidden_states)
qkv = qkv.view(-1, 3, self.num_heads, self.head_size)
query, kv = qkv.split(
[
self.head_size * self.num_heads,
2 * self.head_size * self.num_key_value_heads,
],
dim=1,
)
query = query.view(-1, self.num_heads, self.head_size)
kv = kv.view(-1, 2, self.num_key_value_heads, self.head_size)

# Inplace rotary
self.rotary_emb(qkv[:, 0], cos, sin)
self.rotary_emb(qkv[:, 1], cos, sin)
self.rotary_emb(query, cos, sin)
self.rotary_emb(torch.select(kv, dim=1, index=0), cos, sin)

vllm_cache_ops.reshape_and_cache(
qkv[:, 1], qkv[:, 2], kv_cache[0], kv_cache[1], slots
kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots
)

# output tensor
attn_output = torch.empty_like(qkv[:, 0])
attn_output = torch.empty_like(query)

# Prefill
if cu_seqlen_prefill is not None:
# flash attention
attention(
qkv[:, 0],
qkv[:, 1],
qkv[:, 2],
query,
torch.select(kv, dim=1, index=0),
torch.select(kv, dim=1, index=1),
attn_output,
cu_seqlen_prefill,
max_s,
Expand All @@ -179,7 +216,7 @@ def forward(
block_size = kv_cache[1].shape[3]
vllm_attention_ops.single_query_cached_kv_attention(
attn_output,
qkv[:, 0],
query,
kv_cache[0],
kv_cache[1],
self.kv_head_mapping,
Expand Down Expand Up @@ -316,6 +353,7 @@ def __init__(self, config, weights):

self.head_size = self.layers[0].self_attn.head_size
self.num_heads = self.layers[0].self_attn.num_heads
self.num_key_value_heads = self.layers[0].self_attn.num_key_value_heads

def forward(
self,
Expand Down
2 changes: 1 addition & 1 deletion server/text_generation_server/models/flash_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ def __init__(
model=model,
tokenizer=tokenizer,
num_layers=len(model.model.layers),
num_kv_heads=model.model.num_heads,
num_kv_heads=model.model.num_key_value_heads,
head_size=model.model.head_size,
dtype=dtype,
device=device,
Expand Down

0 comments on commit 211b211

Please sign in to comment.