Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Flash Attention for Neuron #883

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
124 changes: 124 additions & 0 deletions axlearn/common/flash_attention/neuron_attention.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
from functools import partial
import jax
import jax.numpy as jnp
from jax import custom_vjp

lnc = 2 if jax.devices()[0].device_kind == "NC_v3d" else 1

@partial(custom_vjp, nondiff_argnums=(4, 5))
def flash_attention(query, key, value, bias, causal, softmax_scale):
out, _ = _mha_forward(query, key, value, bias, causal, softmax_scale)
return out


def _mha_forward(query, key, value, bias, causal, softmax_scale):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we get a support for segment ID and dropout as well? Both are quite needed nowadays.

# Get the batch size, sequence lengths, number of heads, and hidden dimension
batch_size, q_seq_len, num_heads, d_model = query.shape

# Transpose the query, key, and value tensors
q = query.transpose(0, 2, 3, 1) # [batch_size, num_heads, d_model, q_seq_len]
k = key.transpose(0, 2, 3, 1) # [batch_size, num_heads, d_model, kv_seq_len]
v = value.transpose(0, 2, 1, 3) # [batch_size, num_heads, kv_seq_len, d_model]

import neuronxcc.nki.language as nl
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please add pylint here

from neuronxcc.nki.kernels.attention import flash_fwd
seed = jnp.array([1])

# Call the NKI kernel, duplicate the kernel if we cannot shard on num_heads
if (num_heads % 2) == 0 and (num_heads // 2 > 0):
grid = batch_size, nl.nc(lnc) * (num_heads // lnc)
else:
grid = batch_size, num_heads

if bias != None:
assert bias.ndim == 4, f"Neuron flash_attention is only expecting bias.ndim = 4 but got {bias.ndim}"
attn_output, lse = flash_fwd[grid](
q,
k,
v,
seed,
bias,
use_causal_mask=causal,
softmax_scale=softmax_scale,
mixed_precision=True,
dropout_p=0.0,
)
else:
attn_output, lse = flash_fwd[grid](
q,
k,
v,
seed,
use_causal_mask=causal,
softmax_scale=softmax_scale,
mixed_precision=True,
dropout_p=0.0,
)
# Transpose the output back to the original shape
attn_output = attn_output.transpose(0, 2, 1, 3) # [batch_size, q_seq_len, num_heads, d_model]

return attn_output, (lse, attn_output, q, k, v, bias)


def _mha_backward(causal, softmax_scale, res, d_attn_output):
lse, o, q, k, v, bias = res
batch_size, num_heads, d_model, seq_len = q.shape

# Transpose the input tensors
o = o.transpose(0, 2, 3, 1)
dy = d_attn_output.transpose(0, 2, 3, 1)

# Transpose v tensor
v = jnp.transpose(v, axes=(0, 1, 3, 2))
seed = jnp.array([1])

from neuronxcc.nki.kernels.attention import flash_attn_bwd
import neuronxcc.nki.language as nl

# Call the NKI kernel, duplicate the kernel if we cannot shard on num_heads
if (num_heads % 2) == 0 and (num_heads // 2 > 0):
grid = batch_size, nl.nc(lnc) * (num_heads // lnc)
else:
grid = batch_size, num_heads

if bias != None:
assert bias.ndim == 4, f"Neuron flash_attention is only expecting bias.ndim = 4 but got {bias.ndim}"
d_query, d_key, d_value = flash_attn_bwd[grid](
q,
k,
v,
o,
dy,
lse,
seed,
bias,
use_causal_mask=causal,
mixed_precision=True,
dropout_p=0.0,
softmax_scale=softmax_scale,
)
else:
d_query, d_key, d_value = flash_attn_bwd[grid](
q,
k,
v,
o,
dy,
lse,
seed,
use_causal_mask=causal,
mixed_precision=True,
dropout_p=0.0,
softmax_scale=softmax_scale,
)

# Batch seq_len heads, head_dim
# Transpose the gradients back to the original shape
d_query = d_query.transpose(0, 3, 1, 2)
d_key = d_key.transpose(0, 3, 1, 2)
d_value = d_value.transpose(0, 3, 1, 2)

return d_query, d_key, d_value, None


flash_attention.defvjp(_mha_forward, _mha_backward)
132 changes: 132 additions & 0 deletions axlearn/common/flash_attention/neuron_attention_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,132 @@
# Copyright © 2024 Amazon Inc.
"""Tests for Flash attention on Neuron. Tested on trn1 & trn2."""
import functools

import chex
import jax
import jax.numpy as jnp
import pytest

from axlearn.common.flash_attention.neuron_attention import flash_attention
from axlearn.common.flash_attention.utils import mha_reference


if jax.default_backend() != "neuron":
pytestmark = pytest.mark.skip(reason="Incompatible hardware, AWS Neuron only test.")


@pytest.mark.parametrize(
"batch_size,seq_len,num_heads,per_head_dim",
[
(1, 2048, 1, 64),
(2, 2048, 2, 64),
(1, 2048, 1, 128),
(2, 2048, 2, 128),
(1, 2048, 8, 128),
(2, 2048, 8, 128),
],
)
@pytest.mark.parametrize("use_fwd", [True, False])
@pytest.mark.parametrize("causal", [True, False])
@pytest.mark.parametrize("input_dtype", [jnp.float16, jnp.bfloat16, jnp.float32])
def test_fwd_against_ref(
batch_size: int,
seq_len: int,
num_heads: int,
per_head_dim: int,
use_fwd: bool,
causal: bool,
input_dtype: jnp.dtype,
):
sm_scale = 1.0 / (per_head_dim**0.5)
k1, k2, k3 = jax.random.split(jax.random.PRNGKey(0), 3)
q = jax.random.normal(k1, (batch_size, seq_len, num_heads, per_head_dim), dtype=input_dtype)
k = jax.random.normal(k2, (batch_size, seq_len, num_heads, per_head_dim), dtype=input_dtype)
v = jax.random.normal(k3, (batch_size, seq_len, num_heads, per_head_dim), dtype=input_dtype)

bias = None
segment_ids = None

if use_fwd:

@jax.jit
def impl(q, k, v, bias):
fn = functools.partial(
flash_attention,
causal=causal,
softmax_scale=sm_scale,
)
out, _ = jax.vjp(fn, q, k, v, bias)
return out

else:
impl = functools.partial(
flash_attention,
causal=causal,
softmax_scale=sm_scale,
)

o = impl(q, k, v, bias)
o_ref = mha_reference(q, k, v, bias, segment_ids, causal=causal, softmax_scale=sm_scale)
chex.assert_trees_all_close(o, o_ref, atol=0.05)


@pytest.mark.parametrize(
"batch_size,num_heads,seq_len,per_head_dim",
[
(1, 1, 2048, 64),
(2, 2, 2048, 64),
(1, 1, 2048, 128),
(2, 2, 2048, 128),
(1, 8, 2048, 128),
(2, 8, 2048, 128),
],
)
@pytest.mark.parametrize("causal", [True, False])
@pytest.mark.parametrize("input_dtype", [jnp.bfloat16, jnp.float16, jnp.float32])
def test_bwd_against_ref(
batch_size: int,
num_heads: int,
seq_len: int,
per_head_dim: int,
causal: bool,
input_dtype: jnp.dtype,
):
sm_scale = 1.0 / (per_head_dim**0.5)
q = jax.random.normal(
jax.random.PRNGKey(0), (batch_size, seq_len, num_heads, per_head_dim), dtype=input_dtype
)
k = jax.random.normal(
jax.random.PRNGKey(1), (batch_size, seq_len, num_heads, per_head_dim), dtype=input_dtype
)
v = jax.random.normal(
jax.random.PRNGKey(2), (batch_size, seq_len, num_heads, per_head_dim), dtype=input_dtype
)

bias = None
segment_ids = None

def fn(q, k, v, bias):
return flash_attention(
q,
k,
v,
bias,
causal=causal,
softmax_scale=sm_scale,
).sum()

def ref_fn(q, k, v, bias, segment_ids):
return mha_reference(
q,
k,
v,
bias,
segment_ids,
causal=causal,
softmax_scale=sm_scale,
).sum()

jax_grads = jax.grad(fn, argnums=(0, 1, 2))(q, k, v, bias)
jax_ref_grads = jax.grad(ref_fn, argnums=(0, 1, 2))(q, k, v, bias, segment_ids)
chex.assert_trees_all_close(jax_grads, jax_ref_grads, atol=0.07)
17 changes: 16 additions & 1 deletion axlearn/common/flash_attention/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ def mha_reference(


def flash_attention_implementation(
backend: Literal["cpu", "tpu", "gpu", "xla"],
backend: Literal["cpu", "tpu", "gpu", "xla", "neuron"],
*,
mask: Optional[MaskFn] = None,
softmax_scale: float,
Expand Down Expand Up @@ -159,6 +159,21 @@ def jit_attn(query, key, value, bias, segment_ids):

return jit_attn

elif backend == "neuron":
from axlearn.common.flash_attention.neuron_attention import (
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

On demand import is kind of risky, we can live with it for functions inside neuron_attention.py, can we at least get it as a header import for files outside of neuron_attention.py?

flash_attention as neuron_flash_attention,
)

# shard_map-decorated function needs to be jitted.
@jax.jit
def jit_attn(query, key, value, bias, segment_ids):
if segment_ids != None:
raise Exception("Sequence Packing is not supported on Neuron backend")
return neuron_flash_attention(
query, key, value, bias, causal, softmax_scale)

return jit_attn

elif backend in ("cpu", "xla"):
if backend == "cpu":
logging.warning("Flash attention CPU backend is for testing only.")
Expand Down