diff --git a/MaxText/layers/attentions.py b/MaxText/layers/attentions.py index 57f7afce..e52df2e7 100644 --- a/MaxText/layers/attentions.py +++ b/MaxText/layers/attentions.py @@ -24,6 +24,8 @@ from jax import lax from jax.ad_checkpoint import checkpoint_name from jax.experimental import shard_map +from jax.experimental.pallas.ops.gpu import attention as pallas_attention +from jax.experimental.pallas.ops.gpu import decode_attention as pallas_decode_attention from jax.experimental.pallas.ops.tpu.splash_attention import splash_attention_kernel from jax.experimental.pallas.ops.tpu.splash_attention import splash_attention_mask import jax.numpy as jnp @@ -255,9 +257,67 @@ def apply_attention( Use `dot_product` instead.""" ) return self.cudnn_flash_attention(query, key, value, decoder_segment_ids, model_mode), None, None + elif self.attention_kernel == "pallas_gpu": + if model_mode == common_types.MODEL_MODE_AUTOREGRESSIVE: + return self.gqa_wrapped_maxtext_compatible(query, key, value) + key = jnp.repeat(key, self.num_query_heads // self.num_kv_heads, axis=2) + value = jnp.repeat(value, self.num_query_heads // self.num_kv_heads, axis=2) + return self.pallas_attention_kernel(query, key, value, decoder_segment_ids, model_mode) else: raise ValueError(f"Unexpected attention kernel {self.attention_kernel=}.") + def gqa_wrapped_maxtext_compatible( + self, + q, k, v, + start_idx=None, kv_seq_len=None, + sm_scale: float | None = None, + block_h: int = 16, + block_k: int = 128, + k_splits: int = 16, + num_warps: int | None = None, + num_stages: int = 2, + grid: tuple[int, ...] | None = None, + interpret: bool = False, + debug: bool = False, + ): + sm_scale = sm_scale if sm_scale is not None else (1 / math.sqrt(q.shape[-1])) + batch_size, q_length, q_heads, head_dim = q.shape + k_seq_len, kv_heads = k.shape[1], k.shape[2] + assert q_heads % kv_heads == 0 + + q_reshaped = q.reshape(batch_size, q_length, kv_heads, q_heads // kv_heads, head_dim) + k_transposed = jnp.swapaxes(k, 1, 2) + + # Re-compute attention weights to get local_max and local_sum + attn_weights = (q_reshaped @ k_transposed.transpose(0, 1, 3, 2)) * sm_scale + local_max = jnp.max(attn_weights, axis=-1, keepdims=True) + local_exps = jnp.exp(attn_weights - local_max) + local_sum = jnp.sum(local_exps, axis=-1, keepdims=True) + + # Reshape q to match gqa's expected shape + q_for_gqa = q.reshape(batch_size * q_length, q_heads, head_dim) + + # Use the original gqa function to get the attention output + local_out_gqa = pallas_decode_attention.gqa( + q_for_gqa, k, v, start_idx, kv_seq_len, sm_scale, block_h, block_k, + k_splits, num_warps, num_stages, grid, interpret, debug + ) + + # Reshape gqa's output to include q_length + local_out = local_out_gqa.reshape(batch_size, q_length, q_heads, head_dim) + + # Reshape local_max and local_sum to match Maxtext requirements + local_max = local_max.reshape(batch_size, q_length, q_heads, 1) + local_sum = local_sum.reshape(batch_size, q_length, q_heads, 1) + + return local_out, local_max, local_sum + + def pallas_attention_kernel(self, query, key, value, decoder_segment_ids, model_mode): + """Pallas MHA kernel for prefill stage.""" + sm_scale = 1.0 / math.sqrt(query.shape[-1]) + out = pallas_attention.mha(query, key, value, decoder_segment_ids, sm_scale=sm_scale, causal=True if model_mode == common_types.MODEL_MODE_AUTOREGRESSIVE else False) # Set causal=True for autoregressive mode + return out, None, None + def ragged_attention( self, query: Array, key: Array | KVTensor, value: Array | KVTensor, lengths: Array, block_size: int ) -> tuple[Array, Array, Array]: diff --git a/MaxText/pyconfig.py b/MaxText/pyconfig.py index 5adc8819..e006aab0 100644 --- a/MaxText/pyconfig.py +++ b/MaxText/pyconfig.py @@ -68,7 +68,7 @@ def validate_kv_quant_axis(s: str, quantize_kvcache: bool) -> None: def validate_attention_kernel(s: str) -> None: - valid_attention_kernels = ("autoselected", "dot_product", "flash", "cudnn_flash_te") + valid_attention_kernels = ("autoselected", "dot_product", "flash", "cudnn_flash_te", "pallas_gpu") if s not in valid_attention_kernels: # currently supported attention raise ValueError("Invalid attention kernel was passed. Valid options ", valid_attention_kernels)