-
Notifications
You must be signed in to change notification settings - Fork 276
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Flash Attention for Neuron #883
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,124 @@ | ||
from functools import partial | ||
import jax | ||
import jax.numpy as jnp | ||
from jax import custom_vjp | ||
|
||
lnc = 2 if jax.devices()[0].device_kind == "NC_v3d" else 1 | ||
|
||
@partial(custom_vjp, nondiff_argnums=(4, 5)) | ||
def flash_attention(query, key, value, bias, causal, softmax_scale): | ||
out, _ = _mha_forward(query, key, value, bias, causal, softmax_scale) | ||
return out | ||
|
||
|
||
def _mha_forward(query, key, value, bias, causal, softmax_scale): | ||
# Get the batch size, sequence lengths, number of heads, and hidden dimension | ||
batch_size, q_seq_len, num_heads, d_model = query.shape | ||
|
||
# Transpose the query, key, and value tensors | ||
q = query.transpose(0, 2, 3, 1) # [batch_size, num_heads, d_model, q_seq_len] | ||
k = key.transpose(0, 2, 3, 1) # [batch_size, num_heads, d_model, kv_seq_len] | ||
v = value.transpose(0, 2, 1, 3) # [batch_size, num_heads, kv_seq_len, d_model] | ||
|
||
import neuronxcc.nki.language as nl | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Please add pylint here |
||
from neuronxcc.nki.kernels.attention import flash_fwd | ||
seed = jnp.array([1]) | ||
|
||
# Call the NKI kernel, duplicate the kernel if we cannot shard on num_heads | ||
if (num_heads % 2) == 0 and (num_heads // 2 > 0): | ||
grid = batch_size, nl.nc(lnc) * (num_heads // lnc) | ||
else: | ||
grid = batch_size, num_heads | ||
|
||
if bias != None: | ||
assert bias.ndim == 4, f"Neuron flash_attention is only expecting bias.ndim = 4 but got {bias.ndim}" | ||
attn_output, lse = flash_fwd[grid]( | ||
q, | ||
k, | ||
v, | ||
seed, | ||
bias, | ||
use_causal_mask=causal, | ||
softmax_scale=softmax_scale, | ||
mixed_precision=True, | ||
dropout_p=0.0, | ||
) | ||
else: | ||
attn_output, lse = flash_fwd[grid]( | ||
q, | ||
k, | ||
v, | ||
seed, | ||
use_causal_mask=causal, | ||
softmax_scale=softmax_scale, | ||
mixed_precision=True, | ||
dropout_p=0.0, | ||
) | ||
# Transpose the output back to the original shape | ||
attn_output = attn_output.transpose(0, 2, 1, 3) # [batch_size, q_seq_len, num_heads, d_model] | ||
|
||
return attn_output, (lse, attn_output, q, k, v, bias) | ||
|
||
|
||
def _mha_backward(causal, softmax_scale, res, d_attn_output): | ||
lse, o, q, k, v, bias = res | ||
batch_size, num_heads, d_model, seq_len = q.shape | ||
|
||
# Transpose the input tensors | ||
o = o.transpose(0, 2, 3, 1) | ||
dy = d_attn_output.transpose(0, 2, 3, 1) | ||
|
||
# Transpose v tensor | ||
v = jnp.transpose(v, axes=(0, 1, 3, 2)) | ||
seed = jnp.array([1]) | ||
|
||
from neuronxcc.nki.kernels.attention import flash_attn_bwd | ||
import neuronxcc.nki.language as nl | ||
|
||
# Call the NKI kernel, duplicate the kernel if we cannot shard on num_heads | ||
if (num_heads % 2) == 0 and (num_heads // 2 > 0): | ||
grid = batch_size, nl.nc(lnc) * (num_heads // lnc) | ||
else: | ||
grid = batch_size, num_heads | ||
|
||
if bias != None: | ||
assert bias.ndim == 4, f"Neuron flash_attention is only expecting bias.ndim = 4 but got {bias.ndim}" | ||
d_query, d_key, d_value = flash_attn_bwd[grid]( | ||
q, | ||
k, | ||
v, | ||
o, | ||
dy, | ||
lse, | ||
seed, | ||
bias, | ||
use_causal_mask=causal, | ||
mixed_precision=True, | ||
dropout_p=0.0, | ||
softmax_scale=softmax_scale, | ||
) | ||
else: | ||
d_query, d_key, d_value = flash_attn_bwd[grid]( | ||
q, | ||
k, | ||
v, | ||
o, | ||
dy, | ||
lse, | ||
seed, | ||
use_causal_mask=causal, | ||
mixed_precision=True, | ||
dropout_p=0.0, | ||
softmax_scale=softmax_scale, | ||
) | ||
|
||
# Batch seq_len heads, head_dim | ||
# Transpose the gradients back to the original shape | ||
d_query = d_query.transpose(0, 3, 1, 2) | ||
d_key = d_key.transpose(0, 3, 1, 2) | ||
d_value = d_value.transpose(0, 3, 1, 2) | ||
|
||
return d_query, d_key, d_value, None | ||
|
||
|
||
flash_attention.defvjp(_mha_forward, _mha_backward) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,132 @@ | ||
# Copyright © 2024 Amazon Inc. | ||
"""Tests for Flash attention on Neuron. Tested on trn1 & trn2.""" | ||
import functools | ||
|
||
import chex | ||
import jax | ||
import jax.numpy as jnp | ||
import pytest | ||
|
||
from axlearn.common.flash_attention.neuron_attention import flash_attention | ||
from axlearn.common.flash_attention.utils import mha_reference | ||
|
||
|
||
if jax.default_backend() != "neuron": | ||
pytestmark = pytest.mark.skip(reason="Incompatible hardware, AWS Neuron only test.") | ||
|
||
|
||
@pytest.mark.parametrize( | ||
"batch_size,seq_len,num_heads,per_head_dim", | ||
[ | ||
(1, 2048, 1, 64), | ||
(2, 2048, 2, 64), | ||
(1, 2048, 1, 128), | ||
(2, 2048, 2, 128), | ||
(1, 2048, 8, 128), | ||
(2, 2048, 8, 128), | ||
], | ||
) | ||
@pytest.mark.parametrize("use_fwd", [True, False]) | ||
@pytest.mark.parametrize("causal", [True, False]) | ||
@pytest.mark.parametrize("input_dtype", [jnp.float16, jnp.bfloat16, jnp.float32]) | ||
def test_fwd_against_ref( | ||
batch_size: int, | ||
seq_len: int, | ||
num_heads: int, | ||
per_head_dim: int, | ||
use_fwd: bool, | ||
causal: bool, | ||
input_dtype: jnp.dtype, | ||
): | ||
sm_scale = 1.0 / (per_head_dim**0.5) | ||
k1, k2, k3 = jax.random.split(jax.random.PRNGKey(0), 3) | ||
q = jax.random.normal(k1, (batch_size, seq_len, num_heads, per_head_dim), dtype=input_dtype) | ||
k = jax.random.normal(k2, (batch_size, seq_len, num_heads, per_head_dim), dtype=input_dtype) | ||
v = jax.random.normal(k3, (batch_size, seq_len, num_heads, per_head_dim), dtype=input_dtype) | ||
|
||
bias = None | ||
segment_ids = None | ||
|
||
if use_fwd: | ||
|
||
@jax.jit | ||
def impl(q, k, v, bias): | ||
fn = functools.partial( | ||
flash_attention, | ||
causal=causal, | ||
softmax_scale=sm_scale, | ||
) | ||
out, _ = jax.vjp(fn, q, k, v, bias) | ||
return out | ||
|
||
else: | ||
impl = functools.partial( | ||
flash_attention, | ||
causal=causal, | ||
softmax_scale=sm_scale, | ||
) | ||
|
||
o = impl(q, k, v, bias) | ||
o_ref = mha_reference(q, k, v, bias, segment_ids, causal=causal, softmax_scale=sm_scale) | ||
chex.assert_trees_all_close(o, o_ref, atol=0.05) | ||
|
||
|
||
@pytest.mark.parametrize( | ||
"batch_size,num_heads,seq_len,per_head_dim", | ||
[ | ||
(1, 1, 2048, 64), | ||
(2, 2, 2048, 64), | ||
(1, 1, 2048, 128), | ||
(2, 2, 2048, 128), | ||
(1, 8, 2048, 128), | ||
(2, 8, 2048, 128), | ||
], | ||
) | ||
@pytest.mark.parametrize("causal", [True, False]) | ||
@pytest.mark.parametrize("input_dtype", [jnp.bfloat16, jnp.float16, jnp.float32]) | ||
def test_bwd_against_ref( | ||
batch_size: int, | ||
num_heads: int, | ||
seq_len: int, | ||
per_head_dim: int, | ||
causal: bool, | ||
input_dtype: jnp.dtype, | ||
): | ||
sm_scale = 1.0 / (per_head_dim**0.5) | ||
q = jax.random.normal( | ||
jax.random.PRNGKey(0), (batch_size, seq_len, num_heads, per_head_dim), dtype=input_dtype | ||
) | ||
k = jax.random.normal( | ||
jax.random.PRNGKey(1), (batch_size, seq_len, num_heads, per_head_dim), dtype=input_dtype | ||
) | ||
v = jax.random.normal( | ||
jax.random.PRNGKey(2), (batch_size, seq_len, num_heads, per_head_dim), dtype=input_dtype | ||
) | ||
|
||
bias = None | ||
segment_ids = None | ||
|
||
def fn(q, k, v, bias): | ||
return flash_attention( | ||
q, | ||
k, | ||
v, | ||
bias, | ||
causal=causal, | ||
softmax_scale=sm_scale, | ||
).sum() | ||
|
||
def ref_fn(q, k, v, bias, segment_ids): | ||
return mha_reference( | ||
q, | ||
k, | ||
v, | ||
bias, | ||
segment_ids, | ||
causal=causal, | ||
softmax_scale=sm_scale, | ||
).sum() | ||
|
||
jax_grads = jax.grad(fn, argnums=(0, 1, 2))(q, k, v, bias) | ||
jax_ref_grads = jax.grad(ref_fn, argnums=(0, 1, 2))(q, k, v, bias, segment_ids) | ||
chex.assert_trees_all_close(jax_grads, jax_ref_grads, atol=0.07) |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -75,7 +75,7 @@ def mha_reference( | |
|
||
|
||
def flash_attention_implementation( | ||
backend: Literal["cpu", "tpu", "gpu", "xla"], | ||
backend: Literal["cpu", "tpu", "gpu", "xla", "neuron"], | ||
*, | ||
mask: Optional[MaskFn] = None, | ||
softmax_scale: float, | ||
|
@@ -159,6 +159,21 @@ def jit_attn(query, key, value, bias, segment_ids): | |
|
||
return jit_attn | ||
|
||
elif backend == "neuron": | ||
from axlearn.common.flash_attention.neuron_attention import ( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. On demand import is kind of risky, we can live with it for functions inside neuron_attention.py, can we at least get it as a header import for files outside of neuron_attention.py? |
||
flash_attention as neuron_flash_attention, | ||
) | ||
|
||
# shard_map-decorated function needs to be jitted. | ||
@jax.jit | ||
def jit_attn(query, key, value, bias, segment_ids): | ||
if segment_ids != None: | ||
raise Exception("Sequence Packing is not supported on Neuron backend") | ||
return neuron_flash_attention( | ||
query, key, value, bias, causal, softmax_scale) | ||
|
||
return jit_attn | ||
|
||
elif backend in ("cpu", "xla"): | ||
if backend == "cpu": | ||
logging.warning("Flash attention CPU backend is for testing only.") | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we get a support for segment ID and dropout as well? Both are quite needed nowadays.