Skip to content

Commit

Permalink
working on sequence sharding
Browse files Browse the repository at this point in the history
  • Loading branch information
erfanzar committed Nov 16, 2024
1 parent 4dd027c commit 3fea284
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 2 deletions.
7 changes: 6 additions & 1 deletion jax_flash_attn2/_custom_call_lib/lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -339,6 +339,7 @@ def get_or_create_triton_kernel(
fn,
arg_dtypes,
scalar_args,
device,
*,
num_warps,
num_stages,
Expand All @@ -352,7 +353,6 @@ def get_or_create_triton_kernel(
num_warps = 4
if num_stages is None:
num_stages = 3
device = 0
if compute_capability is None:
compute_capability = triton_kernel_call_lib.get_compute_capability(device)
if num_ctas > 1 and compute_capability < 90:
Expand Down Expand Up @@ -504,6 +504,7 @@ def triton_kernel_call_lowering(
zeroed_outputs,
debug,
serialized_metadata,
device,
**metaparams,
):
if jaxlib.version.__version_info__ < (0, 3, 22) and input_output_aliases:
Expand Down Expand Up @@ -600,6 +601,7 @@ def prune_configs(configs, named_args, **kwargs):
fn,
arg_dtypes,
scalar_args,
device,
num_warps=params["num_warps"],
num_stages=params["num_stages"],
num_ctas=params["num_ctas"],
Expand Down Expand Up @@ -702,6 +704,7 @@ def triton_call(
zeroed_outputs: (Sequence[int] | Callable[[dict[str, Any]], Sequence[int]]) = (),
debug: bool = False,
serialized_metadata: bytes = b"",
device: int = 0,
**metaparams: Any,
) -> Any:
"""Calls a Triton kernel with `jax.Array` arguments.
Expand Down Expand Up @@ -730,6 +733,7 @@ def triton_call(
debug: Prints out intermediate IRs if True for debugging purposes.
serialized_metadata: Arbitrary metadata that will be added into the
serialized kernel call.
device (int): device id in current process to compile triton kernel on
**metaparams: Additional keyword arguments that will be provided to a `grid`
(if it is a function) and to the Triton kernel as `constexpr` arguments.
Expand Down Expand Up @@ -774,6 +778,7 @@ def triton_call(
zeroed_outputs=zeroed_outputs,
debug=debug,
serialized_metadata=serialized_metadata,
device=device,
**metaparams,
)
return tree_util.tree_unflatten(out_tree, out_flat)
7 changes: 6 additions & 1 deletion jax_flash_attn2/triton_kernels/gqa_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
from jax import custom_vjp, random as jrnd
from jax import numpy as jnp
from triton import language as tl
from fjformer.jax_triton import triton_call
from jax_flash_attn2._custom_call_lib import triton_call

FLASH_ATTN_BWD_ = True

Expand Down Expand Up @@ -383,6 +383,7 @@ def _fwd_attention_kernel_call(
stride_qb, stride_qm, stride_qh, stride_qg, stride_qd = get_strides(query.shape)
stride_kb, stride_kn, stride_kh, stride_kd = get_strides(key.shape)
stride_vb, stride_vn, stride_vh, stride_vd = get_strides(value.shape)
device_id = getattr(getattr(query, "device", None), "id", 0)
out, lse = triton_call(
query,
key,
Expand Down Expand Up @@ -429,6 +430,7 @@ def _fwd_attention_kernel_call(
num_groups,
),
name="triton::ops::_fwd_attn_kernel",
device=device_id,
**metaparams,
)
return out.reshape(batch, seqlen_q, num_q_heads, headdim), lse
Expand Down Expand Up @@ -783,6 +785,7 @@ def _bwd_attention_kernel_call(

num_warps = 4 if headdim <= 64 else 8

device_id = getattr(getattr(query, "device", None), "id", 0)
# kernel kwargs
metaparams = dict(
BLOCK_M=128,
Expand Down Expand Up @@ -824,6 +827,7 @@ def _bwd_attention_kernel_call(
),
kernel=_bwd_do_attention_kernel,
name="triton::ops::_bwd_do_attention_kernel",
device=device_id,
**metaparams,
)
metaparams = dict(
Expand Down Expand Up @@ -889,6 +893,7 @@ def _bwd_attention_kernel_call(
),
out_shape=bwd_kernel_out_shapes,
name="triton::ops::_bwd_attention_kernel",
device=device_id,
**metaparams,
)

Expand Down

0 comments on commit 3fea284

Please sign in to comment.