Skip to content

Commit

Permalink
Add Pallas GPU decode attention in Maxtext inference
Browse files Browse the repository at this point in the history
  • Loading branch information
tohaowu committed Nov 26, 2024
1 parent f29bf3a commit 2a92879
Show file tree
Hide file tree
Showing 2 changed files with 61 additions and 1 deletion.
60 changes: 60 additions & 0 deletions MaxText/layers/attentions.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@
from jax import lax
from jax.ad_checkpoint import checkpoint_name
from jax.experimental import shard_map
from jax.experimental.pallas.ops.gpu import attention as pallas_attention
from jax.experimental.pallas.ops.gpu import decode_attention as pallas_decode_attention
from jax.experimental.pallas.ops.tpu.splash_attention import splash_attention_kernel
from jax.experimental.pallas.ops.tpu.splash_attention import splash_attention_mask
import jax.numpy as jnp
Expand Down Expand Up @@ -255,9 +257,67 @@ def apply_attention(
Use `dot_product` instead."""
)
return self.cudnn_flash_attention(query, key, value, decoder_segment_ids, model_mode), None, None
elif self.attention_kernel == "pallas_gpu":
if model_mode == common_types.MODEL_MODE_AUTOREGRESSIVE:
return self.gqa_wrapped_maxtext_compatible(query, key, value)
key = jnp.repeat(key, self.num_query_heads // self.num_kv_heads, axis=2)
value = jnp.repeat(value, self.num_query_heads // self.num_kv_heads, axis=2)
return self.pallas_attention_kernel(query, key, value, decoder_segment_ids, model_mode)
else:
raise ValueError(f"Unexpected attention kernel {self.attention_kernel=}.")

def gqa_wrapped_maxtext_compatible(
self,
q, k, v,
start_idx=None, kv_seq_len=None,
sm_scale: float | None = None,
block_h: int = 16,
block_k: int = 128,
k_splits: int = 16,
num_warps: int | None = None,
num_stages: int = 2,
grid: tuple[int, ...] | None = None,
interpret: bool = False,
debug: bool = False,
):
sm_scale = sm_scale if sm_scale is not None else (1 / math.sqrt(q.shape[-1]))
batch_size, q_length, q_heads, head_dim = q.shape
k_seq_len, kv_heads = k.shape[1], k.shape[2]
assert q_heads % kv_heads == 0

q_reshaped = q.reshape(batch_size, q_length, kv_heads, q_heads // kv_heads, head_dim)
k_transposed = jnp.swapaxes(k, 1, 2)

# Re-compute attention weights to get local_max and local_sum
attn_weights = (q_reshaped @ k_transposed.transpose(0, 1, 3, 2)) * sm_scale
local_max = jnp.max(attn_weights, axis=-1, keepdims=True)
local_exps = jnp.exp(attn_weights - local_max)
local_sum = jnp.sum(local_exps, axis=-1, keepdims=True)

# Reshape q to match gqa's expected shape
q_for_gqa = q.reshape(batch_size * q_length, q_heads, head_dim)

# Use the original gqa function to get the attention output
local_out_gqa = pallas_decode_attention.gqa(
q_for_gqa, k, v, start_idx, kv_seq_len, sm_scale, block_h, block_k,
k_splits, num_warps, num_stages, grid, interpret, debug
)

# Reshape gqa's output to include q_length
local_out = local_out_gqa.reshape(batch_size, q_length, q_heads, head_dim)

# Reshape local_max and local_sum to match Maxtext requirements
local_max = local_max.reshape(batch_size, q_length, q_heads, 1)
local_sum = local_sum.reshape(batch_size, q_length, q_heads, 1)

return local_out, local_max, local_sum

def pallas_attention_kernel(self, query, key, value, decoder_segment_ids, model_mode):
"""Pallas MHA kernel for prefill stage."""
sm_scale = 1.0 / math.sqrt(query.shape[-1])
out = pallas_attention.mha(query, key, value, decoder_segment_ids, sm_scale=sm_scale, causal=True if model_mode == common_types.MODEL_MODE_AUTOREGRESSIVE else False) # Set causal=True for autoregressive mode
return out, None, None

def ragged_attention(
self, query: Array, key: Array | KVTensor, value: Array | KVTensor, lengths: Array, block_size: int
) -> tuple[Array, Array, Array]:
Expand Down
2 changes: 1 addition & 1 deletion MaxText/pyconfig.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def validate_kv_quant_axis(s: str, quantize_kvcache: bool) -> None:


def validate_attention_kernel(s: str) -> None:
valid_attention_kernels = ("autoselected", "dot_product", "flash", "cudnn_flash_te")
valid_attention_kernels = ("autoselected", "dot_product", "flash", "cudnn_flash_te", "pallas_gpu")
if s not in valid_attention_kernels: # currently supported attention
raise ValueError("Invalid attention kernel was passed. Valid options ", valid_attention_kernels)

Expand Down

0 comments on commit 2a92879

Please sign in to comment.