Skip to content

Commit

Permalink
add conditions to determine triton ws
Browse files Browse the repository at this point in the history
Summary:
This is to improve the ottb UX of the flash_attention operator. Right now it will break because warp spec has not yet upstreamed to main. Add predicates to find out if warp spec is available.

Plus, we only enable ThunderKittens in OSS as the fbcode version is outdated.

Reviewed By: FindHao

Differential Revision: D66402556

fbshipit-source-id: f15b079d747e96af87cff0a7ce22d41e63d62843
  • Loading branch information
xuzhao9 authored and facebook-github-bot committed Dec 6, 2024
1 parent 7fb9e5b commit e300fcf
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 4 deletions.
9 changes: 5 additions & 4 deletions tritonbench/operators/flash_attention/operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,7 @@
register_metric,
register_x_val,
)
from tritonbench.utils.triton_utils import has_warp_spec


def parse_op_args(args: List[str]):
Expand Down Expand Up @@ -294,7 +295,7 @@ def triton_tutorial_flash_v2_tma(
q, k, v, self.causal, self.sm_scale, "tma"
)

@register_benchmark(enabled=HAS_CUDA_124)
@register_benchmark(enabled=HAS_CUDA_124 and has_warp_spec())
def triton_tutorial_flash_v2_ws(
self,
q: torch.Tensor,
Expand All @@ -306,7 +307,7 @@ def triton_tutorial_flash_v2_ws(
q, k, v, self.causal, self.sm_scale, "ws"
)

@register_benchmark(enabled=HAS_CUDA_124)
@register_benchmark(enabled=HAS_CUDA_124 and has_warp_spec())
def triton_tutorial_flash_v2_tma_ws(
self,
q: torch.Tensor,
Expand All @@ -318,7 +319,7 @@ def triton_tutorial_flash_v2_tma_ws(
q, k, v, self.causal, self.sm_scale, "tma_ws"
)

@register_benchmark(enabled=HAS_CUDA_124)
@register_benchmark(enabled=HAS_CUDA_124 and has_warp_spec())
def triton_tutorial_flash_v2_tma_ws_persistent(
self,
q: torch.Tensor,
Expand Down Expand Up @@ -415,7 +416,7 @@ def colfax_cutlass(self, q, k, v):
default_scale,
)

@register_benchmark(enabled=bool(tk_fwd is not None))
@register_benchmark(enabled=not IS_FBCODE and bool(tk_fwd is not None))
def tk(self, q, k, v):
o = torch.zeros_like(v)
l_tensor = torch.zeros_like(o).to(torch.float32)
Expand Down
7 changes: 7 additions & 0 deletions tritonbench/utils/triton_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
# utils to identify triton versions


def has_warp_spec():
import triton.language as tl

return hasattr(tl, "async_task")

0 comments on commit e300fcf

Please sign in to comment.