Skip to content

Commit

Permalink
adding mha-gpu benchmarks
Browse files Browse the repository at this point in the history
  • Loading branch information
erfanzar committed Oct 24, 2024
1 parent 2f1ca04 commit 834fcea
Show file tree
Hide file tree
Showing 457 changed files with 1,797 additions and 6 deletions.
155 changes: 155 additions & 0 deletions benchmarks/mha/triton/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,155 @@
# GPU Attention Mechanism Benchmark

This benchmark compares the performance between two attention implementations:
1. JAX Flash Attention 2
2. JAX Scaled Dot-Product Attention (SDPA) via `jax.nn`

## Methodology
The comparison focuses primarily on Multi-Head Attention (MHA) operations. Note that some backward pass benchmarks are omitted since Flash Attention consistently outperforms across all sizes and block configurations.

## Error Codes and Limitations

### Triton Implementation
- **100ms timeout**: Indicates suboptimal BlockSizes leading to CUDA shared memory overflow
- **300ms timeout**: Signals initialization failures in QKV matrices

### JAX Implementation
- **100ms**: Represents attention computation failure (Out of Memory)
- **300ms**: Indicates function compilation failure

## Notes
- Backward pass benchmarks are partially excluded due to Flash Attention's consistent superior performance
- All measurements were conducted under identical hardware conditions
- Results are reproducible under specified configurations
- JAX version `0.4.35`, Cuda version `12.6.0`, GPU `RTX 3090 TI`

## Benchmark Code

```python
import os

import jax
import jaxlib
import triton
from jax import nn
from jax import numpy as jnp
from jax import random as jrnd

from jax_flash_attn2 import get_cached_flash_attention

benchmark_configs = []
for mode in ["bwd", "fwd"]:
for batch_size in [1, 2, 4]:
for bias in [True, False]:
for headdim in [64, 128, 256]:
for num_heads in [8, 16, 32]:
for blocksize_q in [32, 64, 128]:
for blocksize_k in [32, 64, 128]:
benchmark_configs.append(
triton.testing.Benchmark(
x_names=["seqlen"],
x_vals=[1024, 2048, 4096, 6144, 8192],
line_arg="provider",
line_vals=["triton-block-ptr", "triton-ptr-block", "jax"],
line_names=["Triton-BlockPtr", "Triton-PtrBlock", "Jax"],
styles=[("green", "-"), ("blue", "-."), ("blue", ":")],
ylabel="MS",
plot_name=f"batch_size={batch_size}-bias={bias}-headdim={headdim}-num_heads={num_heads}-blocksize_q={blocksize_q}-blocksize_k={blocksize_k}-mode={mode}",
args={
"BATCH": batch_size,
"H": num_heads,
"HEAD_DIM": headdim,
"mode": mode,
"BIAS": bias,
"blocksize_k": blocksize_k,
"blocksize_q": blocksize_q,
},
)
)


@triton.testing.perf_report(benchmark_configs)
def mha_attention_benchmark(
seqlen,
H,
BATCH,
HEAD_DIM,
mode,
BIAS,
blocksize_k,
blocksize_q,
provider,
):
try:
q_key, k_key, v_key = jrnd.split(jrnd.PRNGKey(8), 3)
query = jax.nn.initializers.normal(2)(
q_key, (BATCH, seqlen, H, HEAD_DIM), dtype=jnp.float16
)
key = jax.nn.initializers.normal(2)(
k_key, (BATCH, seqlen, H, HEAD_DIM), dtype=jnp.float16
)
value = jax.nn.initializers.normal(2)(
v_key, (BATCH, seqlen, H, HEAD_DIM), dtype=jnp.float16
)
bias = (
jnp.where(
jrnd.randint(v_key, (BATCH, 1, seqlen, seqlen), 0, 4) > 2,
jnp.finfo(jnp.float16).min,
0,
)
if BIAS
else None
)
if mode == "fwd":
if provider == "triton-block-ptr":
os.environ["FLASH_ATTN_BLOCK_PTR"] = "1"
flash_attn = get_cached_flash_attention(
blocksize_k=blocksize_k,
blocksize_q=blocksize_q,
)
fn = lambda: flash_attn(query, key, value, bias)
elif provider == "triton-ptr-block":
os.environ["FLASH_ATTN_BLOCK_PTR"] = "0"
flash_attn = get_cached_flash_attention(
blocksize_k=blocksize_k,
blocksize_q=blocksize_q,
)
fn = lambda: flash_attn(query, key, value, bias)
elif provider == "jax":
_fn = jax.jit(nn.dot_product_attention)
fn = lambda: _fn(query, key, value, bias).block_until_ready()
elif mode == "bwd":
if provider == "triton-block-ptr":
os.environ["FLASH_ATTN_BLOCK_PTR"] = "1"
flash_attn = get_cached_flash_attention(
blocksize_k=blocksize_k,
blocksize_q=blocksize_q,
)
fn = lambda: jax.grad(lambda *x: flash_attn(*x).sum())(query, key, value, bias)
elif provider == "triton-ptr-block":
os.environ["FLASH_ATTN_BLOCK_PTR"] = "0"
flash_attn = get_cached_flash_attention(
blocksize_k=blocksize_k,
blocksize_q=blocksize_q,
)
fn = lambda: jax.grad(lambda *x: flash_attn(*x).sum())(query, key, value, bias)
elif provider == "jax":
_fn = jax.jit(nn.dot_product_attention)
fn = lambda: jax.grad(lambda *x: _fn(*x).sum())(
query, key, value, bias
).block_until_ready()
try:
ms = triton.testing.do_bench(fn)
except jaxlib.xla_extension.XlaRuntimeError:
ms = 100.0000
return ms
except: # noqa
return 300.0000


if __name__ == "__main__":
mha_attention_benchmark.run(
print_data=True,
save_path="jax-flash-attn2/benchmarks/mha/triton",
)
```
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
seqlen,Triton-BlockPtr,Triton-PtrBlock,Jax
1024.000000,2.239545,2.246270,0.534121
2048.000000,2.802878,2.830848,1.863625
4096.000000,4.071296,4.052608,6.983003
6144.000000,0.001365,0.002048,15.470117
8192.000000,0.001536,0.001024,27.271233
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
seqlen,Triton-BlockPtr,Triton-PtrBlock,Jax
1024.000000,0.365310,0.331610,0.529052
2048.000000,0.654309,0.672502,1.742088
4096.000000,0.950076,0.917385,7.084596
6144.000000,1.453598,2.363994,15.538006
8192.000000,6.834781,6.913000,25.276800
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
seqlen,Triton-BlockPtr,Triton-PtrBlock,Jax
1024.000000,2.652372,2.649856,0.532650
2048.000000,4.689167,4.564451,1.769542
4096.000000,3.866112,3.788493,6.997145
6144.000000,5.734605,5.803213,15.507152
8192.000000,0.001536,0.002048,27.187307
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
seqlen,Triton-BlockPtr,Triton-PtrBlock,Jax
1024.000000,0.340694,0.356367,0.536380
2048.000000,0.584720,0.620229,1.750588
4096.000000,0.858617,0.737234,7.104455
6144.000000,2.303791,1.034298,15.551141
8192.000000,6.617468,4.754402,27.075638
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
seqlen,Triton-BlockPtr,Triton-PtrBlock,Jax
1024.000000,0.294859,0.237292,0.540099
2048.000000,0.573222,0.623379,1.740166
4096.000000,0.974636,0.909623,7.070757
6144.000000,1.605129,1.223835,15.418086
8192.000000,7.616804,7.978400,26.985130
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
seqlen,Triton-BlockPtr,Triton-PtrBlock,Jax
1024.000000,0.243670,0.274491,0.537623
2048.000000,0.521724,0.531087,1.777017
4096.000000,0.917019,0.777973,7.014678
6144.000000,2.103479,1.501221,15.532090
8192.000000,3.026739,8.555988,27.294037
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
seqlen,Triton-BlockPtr,Triton-PtrBlock,Jax
1024.000000,1.689071,1.632223,0.527805
2048.000000,1.774188,1.849037,1.761436
4096.000000,3.593045,3.573333,6.816355
6144.000000,2.518630,2.472345,15.518966
8192.000000,0.001024,0.001365,27.134325
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
seqlen,Triton-BlockPtr,Triton-PtrBlock,Jax
1024.000000,0.268690,0.324626,0.551810
2048.000000,0.669807,0.695737,1.728888
4096.000000,0.962905,0.993254,7.070817
6144.000000,1.448015,1.474580,15.472186
8192.000000,8.485195,6.763800,25.410240
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
seqlen,Triton-BlockPtr,Triton-PtrBlock,Jax
1024.000000,0.250069,0.270965,0.536178
2048.000000,0.588549,0.612682,1.740736
4096.000000,0.851080,0.833411,6.838324
6144.000000,2.407159,2.345758,15.523211
8192.000000,4.450627,1.349848,27.101526
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
seqlen,Triton-BlockPtr,Triton-PtrBlock,Jax
1024.000000,1.728468,1.726268,0.949476
2048.000000,3.036570,3.023552,3.590443
4096.000000,1.823232,1.859328,13.816150
6144.000000,0.410624,0.001536,30.873941
8192.000000,0.202752,0.325632,54.324223
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
seqlen,Triton-BlockPtr,Triton-PtrBlock,Jax
1024.000000,0.449706,0.441662,0.942542
2048.000000,0.725489,0.718128,3.535767
4096.000000,5.083377,4.352596,13.542912
6144.000000,10.709229,10.753654,30.895445
8192.000000,18.391808,18.459818,54.175743
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
seqlen,Triton-BlockPtr,Triton-PtrBlock,Jax
1024.000000,2.644100,2.571638,0.952658
2048.000000,2.912882,2.799400,3.602687
4096.000000,2.237235,2.543002,14.042287
6144.000000,0.001536,0.223232,30.807169
8192.000000,0.314368,0.294912,55.842976
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
seqlen,Triton-BlockPtr,Triton-PtrBlock,Jax
1024.000000,0.397935,0.388066,0.943657
2048.000000,0.682312,0.656810,3.458389
4096.000000,4.978271,4.525303,13.793052
6144.000000,14.645965,14.720819,30.976171
8192.000000,24.110195,24.234497,54.190079
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
seqlen,Triton-BlockPtr,Triton-PtrBlock,Jax
1024.000000,0.428652,0.437393,0.954350
2048.000000,0.605235,0.676507,3.454546
4096.000000,3.610572,4.651672,13.772809
6144.000000,14.301920,14.299665,30.698877
8192.000000,23.950932,23.703951,55.654690
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
seqlen,Triton-BlockPtr,Triton-PtrBlock,Jax
1024.000000,0.420492,0.410696,0.946157
2048.000000,0.573895,0.620707,3.593755
4096.000000,6.201218,5.305306,13.291168
6144.000000,13.638240,13.751891,30.975113
8192.000000,22.504395,22.605482,54.190079
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
seqlen,Triton-BlockPtr,Triton-PtrBlock,Jax
1024.000000,1.148229,1.220246,0.958392
2048.000000,2.376216,2.197991,3.528061
4096.000000,2.406570,2.420053,13.825354
6144.000000,0.001536,0.001536,31.423487
8192.000000,0.350208,0.385024,54.291454
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
seqlen,Triton-BlockPtr,Triton-PtrBlock,Jax
1024.000000,0.496633,0.496985,0.951059
2048.000000,0.713811,0.683929,3.521397
4096.000000,5.992388,4.051164,13.698790
6144.000000,14.462778,14.478235,31.428032
8192.000000,23.660486,23.670952,54.098175
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
seqlen,Triton-BlockPtr,Triton-PtrBlock,Jax
1024.000000,0.433008,0.441802,0.946560
2048.000000,0.633380,0.601894,3.526337
4096.000000,5.002587,5.816985,14.009102
6144.000000,12.504795,12.634414,30.791702
8192.000000,21.693977,21.714382,54.885471
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
seqlen,Triton-BlockPtr,Triton-PtrBlock,Jax
1024.000000,1.958384,1.934998,0.320536
2048.000000,3.945316,3.942776,0.944790
4096.000000,4.544990,4.558117,3.522157
6144.000000,6.691255,6.519808,7.861914
8192.000000,2.664960,2.376704,13.630857
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
seqlen,Triton-BlockPtr,Triton-PtrBlock,Jax
1024.000000,0.299450,0.282076,0.315573
2048.000000,0.335973,0.371345,0.950768
4096.000000,0.976245,0.977657,3.573005
6144.000000,1.223922,1.234974,7.762157
8192.000000,1.466760,1.437840,13.701576
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
seqlen,Triton-BlockPtr,Triton-PtrBlock,Jax
1024.000000,1.730391,1.696853,0.323550
2048.000000,5.271585,5.262336,0.955698
4096.000000,8.617779,8.639847,3.568900
6144.000000,7.742692,7.931449,7.616792
8192.000000,3.658547,3.622707,13.906825
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
seqlen,Triton-BlockPtr,Triton-PtrBlock,Jax
1024.000000,0.300364,0.313451,0.310758
2048.000000,0.431260,0.451925,0.940426
4096.000000,0.650297,0.672193,3.455024
6144.000000,0.825722,0.741293,7.913847
8192.000000,0.879881,0.946141,13.727808
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
seqlen,Triton-BlockPtr,Triton-PtrBlock,Jax
1024.000000,0.320431,0.282806,0.320881
2048.000000,0.504967,0.520851,0.961009
4096.000000,0.796427,0.712784,3.520143
6144.000000,0.958901,1.008452,7.763666
8192.000000,1.179971,1.279453,13.615736
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
seqlen,Triton-BlockPtr,Triton-PtrBlock,Jax
1024.000000,0.301363,0.307574,0.330263
2048.000000,0.455921,0.459615,0.948956
4096.000000,0.759461,0.709308,3.643176
6144.000000,0.855790,0.862116,7.874294
8192.000000,0.934822,1.067411,13.695122
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
seqlen,Triton-BlockPtr,Triton-PtrBlock,Jax
1024.000000,1.199489,1.188126,0.320194
2048.000000,2.732049,2.691692,0.944263
4096.000000,2.953402,2.968529,3.576501
6144.000000,6.294714,6.171462,7.698459
8192.000000,3.475797,3.483989,13.665248
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
seqlen,Triton-BlockPtr,Triton-PtrBlock,Jax
1024.000000,0.384097,0.378137,0.316786
2048.000000,0.418716,0.478603,0.961054
4096.000000,0.971452,1.011760,3.492059
6144.000000,1.341094,1.265911,7.893131
8192.000000,1.407616,1.346013,13.595228
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
seqlen,Triton-BlockPtr,Triton-PtrBlock,Jax
1024.000000,0.330636,0.319132,0.312371
2048.000000,0.417799,0.380138,0.973631
4096.000000,0.808951,0.787719,3.556542
6144.000000,0.995811,1.010753,7.949232
8192.000000,1.114297,1.104483,12.924819
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
seqlen,Triton-BlockPtr,Triton-PtrBlock,Jax
1024.000000,100.000000,100.000000,0.627177
2048.000000,100.000000,100.000000,2.266261
4096.000000,100.000000,100.000000,8.579533
6144.000000,100.000000,100.000000,19.190802
8192.000000,100.000000,100.000000,33.459198
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
seqlen,Triton-BlockPtr,Triton-PtrBlock,Jax
1024.000000,1.690055,1.755562,0.639195
2048.000000,2.717568,2.680005,2.225120
4096.000000,2.809710,2.726766,8.499490
6144.000000,0.003072,0.003072,19.367949
8192.000000,0.423936,0.359424,34.027023
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
seqlen,Triton-BlockPtr,Triton-PtrBlock,Jax
1024.000000,3.770517,3.749226,0.625865
2048.000000,4.167748,4.096273,2.260407
4096.000000,2.687744,2.608128,8.657213
6144.000000,0.189952,0.217600,19.321568
8192.000000,0.407552,0.381952,34.033310
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
seqlen,Triton-BlockPtr,Triton-PtrBlock,Jax
1024.000000,0.526499,0.547620,0.628370
2048.000000,0.723867,0.755144,2.264489
4096.000000,3.418806,4.591451,8.766250
6144.000000,21.938980,23.481184,19.320877
8192.000000,27.087725,27.141266,33.542881
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
seqlen,Triton-BlockPtr,Triton-PtrBlock,Jax
1024.000000,0.480875,0.504145,0.609149
2048.000000,0.797078,0.773144,2.209283
4096.000000,4.941043,6.100704,8.522633
6144.000000,13.501738,13.587289,19.342388
8192.000000,23.328512,22.148439,33.728001
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
seqlen,Triton-BlockPtr,Triton-PtrBlock,Jax
1024.000000,0.509803,0.516781,0.630811
2048.000000,0.745072,0.764514,2.244790
4096.000000,4.551983,3.775602,8.626266
6144.000000,19.617281,19.697336,19.330496
8192.000000,29.000385,29.007162,33.484287
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
seqlen,Triton-BlockPtr,Triton-PtrBlock,Jax
1024.000000,1.779477,1.778627,0.622846
2048.000000,3.437056,3.412864,2.205383
4096.000000,1.805824,1.783040,8.804002
6144.000000,0.001024,0.208384,19.443304
8192.000000,0.376832,0.442368,34.250336
Loading

0 comments on commit 834fcea

Please sign in to comment.