Skip to content

Commit

Permalink
fixing possible issues
Browse files Browse the repository at this point in the history
  • Loading branch information
erfanzar committed Oct 23, 2024
1 parent 47d20cd commit 9d5909f
Show file tree
Hide file tree
Showing 4 changed files with 38 additions and 32 deletions.
8 changes: 5 additions & 3 deletions src/jax_flash_attn2/flash_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,9 +173,6 @@ def __call__(
f"Query heads ({num_q_heads}) must be divisible by "
f"key/value heads ({num_kv_heads})"
)

bias = self._handle_bias(bias, num_q_heads, num_kv_heads)

if self.config.platform == Platform.TRITON:
return self._compute_triton(query, key, value, bias)
elif self.config.platform == Platform.PALLAS:
Expand All @@ -192,6 +189,7 @@ def _compute_triton(
) -> chex.Array:
"""Computes attention using Triton backend."""
# fmt:off
bias = self._handle_bias(bias, query.shape[2], key.shape[2])
if query.shape[2] == key.shape[2] or os.environ.get("FORCE_MHA", "false") in ["true", "1", "on"]:
key, value = self.repeat_kv_heads(key, value, query.shape[2] // key.shape[2])

Expand Down Expand Up @@ -223,6 +221,8 @@ def _compute_pallas(
bias: Optional[chex.Array],
) -> chex.Array:
"""Computes attention using Pallas backend."""

bias = self._handle_bias(bias, query.shape[2], key.shape[2])
key, value = self.repeat_kv_heads(key, value, query.shape[2] // key.shape[2])

if self.config.backend == Backend.GPU:
Expand Down Expand Up @@ -272,6 +272,8 @@ def _compute_jax(
bias: Optional[chex.Array],
) -> chex.Array:
"""Computes attention using JAX backend."""

bias = self._handle_bias(bias, query.shape[2], key.shape[2])
key, value = self.repeat_kv_heads(key, value, query.shape[2] // key.shape[2])
return jax_flash_attn_2_mu(
query_state=query,
Expand Down
3 changes: 2 additions & 1 deletion src/jax_flash_attn2/triton_kernels/gqa_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -1115,6 +1115,7 @@ def _fwd_attn_kernel_call_with_residual(


@functools.partial(custom_vjp, nondiff_argnums=[4, 5, 6])
@functools.partial(jax.jit, static_argnums=[4, 5, 6])
def _flash_gqa_attn2(
query: chex.Array,
key: chex.Array,
Expand Down Expand Up @@ -1240,7 +1241,7 @@ def _test_forward():
def _test_backward():
"""Tests the backward pass of the attention mechanism."""
q_key, k_key, v_key = jrnd.split(jrnd.PRNGKey(8), 3)
B, QH, KVH, QS, KS, D = 1, 32, 32, 1024, 1024, 128
B, QH, KVH, QS, KS, D = 1, 32, 16, 1024, 1024, 128
blocksize_k = 16
blocksize_q = 16
q = jax.nn.initializers.normal(2)(q_key, (B, QS, QH, D), dtype=jnp.float16)
Expand Down
4 changes: 2 additions & 2 deletions src/jax_flash_attn2/triton_kernels/mha_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -1038,6 +1038,7 @@ def _fwd_attn_kernel_call_with_residual(


@functools.partial(custom_vjp, nondiff_argnums=[4, 5, 6])
@functools.partial(jax.jit, static_argnums=[4, 5, 6])
def _flash_attn2(
query: chex.Array,
key: chex.Array,
Expand Down Expand Up @@ -1076,7 +1077,6 @@ def _flash_attn2(
_fwd_attn_kernel_call_with_residual,
_bwd_attn_kernel_call,
)

triton_flash_mha_attn_2_gpu = _flash_attn2
__all__ = ["triton_flash_mha_attn_2_gpu"]

Expand Down Expand Up @@ -1130,7 +1130,7 @@ def _test_forward():
"""Tests the forward pass of the attention mechanism."""
q_key, k_key, v_key = jrnd.split(jrnd.PRNGKey(8), 3)
q_key, k_key, v_key = jrnd.split(jrnd.PRNGKey(8), 3)
B, QH, KVH, QS, KS, D = 1, 32, 8, 1024, 1024, 128
B, QH, KVH, QS, KS, D = 1, 32, 32, 1024, 1024, 128
blocksize_k = 64
blocksize_q = 128
q = jax.nn.initializers.normal(2)(q_key, (B, QS, QH, D), dtype=jnp.float16)
Expand Down
55 changes: 29 additions & 26 deletions tests/test_triton.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,20 +5,19 @@
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = "1.0"
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
os.environ["JAX_TRACEBACK_FILTERING"] = "off"
os.environ["XLA_FLAGS"] = "--xla_gpu_enable_command_buffer="

sys.path.append(os.path.join(os.path.dirname(os.path.abspath(__file__)), "../src"))

import jax
from jax import numpy as jnp
from jax import random as jrnd

from jax_flash_attn2 import get_cached_flash_attention
from jax_flash_attn2 import create_flash_attention

USE_BIAS = True


def _attn_refrence(query_states, key_states, value_states, bias):
def _gqa_attn_refrence(query_states, key_states, value_states, bias):
b, qs, num_q_heads, d = query_states.shape
num_kv_heads = value_states.shape[2]
ks = value_states.shape[1]
Expand Down Expand Up @@ -63,13 +62,32 @@ def _attn_refrence(query_states, key_states, value_states, bias):
)


def _mha_attn_refrence(query_states, key_states, value_states, bias):
d = query_states.shape[-1]

attention_weight = jnp.einsum("bqhd,bkhd->bhqk", query_states * (d**-0.5), key_states)

if bias is not None:
attention_weight = jnp.add(attention_weight, bias)
attention_weight = jax.nn.softmax(attention_weight)

return jnp.einsum("bhqk,bkhd->bqhd", attention_weight, value_states)


flash_attn = create_flash_attention(
backend="gpu",
platform="triton",
blocksize_q=64,
blocksize_k=64,
softmax_scale=None,
)


def test_forward():
"""Tests the forward pass of the attention mechanism."""
q_key, k_key, v_key = jrnd.split(jrnd.PRNGKey(8), 3)
q_key, k_key, v_key = jrnd.split(jrnd.PRNGKey(8), 3)
B, QH, KVH, QS, KS, D = 1, 32, 8, 1024, 1024, 128
blocksize_k = 64
blocksize_q = 128
q = jax.nn.initializers.normal(2)(q_key, (B, QS, QH, D), dtype=jnp.float16)
k = jax.nn.initializers.normal(2)(k_key, (B, KS, KVH, D), dtype=jnp.float16)
v = jax.nn.initializers.normal(2)(v_key, (B, KS, KVH, D), dtype=jnp.float16)
Expand All @@ -82,27 +100,19 @@ def test_forward():
if USE_BIAS
else None
)
flash_attn = get_cached_flash_attention(
backend="gpu",
platform="triton",
blocksize_q=blocksize_q,
blocksize_k=blocksize_k,
softmax_scale=None,
)
print("QKV Allocated")
co = flash_attn(q, k, v, b) # passes 256K on 24G GPU 3090
print(co[-1, -1, -1, :5])
fo = _attn_refrence(q, k, v, b)
fo = _gqa_attn_refrence(q, k, v, b)
print(fo[-1, -1, -1, :5])
print("Results are Close" if jnp.allclose(co, fo, 0, 0.125) else "Wrong results!")


def test_backward():
"""Tests the backward pass of the attention mechanism."""
"""Tests the backward pass of the attention mechanism."""

q_key, k_key, v_key = jrnd.split(jrnd.PRNGKey(8), 3)
B, QH, KVH, QS, KS, D = 1, 32, 32, 1024, 1024, 128
blocksize_k = 16
blocksize_q = 16
q = jax.nn.initializers.normal(2)(q_key, (B, QS, QH, D), dtype=jnp.float16)
k = jax.nn.initializers.normal(2)(k_key, (B, KS, KVH, D), dtype=jnp.float16)
v = jax.nn.initializers.normal(2)(v_key, (B, KS, KVH, D), dtype=jnp.float16)
Expand All @@ -116,25 +126,18 @@ def test_backward():
else None
)

flash_attn = get_cached_flash_attention(
backend="gpu",
platform="triton",
blocksize_q=blocksize_q,
blocksize_k=blocksize_k,
softmax_scale=None,
)
try:
co = jax.grad(lambda *x: flash_attn(*x).sum())(q, k, v, b)
print("Custom op backward pass gradients:")
print(co[-1][-1, -1, :5]) # Print last 5 elements of last head of last batch
print(co[-1, -1, -1, :5]) # Print last 5 elements of last head of last batch
except Exception as er:
print(f"Custom op backward pass failed: {er}")
co = None

try:
fo = jax.grad(lambda *x: _attn_refrence(*x).sum())(q, k, v, b)
fo = jax.grad(lambda *x: _mha_attn_refrence(*x).sum())(q, k, v, b)

print(fo[-1, -1, -1, :5]) # Print last 5 elements of last head of last batch
print(fo[-1, -1, -1, :5])
except Exception as e:
print(f"Flax backward pass failed : {e}")
fo = None
Expand Down

0 comments on commit 9d5909f

Please sign in to comment.