diff --git a/jax_flash_attn2/_custom_call_lib/lib.py b/jax_flash_attn2/_custom_call_lib/lib.py index 27477b7..dec310d 100644 --- a/jax_flash_attn2/_custom_call_lib/lib.py +++ b/jax_flash_attn2/_custom_call_lib/lib.py @@ -339,6 +339,7 @@ def get_or_create_triton_kernel( fn, arg_dtypes, scalar_args, + device, *, num_warps, num_stages, @@ -352,7 +353,6 @@ def get_or_create_triton_kernel( num_warps = 4 if num_stages is None: num_stages = 3 - device = 0 if compute_capability is None: compute_capability = triton_kernel_call_lib.get_compute_capability(device) if num_ctas > 1 and compute_capability < 90: @@ -504,6 +504,7 @@ def triton_kernel_call_lowering( zeroed_outputs, debug, serialized_metadata, + device, **metaparams, ): if jaxlib.version.__version_info__ < (0, 3, 22) and input_output_aliases: @@ -600,6 +601,7 @@ def prune_configs(configs, named_args, **kwargs): fn, arg_dtypes, scalar_args, + device, num_warps=params["num_warps"], num_stages=params["num_stages"], num_ctas=params["num_ctas"], @@ -702,6 +704,7 @@ def triton_call( zeroed_outputs: (Sequence[int] | Callable[[dict[str, Any]], Sequence[int]]) = (), debug: bool = False, serialized_metadata: bytes = b"", + device: int = 0, **metaparams: Any, ) -> Any: """Calls a Triton kernel with `jax.Array` arguments. @@ -730,6 +733,7 @@ def triton_call( debug: Prints out intermediate IRs if True for debugging purposes. serialized_metadata: Arbitrary metadata that will be added into the serialized kernel call. + device (int): device id in current process to compile triton kernel on **metaparams: Additional keyword arguments that will be provided to a `grid` (if it is a function) and to the Triton kernel as `constexpr` arguments. @@ -774,6 +778,7 @@ def triton_call( zeroed_outputs=zeroed_outputs, debug=debug, serialized_metadata=serialized_metadata, + device=device, **metaparams, ) return tree_util.tree_unflatten(out_tree, out_flat) diff --git a/jax_flash_attn2/triton_kernels/gqa_kernel.py b/jax_flash_attn2/triton_kernels/gqa_kernel.py index aedf5f8..34c7067 100644 --- a/jax_flash_attn2/triton_kernels/gqa_kernel.py +++ b/jax_flash_attn2/triton_kernels/gqa_kernel.py @@ -32,7 +32,7 @@ from jax import custom_vjp, random as jrnd from jax import numpy as jnp from triton import language as tl -from fjformer.jax_triton import triton_call +from jax_flash_attn2._custom_call_lib import triton_call FLASH_ATTN_BWD_ = True @@ -383,6 +383,7 @@ def _fwd_attention_kernel_call( stride_qb, stride_qm, stride_qh, stride_qg, stride_qd = get_strides(query.shape) stride_kb, stride_kn, stride_kh, stride_kd = get_strides(key.shape) stride_vb, stride_vn, stride_vh, stride_vd = get_strides(value.shape) + device_id = getattr(getattr(query, "device", None), "id", 0) out, lse = triton_call( query, key, @@ -429,6 +430,7 @@ def _fwd_attention_kernel_call( num_groups, ), name="triton::ops::_fwd_attn_kernel", + device=device_id, **metaparams, ) return out.reshape(batch, seqlen_q, num_q_heads, headdim), lse @@ -783,6 +785,7 @@ def _bwd_attention_kernel_call( num_warps = 4 if headdim <= 64 else 8 + device_id = getattr(getattr(query, "device", None), "id", 0) # kernel kwargs metaparams = dict( BLOCK_M=128, @@ -824,6 +827,7 @@ def _bwd_attention_kernel_call( ), kernel=_bwd_do_attention_kernel, name="triton::ops::_bwd_do_attention_kernel", + device=device_id, **metaparams, ) metaparams = dict( @@ -889,6 +893,7 @@ def _bwd_attention_kernel_call( ), out_shape=bwd_kernel_out_shapes, name="triton::ops::_bwd_attention_kernel", + device=device_id, **metaparams, )