diff --git a/benchmarks/mha/triton/README.md b/benchmarks/mha/triton/README.md new file mode 100644 index 0000000..56270c1 --- /dev/null +++ b/benchmarks/mha/triton/README.md @@ -0,0 +1,155 @@ +# GPU Attention Mechanism Benchmark + +This benchmark compares the performance between two attention implementations: +1. JAX Flash Attention 2 +2. JAX Scaled Dot-Product Attention (SDPA) via `jax.nn` + +## Methodology +The comparison focuses primarily on Multi-Head Attention (MHA) operations. Note that some backward pass benchmarks are omitted since Flash Attention consistently outperforms across all sizes and block configurations. + +## Error Codes and Limitations + +### Triton Implementation +- **100ms timeout**: Indicates suboptimal BlockSizes leading to CUDA shared memory overflow +- **300ms timeout**: Signals initialization failures in QKV matrices + +### JAX Implementation +- **100ms**: Represents attention computation failure (Out of Memory) +- **300ms**: Indicates function compilation failure + +## Notes +- Backward pass benchmarks are partially excluded due to Flash Attention's consistent superior performance +- All measurements were conducted under identical hardware conditions +- Results are reproducible under specified configurations +- JAX version `0.4.35`, Cuda version `12.6.0`, GPU `RTX 3090 TI` + +## Benchmark Code + +```python +import os + +import jax +import jaxlib +import triton +from jax import nn +from jax import numpy as jnp +from jax import random as jrnd + +from jax_flash_attn2 import get_cached_flash_attention + +benchmark_configs = [] +for mode in ["bwd", "fwd"]: + for batch_size in [1, 2, 4]: + for bias in [True, False]: + for headdim in [64, 128, 256]: + for num_heads in [8, 16, 32]: + for blocksize_q in [32, 64, 128]: + for blocksize_k in [32, 64, 128]: + benchmark_configs.append( + triton.testing.Benchmark( + x_names=["seqlen"], + x_vals=[1024, 2048, 4096, 6144, 8192], + line_arg="provider", + line_vals=["triton-block-ptr", "triton-ptr-block", "jax"], + line_names=["Triton-BlockPtr", "Triton-PtrBlock", "Jax"], + styles=[("green", "-"), ("blue", "-."), ("blue", ":")], + ylabel="MS", + plot_name=f"batch_size={batch_size}-bias={bias}-headdim={headdim}-num_heads={num_heads}-blocksize_q={blocksize_q}-blocksize_k={blocksize_k}-mode={mode}", + args={ + "BATCH": batch_size, + "H": num_heads, + "HEAD_DIM": headdim, + "mode": mode, + "BIAS": bias, + "blocksize_k": blocksize_k, + "blocksize_q": blocksize_q, + }, + ) + ) + + +@triton.testing.perf_report(benchmark_configs) +def mha_attention_benchmark( + seqlen, + H, + BATCH, + HEAD_DIM, + mode, + BIAS, + blocksize_k, + blocksize_q, + provider, +): + try: + q_key, k_key, v_key = jrnd.split(jrnd.PRNGKey(8), 3) + query = jax.nn.initializers.normal(2)( + q_key, (BATCH, seqlen, H, HEAD_DIM), dtype=jnp.float16 + ) + key = jax.nn.initializers.normal(2)( + k_key, (BATCH, seqlen, H, HEAD_DIM), dtype=jnp.float16 + ) + value = jax.nn.initializers.normal(2)( + v_key, (BATCH, seqlen, H, HEAD_DIM), dtype=jnp.float16 + ) + bias = ( + jnp.where( + jrnd.randint(v_key, (BATCH, 1, seqlen, seqlen), 0, 4) > 2, + jnp.finfo(jnp.float16).min, + 0, + ) + if BIAS + else None + ) + if mode == "fwd": + if provider == "triton-block-ptr": + os.environ["FLASH_ATTN_BLOCK_PTR"] = "1" + flash_attn = get_cached_flash_attention( + blocksize_k=blocksize_k, + blocksize_q=blocksize_q, + ) + fn = lambda: flash_attn(query, key, value, bias) + elif provider == "triton-ptr-block": + os.environ["FLASH_ATTN_BLOCK_PTR"] = "0" + flash_attn = get_cached_flash_attention( + blocksize_k=blocksize_k, + blocksize_q=blocksize_q, + ) + fn = lambda: flash_attn(query, key, value, bias) + elif provider == "jax": + _fn = jax.jit(nn.dot_product_attention) + fn = lambda: _fn(query, key, value, bias).block_until_ready() + elif mode == "bwd": + if provider == "triton-block-ptr": + os.environ["FLASH_ATTN_BLOCK_PTR"] = "1" + flash_attn = get_cached_flash_attention( + blocksize_k=blocksize_k, + blocksize_q=blocksize_q, + ) + fn = lambda: jax.grad(lambda *x: flash_attn(*x).sum())(query, key, value, bias) + elif provider == "triton-ptr-block": + os.environ["FLASH_ATTN_BLOCK_PTR"] = "0" + flash_attn = get_cached_flash_attention( + blocksize_k=blocksize_k, + blocksize_q=blocksize_q, + ) + fn = lambda: jax.grad(lambda *x: flash_attn(*x).sum())(query, key, value, bias) + elif provider == "jax": + _fn = jax.jit(nn.dot_product_attention) + fn = lambda: jax.grad(lambda *x: _fn(*x).sum())( + query, key, value, bias + ).block_until_ready() + try: + ms = triton.testing.do_bench(fn) + except jaxlib.xla_extension.XlaRuntimeError: + ms = 100.0000 + return ms + except: # noqa + return 300.0000 + + +if __name__ == "__main__": + mha_attention_benchmark.run( + print_data=True, + save_path="jax-flash-attn2/benchmarks/mha/triton", + ) +``` \ No newline at end of file diff --git a/benchmarks/mha/triton/batch_size=1-bias=False-headdim=128-num_heads=16-blocksize_q=128-blocksize_k=128-mode=fwd.csv b/benchmarks/mha/triton/batch_size=1-bias=False-headdim=128-num_heads=16-blocksize_q=128-blocksize_k=128-mode=fwd.csv new file mode 100644 index 0000000..9bb8727 --- /dev/null +++ b/benchmarks/mha/triton/batch_size=1-bias=False-headdim=128-num_heads=16-blocksize_q=128-blocksize_k=128-mode=fwd.csv @@ -0,0 +1,6 @@ +seqlen,Triton-BlockPtr,Triton-PtrBlock,Jax +1024.000000,2.239545,2.246270,0.534121 +2048.000000,2.802878,2.830848,1.863625 +4096.000000,4.071296,4.052608,6.983003 +6144.000000,0.001365,0.002048,15.470117 +8192.000000,0.001536,0.001024,27.271233 diff --git a/benchmarks/mha/triton/batch_size=1-bias=False-headdim=128-num_heads=16-blocksize_q=128-blocksize_k=128-mode=fwd.png b/benchmarks/mha/triton/batch_size=1-bias=False-headdim=128-num_heads=16-blocksize_q=128-blocksize_k=128-mode=fwd.png new file mode 100644 index 0000000..4e4dd86 Binary files /dev/null and b/benchmarks/mha/triton/batch_size=1-bias=False-headdim=128-num_heads=16-blocksize_q=128-blocksize_k=128-mode=fwd.png differ diff --git a/benchmarks/mha/triton/batch_size=1-bias=False-headdim=128-num_heads=16-blocksize_q=128-blocksize_k=32-mode=fwd.csv b/benchmarks/mha/triton/batch_size=1-bias=False-headdim=128-num_heads=16-blocksize_q=128-blocksize_k=32-mode=fwd.csv new file mode 100644 index 0000000..8fb321d --- /dev/null +++ b/benchmarks/mha/triton/batch_size=1-bias=False-headdim=128-num_heads=16-blocksize_q=128-blocksize_k=32-mode=fwd.csv @@ -0,0 +1,6 @@ +seqlen,Triton-BlockPtr,Triton-PtrBlock,Jax +1024.000000,0.365310,0.331610,0.529052 +2048.000000,0.654309,0.672502,1.742088 +4096.000000,0.950076,0.917385,7.084596 +6144.000000,1.453598,2.363994,15.538006 +8192.000000,6.834781,6.913000,25.276800 diff --git a/benchmarks/mha/triton/batch_size=1-bias=False-headdim=128-num_heads=16-blocksize_q=128-blocksize_k=32-mode=fwd.png b/benchmarks/mha/triton/batch_size=1-bias=False-headdim=128-num_heads=16-blocksize_q=128-blocksize_k=32-mode=fwd.png new file mode 100644 index 0000000..ed47060 Binary files /dev/null and b/benchmarks/mha/triton/batch_size=1-bias=False-headdim=128-num_heads=16-blocksize_q=128-blocksize_k=32-mode=fwd.png differ diff --git a/benchmarks/mha/triton/batch_size=1-bias=False-headdim=128-num_heads=16-blocksize_q=128-blocksize_k=64-mode=fwd.csv b/benchmarks/mha/triton/batch_size=1-bias=False-headdim=128-num_heads=16-blocksize_q=128-blocksize_k=64-mode=fwd.csv new file mode 100644 index 0000000..86f9f20 --- /dev/null +++ b/benchmarks/mha/triton/batch_size=1-bias=False-headdim=128-num_heads=16-blocksize_q=128-blocksize_k=64-mode=fwd.csv @@ -0,0 +1,6 @@ +seqlen,Triton-BlockPtr,Triton-PtrBlock,Jax +1024.000000,2.652372,2.649856,0.532650 +2048.000000,4.689167,4.564451,1.769542 +4096.000000,3.866112,3.788493,6.997145 +6144.000000,5.734605,5.803213,15.507152 +8192.000000,0.001536,0.002048,27.187307 diff --git a/benchmarks/mha/triton/batch_size=1-bias=False-headdim=128-num_heads=16-blocksize_q=128-blocksize_k=64-mode=fwd.png b/benchmarks/mha/triton/batch_size=1-bias=False-headdim=128-num_heads=16-blocksize_q=128-blocksize_k=64-mode=fwd.png new file mode 100644 index 0000000..367214a Binary files /dev/null and b/benchmarks/mha/triton/batch_size=1-bias=False-headdim=128-num_heads=16-blocksize_q=128-blocksize_k=64-mode=fwd.png differ diff --git a/benchmarks/mha/triton/batch_size=1-bias=False-headdim=128-num_heads=16-blocksize_q=32-blocksize_k=128-mode=fwd.csv b/benchmarks/mha/triton/batch_size=1-bias=False-headdim=128-num_heads=16-blocksize_q=32-blocksize_k=128-mode=fwd.csv new file mode 100644 index 0000000..640d124 --- /dev/null +++ b/benchmarks/mha/triton/batch_size=1-bias=False-headdim=128-num_heads=16-blocksize_q=32-blocksize_k=128-mode=fwd.csv @@ -0,0 +1,6 @@ +seqlen,Triton-BlockPtr,Triton-PtrBlock,Jax +1024.000000,0.340694,0.356367,0.536380 +2048.000000,0.584720,0.620229,1.750588 +4096.000000,0.858617,0.737234,7.104455 +6144.000000,2.303791,1.034298,15.551141 +8192.000000,6.617468,4.754402,27.075638 diff --git a/benchmarks/mha/triton/batch_size=1-bias=False-headdim=128-num_heads=16-blocksize_q=32-blocksize_k=128-mode=fwd.png b/benchmarks/mha/triton/batch_size=1-bias=False-headdim=128-num_heads=16-blocksize_q=32-blocksize_k=128-mode=fwd.png new file mode 100644 index 0000000..cf5fba8 Binary files /dev/null and b/benchmarks/mha/triton/batch_size=1-bias=False-headdim=128-num_heads=16-blocksize_q=32-blocksize_k=128-mode=fwd.png differ diff --git a/benchmarks/mha/triton/batch_size=1-bias=False-headdim=128-num_heads=16-blocksize_q=32-blocksize_k=32-mode=fwd.csv b/benchmarks/mha/triton/batch_size=1-bias=False-headdim=128-num_heads=16-blocksize_q=32-blocksize_k=32-mode=fwd.csv new file mode 100644 index 0000000..0f85521 --- /dev/null +++ b/benchmarks/mha/triton/batch_size=1-bias=False-headdim=128-num_heads=16-blocksize_q=32-blocksize_k=32-mode=fwd.csv @@ -0,0 +1,6 @@ +seqlen,Triton-BlockPtr,Triton-PtrBlock,Jax +1024.000000,0.294859,0.237292,0.540099 +2048.000000,0.573222,0.623379,1.740166 +4096.000000,0.974636,0.909623,7.070757 +6144.000000,1.605129,1.223835,15.418086 +8192.000000,7.616804,7.978400,26.985130 diff --git a/benchmarks/mha/triton/batch_size=1-bias=False-headdim=128-num_heads=16-blocksize_q=32-blocksize_k=32-mode=fwd.png b/benchmarks/mha/triton/batch_size=1-bias=False-headdim=128-num_heads=16-blocksize_q=32-blocksize_k=32-mode=fwd.png new file mode 100644 index 0000000..76baf2c Binary files /dev/null and b/benchmarks/mha/triton/batch_size=1-bias=False-headdim=128-num_heads=16-blocksize_q=32-blocksize_k=32-mode=fwd.png differ diff --git a/benchmarks/mha/triton/batch_size=1-bias=False-headdim=128-num_heads=16-blocksize_q=32-blocksize_k=64-mode=fwd.csv b/benchmarks/mha/triton/batch_size=1-bias=False-headdim=128-num_heads=16-blocksize_q=32-blocksize_k=64-mode=fwd.csv new file mode 100644 index 0000000..99c2748 --- /dev/null +++ b/benchmarks/mha/triton/batch_size=1-bias=False-headdim=128-num_heads=16-blocksize_q=32-blocksize_k=64-mode=fwd.csv @@ -0,0 +1,6 @@ +seqlen,Triton-BlockPtr,Triton-PtrBlock,Jax +1024.000000,0.243670,0.274491,0.537623 +2048.000000,0.521724,0.531087,1.777017 +4096.000000,0.917019,0.777973,7.014678 +6144.000000,2.103479,1.501221,15.532090 +8192.000000,3.026739,8.555988,27.294037 diff --git a/benchmarks/mha/triton/batch_size=1-bias=False-headdim=128-num_heads=16-blocksize_q=32-blocksize_k=64-mode=fwd.png b/benchmarks/mha/triton/batch_size=1-bias=False-headdim=128-num_heads=16-blocksize_q=32-blocksize_k=64-mode=fwd.png new file mode 100644 index 0000000..c208336 Binary files /dev/null and b/benchmarks/mha/triton/batch_size=1-bias=False-headdim=128-num_heads=16-blocksize_q=32-blocksize_k=64-mode=fwd.png differ diff --git a/benchmarks/mha/triton/batch_size=1-bias=False-headdim=128-num_heads=16-blocksize_q=64-blocksize_k=128-mode=fwd.csv b/benchmarks/mha/triton/batch_size=1-bias=False-headdim=128-num_heads=16-blocksize_q=64-blocksize_k=128-mode=fwd.csv new file mode 100644 index 0000000..4a59122 --- /dev/null +++ b/benchmarks/mha/triton/batch_size=1-bias=False-headdim=128-num_heads=16-blocksize_q=64-blocksize_k=128-mode=fwd.csv @@ -0,0 +1,6 @@ +seqlen,Triton-BlockPtr,Triton-PtrBlock,Jax +1024.000000,1.689071,1.632223,0.527805 +2048.000000,1.774188,1.849037,1.761436 +4096.000000,3.593045,3.573333,6.816355 +6144.000000,2.518630,2.472345,15.518966 +8192.000000,0.001024,0.001365,27.134325 diff --git a/benchmarks/mha/triton/batch_size=1-bias=False-headdim=128-num_heads=16-blocksize_q=64-blocksize_k=128-mode=fwd.png b/benchmarks/mha/triton/batch_size=1-bias=False-headdim=128-num_heads=16-blocksize_q=64-blocksize_k=128-mode=fwd.png new file mode 100644 index 0000000..2307e40 Binary files /dev/null and b/benchmarks/mha/triton/batch_size=1-bias=False-headdim=128-num_heads=16-blocksize_q=64-blocksize_k=128-mode=fwd.png differ diff --git a/benchmarks/mha/triton/batch_size=1-bias=False-headdim=128-num_heads=16-blocksize_q=64-blocksize_k=32-mode=fwd.csv b/benchmarks/mha/triton/batch_size=1-bias=False-headdim=128-num_heads=16-blocksize_q=64-blocksize_k=32-mode=fwd.csv new file mode 100644 index 0000000..1c19809 --- /dev/null +++ b/benchmarks/mha/triton/batch_size=1-bias=False-headdim=128-num_heads=16-blocksize_q=64-blocksize_k=32-mode=fwd.csv @@ -0,0 +1,6 @@ +seqlen,Triton-BlockPtr,Triton-PtrBlock,Jax +1024.000000,0.268690,0.324626,0.551810 +2048.000000,0.669807,0.695737,1.728888 +4096.000000,0.962905,0.993254,7.070817 +6144.000000,1.448015,1.474580,15.472186 +8192.000000,8.485195,6.763800,25.410240 diff --git a/benchmarks/mha/triton/batch_size=1-bias=False-headdim=128-num_heads=16-blocksize_q=64-blocksize_k=32-mode=fwd.png b/benchmarks/mha/triton/batch_size=1-bias=False-headdim=128-num_heads=16-blocksize_q=64-blocksize_k=32-mode=fwd.png new file mode 100644 index 0000000..babace8 Binary files /dev/null and b/benchmarks/mha/triton/batch_size=1-bias=False-headdim=128-num_heads=16-blocksize_q=64-blocksize_k=32-mode=fwd.png differ diff --git a/benchmarks/mha/triton/batch_size=1-bias=False-headdim=128-num_heads=16-blocksize_q=64-blocksize_k=64-mode=fwd.csv b/benchmarks/mha/triton/batch_size=1-bias=False-headdim=128-num_heads=16-blocksize_q=64-blocksize_k=64-mode=fwd.csv new file mode 100644 index 0000000..5f5c26d --- /dev/null +++ b/benchmarks/mha/triton/batch_size=1-bias=False-headdim=128-num_heads=16-blocksize_q=64-blocksize_k=64-mode=fwd.csv @@ -0,0 +1,6 @@ +seqlen,Triton-BlockPtr,Triton-PtrBlock,Jax +1024.000000,0.250069,0.270965,0.536178 +2048.000000,0.588549,0.612682,1.740736 +4096.000000,0.851080,0.833411,6.838324 +6144.000000,2.407159,2.345758,15.523211 +8192.000000,4.450627,1.349848,27.101526 diff --git a/benchmarks/mha/triton/batch_size=1-bias=False-headdim=128-num_heads=16-blocksize_q=64-blocksize_k=64-mode=fwd.png b/benchmarks/mha/triton/batch_size=1-bias=False-headdim=128-num_heads=16-blocksize_q=64-blocksize_k=64-mode=fwd.png new file mode 100644 index 0000000..a80bac2 Binary files /dev/null and b/benchmarks/mha/triton/batch_size=1-bias=False-headdim=128-num_heads=16-blocksize_q=64-blocksize_k=64-mode=fwd.png differ diff --git a/benchmarks/mha/triton/batch_size=1-bias=False-headdim=128-num_heads=32-blocksize_q=128-blocksize_k=128-mode=fwd.csv b/benchmarks/mha/triton/batch_size=1-bias=False-headdim=128-num_heads=32-blocksize_q=128-blocksize_k=128-mode=fwd.csv new file mode 100644 index 0000000..5de6c6d --- /dev/null +++ b/benchmarks/mha/triton/batch_size=1-bias=False-headdim=128-num_heads=32-blocksize_q=128-blocksize_k=128-mode=fwd.csv @@ -0,0 +1,6 @@ +seqlen,Triton-BlockPtr,Triton-PtrBlock,Jax +1024.000000,1.728468,1.726268,0.949476 +2048.000000,3.036570,3.023552,3.590443 +4096.000000,1.823232,1.859328,13.816150 +6144.000000,0.410624,0.001536,30.873941 +8192.000000,0.202752,0.325632,54.324223 diff --git a/benchmarks/mha/triton/batch_size=1-bias=False-headdim=128-num_heads=32-blocksize_q=128-blocksize_k=128-mode=fwd.png b/benchmarks/mha/triton/batch_size=1-bias=False-headdim=128-num_heads=32-blocksize_q=128-blocksize_k=128-mode=fwd.png new file mode 100644 index 0000000..d53c0f5 Binary files /dev/null and b/benchmarks/mha/triton/batch_size=1-bias=False-headdim=128-num_heads=32-blocksize_q=128-blocksize_k=128-mode=fwd.png differ diff --git a/benchmarks/mha/triton/batch_size=1-bias=False-headdim=128-num_heads=32-blocksize_q=128-blocksize_k=32-mode=fwd.csv b/benchmarks/mha/triton/batch_size=1-bias=False-headdim=128-num_heads=32-blocksize_q=128-blocksize_k=32-mode=fwd.csv new file mode 100644 index 0000000..4ba0ad8 --- /dev/null +++ b/benchmarks/mha/triton/batch_size=1-bias=False-headdim=128-num_heads=32-blocksize_q=128-blocksize_k=32-mode=fwd.csv @@ -0,0 +1,6 @@ +seqlen,Triton-BlockPtr,Triton-PtrBlock,Jax +1024.000000,0.449706,0.441662,0.942542 +2048.000000,0.725489,0.718128,3.535767 +4096.000000,5.083377,4.352596,13.542912 +6144.000000,10.709229,10.753654,30.895445 +8192.000000,18.391808,18.459818,54.175743 diff --git a/benchmarks/mha/triton/batch_size=1-bias=False-headdim=128-num_heads=32-blocksize_q=128-blocksize_k=32-mode=fwd.png b/benchmarks/mha/triton/batch_size=1-bias=False-headdim=128-num_heads=32-blocksize_q=128-blocksize_k=32-mode=fwd.png new file mode 100644 index 0000000..3ef20d2 Binary files /dev/null and b/benchmarks/mha/triton/batch_size=1-bias=False-headdim=128-num_heads=32-blocksize_q=128-blocksize_k=32-mode=fwd.png differ diff --git a/benchmarks/mha/triton/batch_size=1-bias=False-headdim=128-num_heads=32-blocksize_q=128-blocksize_k=64-mode=fwd.csv b/benchmarks/mha/triton/batch_size=1-bias=False-headdim=128-num_heads=32-blocksize_q=128-blocksize_k=64-mode=fwd.csv new file mode 100644 index 0000000..41d2377 --- /dev/null +++ b/benchmarks/mha/triton/batch_size=1-bias=False-headdim=128-num_heads=32-blocksize_q=128-blocksize_k=64-mode=fwd.csv @@ -0,0 +1,6 @@ +seqlen,Triton-BlockPtr,Triton-PtrBlock,Jax +1024.000000,2.644100,2.571638,0.952658 +2048.000000,2.912882,2.799400,3.602687 +4096.000000,2.237235,2.543002,14.042287 +6144.000000,0.001536,0.223232,30.807169 +8192.000000,0.314368,0.294912,55.842976 diff --git a/benchmarks/mha/triton/batch_size=1-bias=False-headdim=128-num_heads=32-blocksize_q=128-blocksize_k=64-mode=fwd.png b/benchmarks/mha/triton/batch_size=1-bias=False-headdim=128-num_heads=32-blocksize_q=128-blocksize_k=64-mode=fwd.png new file mode 100644 index 0000000..5ce20bf Binary files /dev/null and b/benchmarks/mha/triton/batch_size=1-bias=False-headdim=128-num_heads=32-blocksize_q=128-blocksize_k=64-mode=fwd.png differ diff --git a/benchmarks/mha/triton/batch_size=1-bias=False-headdim=128-num_heads=32-blocksize_q=32-blocksize_k=128-mode=fwd.csv b/benchmarks/mha/triton/batch_size=1-bias=False-headdim=128-num_heads=32-blocksize_q=32-blocksize_k=128-mode=fwd.csv new file mode 100644 index 0000000..9071898 --- /dev/null +++ b/benchmarks/mha/triton/batch_size=1-bias=False-headdim=128-num_heads=32-blocksize_q=32-blocksize_k=128-mode=fwd.csv @@ -0,0 +1,6 @@ +seqlen,Triton-BlockPtr,Triton-PtrBlock,Jax +1024.000000,0.397935,0.388066,0.943657 +2048.000000,0.682312,0.656810,3.458389 +4096.000000,4.978271,4.525303,13.793052 +6144.000000,14.645965,14.720819,30.976171 +8192.000000,24.110195,24.234497,54.190079 diff --git a/benchmarks/mha/triton/batch_size=1-bias=False-headdim=128-num_heads=32-blocksize_q=32-blocksize_k=128-mode=fwd.png b/benchmarks/mha/triton/batch_size=1-bias=False-headdim=128-num_heads=32-blocksize_q=32-blocksize_k=128-mode=fwd.png new file mode 100644 index 0000000..9d6e612 Binary files /dev/null and b/benchmarks/mha/triton/batch_size=1-bias=False-headdim=128-num_heads=32-blocksize_q=32-blocksize_k=128-mode=fwd.png differ diff --git a/benchmarks/mha/triton/batch_size=1-bias=False-headdim=128-num_heads=32-blocksize_q=32-blocksize_k=32-mode=fwd.csv b/benchmarks/mha/triton/batch_size=1-bias=False-headdim=128-num_heads=32-blocksize_q=32-blocksize_k=32-mode=fwd.csv new file mode 100644 index 0000000..97b3eda --- /dev/null +++ b/benchmarks/mha/triton/batch_size=1-bias=False-headdim=128-num_heads=32-blocksize_q=32-blocksize_k=32-mode=fwd.csv @@ -0,0 +1,6 @@ +seqlen,Triton-BlockPtr,Triton-PtrBlock,Jax +1024.000000,0.428652,0.437393,0.954350 +2048.000000,0.605235,0.676507,3.454546 +4096.000000,3.610572,4.651672,13.772809 +6144.000000,14.301920,14.299665,30.698877 +8192.000000,23.950932,23.703951,55.654690 diff --git a/benchmarks/mha/triton/batch_size=1-bias=False-headdim=128-num_heads=32-blocksize_q=32-blocksize_k=32-mode=fwd.png b/benchmarks/mha/triton/batch_size=1-bias=False-headdim=128-num_heads=32-blocksize_q=32-blocksize_k=32-mode=fwd.png new file mode 100644 index 0000000..9de933e Binary files /dev/null and b/benchmarks/mha/triton/batch_size=1-bias=False-headdim=128-num_heads=32-blocksize_q=32-blocksize_k=32-mode=fwd.png differ diff --git a/benchmarks/mha/triton/batch_size=1-bias=False-headdim=128-num_heads=32-blocksize_q=32-blocksize_k=64-mode=fwd.csv b/benchmarks/mha/triton/batch_size=1-bias=False-headdim=128-num_heads=32-blocksize_q=32-blocksize_k=64-mode=fwd.csv new file mode 100644 index 0000000..626cca0 --- /dev/null +++ b/benchmarks/mha/triton/batch_size=1-bias=False-headdim=128-num_heads=32-blocksize_q=32-blocksize_k=64-mode=fwd.csv @@ -0,0 +1,6 @@ +seqlen,Triton-BlockPtr,Triton-PtrBlock,Jax +1024.000000,0.420492,0.410696,0.946157 +2048.000000,0.573895,0.620707,3.593755 +4096.000000,6.201218,5.305306,13.291168 +6144.000000,13.638240,13.751891,30.975113 +8192.000000,22.504395,22.605482,54.190079 diff --git a/benchmarks/mha/triton/batch_size=1-bias=False-headdim=128-num_heads=32-blocksize_q=32-blocksize_k=64-mode=fwd.png b/benchmarks/mha/triton/batch_size=1-bias=False-headdim=128-num_heads=32-blocksize_q=32-blocksize_k=64-mode=fwd.png new file mode 100644 index 0000000..6a62062 Binary files /dev/null and b/benchmarks/mha/triton/batch_size=1-bias=False-headdim=128-num_heads=32-blocksize_q=32-blocksize_k=64-mode=fwd.png differ diff --git a/benchmarks/mha/triton/batch_size=1-bias=False-headdim=128-num_heads=32-blocksize_q=64-blocksize_k=128-mode=fwd.csv b/benchmarks/mha/triton/batch_size=1-bias=False-headdim=128-num_heads=32-blocksize_q=64-blocksize_k=128-mode=fwd.csv new file mode 100644 index 0000000..d08f14b --- /dev/null +++ b/benchmarks/mha/triton/batch_size=1-bias=False-headdim=128-num_heads=32-blocksize_q=64-blocksize_k=128-mode=fwd.csv @@ -0,0 +1,6 @@ +seqlen,Triton-BlockPtr,Triton-PtrBlock,Jax +1024.000000,1.148229,1.220246,0.958392 +2048.000000,2.376216,2.197991,3.528061 +4096.000000,2.406570,2.420053,13.825354 +6144.000000,0.001536,0.001536,31.423487 +8192.000000,0.350208,0.385024,54.291454 diff --git a/benchmarks/mha/triton/batch_size=1-bias=False-headdim=128-num_heads=32-blocksize_q=64-blocksize_k=128-mode=fwd.png b/benchmarks/mha/triton/batch_size=1-bias=False-headdim=128-num_heads=32-blocksize_q=64-blocksize_k=128-mode=fwd.png new file mode 100644 index 0000000..e024ba8 Binary files /dev/null and b/benchmarks/mha/triton/batch_size=1-bias=False-headdim=128-num_heads=32-blocksize_q=64-blocksize_k=128-mode=fwd.png differ diff --git a/benchmarks/mha/triton/batch_size=1-bias=False-headdim=128-num_heads=32-blocksize_q=64-blocksize_k=32-mode=fwd.csv b/benchmarks/mha/triton/batch_size=1-bias=False-headdim=128-num_heads=32-blocksize_q=64-blocksize_k=32-mode=fwd.csv new file mode 100644 index 0000000..8052dab --- /dev/null +++ b/benchmarks/mha/triton/batch_size=1-bias=False-headdim=128-num_heads=32-blocksize_q=64-blocksize_k=32-mode=fwd.csv @@ -0,0 +1,6 @@ +seqlen,Triton-BlockPtr,Triton-PtrBlock,Jax +1024.000000,0.496633,0.496985,0.951059 +2048.000000,0.713811,0.683929,3.521397 +4096.000000,5.992388,4.051164,13.698790 +6144.000000,14.462778,14.478235,31.428032 +8192.000000,23.660486,23.670952,54.098175 diff --git a/benchmarks/mha/triton/batch_size=1-bias=False-headdim=128-num_heads=32-blocksize_q=64-blocksize_k=32-mode=fwd.png b/benchmarks/mha/triton/batch_size=1-bias=False-headdim=128-num_heads=32-blocksize_q=64-blocksize_k=32-mode=fwd.png new file mode 100644 index 0000000..a6f5a14 Binary files /dev/null and b/benchmarks/mha/triton/batch_size=1-bias=False-headdim=128-num_heads=32-blocksize_q=64-blocksize_k=32-mode=fwd.png differ diff --git a/benchmarks/mha/triton/batch_size=1-bias=False-headdim=128-num_heads=32-blocksize_q=64-blocksize_k=64-mode=fwd.csv b/benchmarks/mha/triton/batch_size=1-bias=False-headdim=128-num_heads=32-blocksize_q=64-blocksize_k=64-mode=fwd.csv new file mode 100644 index 0000000..5b7c3c7 --- /dev/null +++ b/benchmarks/mha/triton/batch_size=1-bias=False-headdim=128-num_heads=32-blocksize_q=64-blocksize_k=64-mode=fwd.csv @@ -0,0 +1,6 @@ +seqlen,Triton-BlockPtr,Triton-PtrBlock,Jax +1024.000000,0.433008,0.441802,0.946560 +2048.000000,0.633380,0.601894,3.526337 +4096.000000,5.002587,5.816985,14.009102 +6144.000000,12.504795,12.634414,30.791702 +8192.000000,21.693977,21.714382,54.885471 diff --git a/benchmarks/mha/triton/batch_size=1-bias=False-headdim=128-num_heads=32-blocksize_q=64-blocksize_k=64-mode=fwd.png b/benchmarks/mha/triton/batch_size=1-bias=False-headdim=128-num_heads=32-blocksize_q=64-blocksize_k=64-mode=fwd.png new file mode 100644 index 0000000..ee2ec2a Binary files /dev/null and b/benchmarks/mha/triton/batch_size=1-bias=False-headdim=128-num_heads=32-blocksize_q=64-blocksize_k=64-mode=fwd.png differ diff --git a/benchmarks/mha/triton/batch_size=1-bias=False-headdim=128-num_heads=8-blocksize_q=128-blocksize_k=128-mode=fwd.csv b/benchmarks/mha/triton/batch_size=1-bias=False-headdim=128-num_heads=8-blocksize_q=128-blocksize_k=128-mode=fwd.csv new file mode 100644 index 0000000..6a9cf35 --- /dev/null +++ b/benchmarks/mha/triton/batch_size=1-bias=False-headdim=128-num_heads=8-blocksize_q=128-blocksize_k=128-mode=fwd.csv @@ -0,0 +1,6 @@ +seqlen,Triton-BlockPtr,Triton-PtrBlock,Jax +1024.000000,1.958384,1.934998,0.320536 +2048.000000,3.945316,3.942776,0.944790 +4096.000000,4.544990,4.558117,3.522157 +6144.000000,6.691255,6.519808,7.861914 +8192.000000,2.664960,2.376704,13.630857 diff --git a/benchmarks/mha/triton/batch_size=1-bias=False-headdim=128-num_heads=8-blocksize_q=128-blocksize_k=128-mode=fwd.png b/benchmarks/mha/triton/batch_size=1-bias=False-headdim=128-num_heads=8-blocksize_q=128-blocksize_k=128-mode=fwd.png new file mode 100644 index 0000000..9ed8f06 Binary files /dev/null and b/benchmarks/mha/triton/batch_size=1-bias=False-headdim=128-num_heads=8-blocksize_q=128-blocksize_k=128-mode=fwd.png differ diff --git a/benchmarks/mha/triton/batch_size=1-bias=False-headdim=128-num_heads=8-blocksize_q=128-blocksize_k=32-mode=fwd.csv b/benchmarks/mha/triton/batch_size=1-bias=False-headdim=128-num_heads=8-blocksize_q=128-blocksize_k=32-mode=fwd.csv new file mode 100644 index 0000000..681b013 --- /dev/null +++ b/benchmarks/mha/triton/batch_size=1-bias=False-headdim=128-num_heads=8-blocksize_q=128-blocksize_k=32-mode=fwd.csv @@ -0,0 +1,6 @@ +seqlen,Triton-BlockPtr,Triton-PtrBlock,Jax +1024.000000,0.299450,0.282076,0.315573 +2048.000000,0.335973,0.371345,0.950768 +4096.000000,0.976245,0.977657,3.573005 +6144.000000,1.223922,1.234974,7.762157 +8192.000000,1.466760,1.437840,13.701576 diff --git a/benchmarks/mha/triton/batch_size=1-bias=False-headdim=128-num_heads=8-blocksize_q=128-blocksize_k=32-mode=fwd.png b/benchmarks/mha/triton/batch_size=1-bias=False-headdim=128-num_heads=8-blocksize_q=128-blocksize_k=32-mode=fwd.png new file mode 100644 index 0000000..c1aba22 Binary files /dev/null and b/benchmarks/mha/triton/batch_size=1-bias=False-headdim=128-num_heads=8-blocksize_q=128-blocksize_k=32-mode=fwd.png differ diff --git a/benchmarks/mha/triton/batch_size=1-bias=False-headdim=128-num_heads=8-blocksize_q=128-blocksize_k=64-mode=fwd.csv b/benchmarks/mha/triton/batch_size=1-bias=False-headdim=128-num_heads=8-blocksize_q=128-blocksize_k=64-mode=fwd.csv new file mode 100644 index 0000000..6b36bdf --- /dev/null +++ b/benchmarks/mha/triton/batch_size=1-bias=False-headdim=128-num_heads=8-blocksize_q=128-blocksize_k=64-mode=fwd.csv @@ -0,0 +1,6 @@ +seqlen,Triton-BlockPtr,Triton-PtrBlock,Jax +1024.000000,1.730391,1.696853,0.323550 +2048.000000,5.271585,5.262336,0.955698 +4096.000000,8.617779,8.639847,3.568900 +6144.000000,7.742692,7.931449,7.616792 +8192.000000,3.658547,3.622707,13.906825 diff --git a/benchmarks/mha/triton/batch_size=1-bias=False-headdim=128-num_heads=8-blocksize_q=128-blocksize_k=64-mode=fwd.png b/benchmarks/mha/triton/batch_size=1-bias=False-headdim=128-num_heads=8-blocksize_q=128-blocksize_k=64-mode=fwd.png new file mode 100644 index 0000000..b4dd4c7 Binary files /dev/null and b/benchmarks/mha/triton/batch_size=1-bias=False-headdim=128-num_heads=8-blocksize_q=128-blocksize_k=64-mode=fwd.png differ diff --git a/benchmarks/mha/triton/batch_size=1-bias=False-headdim=128-num_heads=8-blocksize_q=32-blocksize_k=128-mode=fwd.csv b/benchmarks/mha/triton/batch_size=1-bias=False-headdim=128-num_heads=8-blocksize_q=32-blocksize_k=128-mode=fwd.csv new file mode 100644 index 0000000..adf05e7 --- /dev/null +++ b/benchmarks/mha/triton/batch_size=1-bias=False-headdim=128-num_heads=8-blocksize_q=32-blocksize_k=128-mode=fwd.csv @@ -0,0 +1,6 @@ +seqlen,Triton-BlockPtr,Triton-PtrBlock,Jax +1024.000000,0.300364,0.313451,0.310758 +2048.000000,0.431260,0.451925,0.940426 +4096.000000,0.650297,0.672193,3.455024 +6144.000000,0.825722,0.741293,7.913847 +8192.000000,0.879881,0.946141,13.727808 diff --git a/benchmarks/mha/triton/batch_size=1-bias=False-headdim=128-num_heads=8-blocksize_q=32-blocksize_k=128-mode=fwd.png b/benchmarks/mha/triton/batch_size=1-bias=False-headdim=128-num_heads=8-blocksize_q=32-blocksize_k=128-mode=fwd.png new file mode 100644 index 0000000..6294eaa Binary files /dev/null and b/benchmarks/mha/triton/batch_size=1-bias=False-headdim=128-num_heads=8-blocksize_q=32-blocksize_k=128-mode=fwd.png differ diff --git a/benchmarks/mha/triton/batch_size=1-bias=False-headdim=128-num_heads=8-blocksize_q=32-blocksize_k=32-mode=fwd.csv b/benchmarks/mha/triton/batch_size=1-bias=False-headdim=128-num_heads=8-blocksize_q=32-blocksize_k=32-mode=fwd.csv new file mode 100644 index 0000000..fdb8b77 --- /dev/null +++ b/benchmarks/mha/triton/batch_size=1-bias=False-headdim=128-num_heads=8-blocksize_q=32-blocksize_k=32-mode=fwd.csv @@ -0,0 +1,6 @@ +seqlen,Triton-BlockPtr,Triton-PtrBlock,Jax +1024.000000,0.320431,0.282806,0.320881 +2048.000000,0.504967,0.520851,0.961009 +4096.000000,0.796427,0.712784,3.520143 +6144.000000,0.958901,1.008452,7.763666 +8192.000000,1.179971,1.279453,13.615736 diff --git a/benchmarks/mha/triton/batch_size=1-bias=False-headdim=128-num_heads=8-blocksize_q=32-blocksize_k=32-mode=fwd.png b/benchmarks/mha/triton/batch_size=1-bias=False-headdim=128-num_heads=8-blocksize_q=32-blocksize_k=32-mode=fwd.png new file mode 100644 index 0000000..7c88b6f Binary files /dev/null and b/benchmarks/mha/triton/batch_size=1-bias=False-headdim=128-num_heads=8-blocksize_q=32-blocksize_k=32-mode=fwd.png differ diff --git a/benchmarks/mha/triton/batch_size=1-bias=False-headdim=128-num_heads=8-blocksize_q=32-blocksize_k=64-mode=fwd.csv b/benchmarks/mha/triton/batch_size=1-bias=False-headdim=128-num_heads=8-blocksize_q=32-blocksize_k=64-mode=fwd.csv new file mode 100644 index 0000000..f69e994 --- /dev/null +++ b/benchmarks/mha/triton/batch_size=1-bias=False-headdim=128-num_heads=8-blocksize_q=32-blocksize_k=64-mode=fwd.csv @@ -0,0 +1,6 @@ +seqlen,Triton-BlockPtr,Triton-PtrBlock,Jax +1024.000000,0.301363,0.307574,0.330263 +2048.000000,0.455921,0.459615,0.948956 +4096.000000,0.759461,0.709308,3.643176 +6144.000000,0.855790,0.862116,7.874294 +8192.000000,0.934822,1.067411,13.695122 diff --git a/benchmarks/mha/triton/batch_size=1-bias=False-headdim=128-num_heads=8-blocksize_q=32-blocksize_k=64-mode=fwd.png b/benchmarks/mha/triton/batch_size=1-bias=False-headdim=128-num_heads=8-blocksize_q=32-blocksize_k=64-mode=fwd.png new file mode 100644 index 0000000..18b3366 Binary files /dev/null and b/benchmarks/mha/triton/batch_size=1-bias=False-headdim=128-num_heads=8-blocksize_q=32-blocksize_k=64-mode=fwd.png differ diff --git a/benchmarks/mha/triton/batch_size=1-bias=False-headdim=128-num_heads=8-blocksize_q=64-blocksize_k=128-mode=fwd.csv b/benchmarks/mha/triton/batch_size=1-bias=False-headdim=128-num_heads=8-blocksize_q=64-blocksize_k=128-mode=fwd.csv new file mode 100644 index 0000000..e045d00 --- /dev/null +++ b/benchmarks/mha/triton/batch_size=1-bias=False-headdim=128-num_heads=8-blocksize_q=64-blocksize_k=128-mode=fwd.csv @@ -0,0 +1,6 @@ +seqlen,Triton-BlockPtr,Triton-PtrBlock,Jax +1024.000000,1.199489,1.188126,0.320194 +2048.000000,2.732049,2.691692,0.944263 +4096.000000,2.953402,2.968529,3.576501 +6144.000000,6.294714,6.171462,7.698459 +8192.000000,3.475797,3.483989,13.665248 diff --git a/benchmarks/mha/triton/batch_size=1-bias=False-headdim=128-num_heads=8-blocksize_q=64-blocksize_k=128-mode=fwd.png b/benchmarks/mha/triton/batch_size=1-bias=False-headdim=128-num_heads=8-blocksize_q=64-blocksize_k=128-mode=fwd.png new file mode 100644 index 0000000..15d079a Binary files /dev/null and b/benchmarks/mha/triton/batch_size=1-bias=False-headdim=128-num_heads=8-blocksize_q=64-blocksize_k=128-mode=fwd.png differ diff --git a/benchmarks/mha/triton/batch_size=1-bias=False-headdim=128-num_heads=8-blocksize_q=64-blocksize_k=32-mode=fwd.csv b/benchmarks/mha/triton/batch_size=1-bias=False-headdim=128-num_heads=8-blocksize_q=64-blocksize_k=32-mode=fwd.csv new file mode 100644 index 0000000..feedd26 --- /dev/null +++ b/benchmarks/mha/triton/batch_size=1-bias=False-headdim=128-num_heads=8-blocksize_q=64-blocksize_k=32-mode=fwd.csv @@ -0,0 +1,6 @@ +seqlen,Triton-BlockPtr,Triton-PtrBlock,Jax +1024.000000,0.384097,0.378137,0.316786 +2048.000000,0.418716,0.478603,0.961054 +4096.000000,0.971452,1.011760,3.492059 +6144.000000,1.341094,1.265911,7.893131 +8192.000000,1.407616,1.346013,13.595228 diff --git a/benchmarks/mha/triton/batch_size=1-bias=False-headdim=128-num_heads=8-blocksize_q=64-blocksize_k=32-mode=fwd.png b/benchmarks/mha/triton/batch_size=1-bias=False-headdim=128-num_heads=8-blocksize_q=64-blocksize_k=32-mode=fwd.png new file mode 100644 index 0000000..0ec9ee4 Binary files /dev/null and b/benchmarks/mha/triton/batch_size=1-bias=False-headdim=128-num_heads=8-blocksize_q=64-blocksize_k=32-mode=fwd.png differ diff --git a/benchmarks/mha/triton/batch_size=1-bias=False-headdim=128-num_heads=8-blocksize_q=64-blocksize_k=64-mode=fwd.csv b/benchmarks/mha/triton/batch_size=1-bias=False-headdim=128-num_heads=8-blocksize_q=64-blocksize_k=64-mode=fwd.csv new file mode 100644 index 0000000..6f512ed --- /dev/null +++ b/benchmarks/mha/triton/batch_size=1-bias=False-headdim=128-num_heads=8-blocksize_q=64-blocksize_k=64-mode=fwd.csv @@ -0,0 +1,6 @@ +seqlen,Triton-BlockPtr,Triton-PtrBlock,Jax +1024.000000,0.330636,0.319132,0.312371 +2048.000000,0.417799,0.380138,0.973631 +4096.000000,0.808951,0.787719,3.556542 +6144.000000,0.995811,1.010753,7.949232 +8192.000000,1.114297,1.104483,12.924819 diff --git a/benchmarks/mha/triton/batch_size=1-bias=False-headdim=128-num_heads=8-blocksize_q=64-blocksize_k=64-mode=fwd.png b/benchmarks/mha/triton/batch_size=1-bias=False-headdim=128-num_heads=8-blocksize_q=64-blocksize_k=64-mode=fwd.png new file mode 100644 index 0000000..584451d Binary files /dev/null and b/benchmarks/mha/triton/batch_size=1-bias=False-headdim=128-num_heads=8-blocksize_q=64-blocksize_k=64-mode=fwd.png differ diff --git a/benchmarks/mha/triton/batch_size=1-bias=False-headdim=256-num_heads=16-blocksize_q=128-blocksize_k=128-mode=fwd.csv b/benchmarks/mha/triton/batch_size=1-bias=False-headdim=256-num_heads=16-blocksize_q=128-blocksize_k=128-mode=fwd.csv new file mode 100644 index 0000000..f2a244c --- /dev/null +++ b/benchmarks/mha/triton/batch_size=1-bias=False-headdim=256-num_heads=16-blocksize_q=128-blocksize_k=128-mode=fwd.csv @@ -0,0 +1,6 @@ +seqlen,Triton-BlockPtr,Triton-PtrBlock,Jax +1024.000000,100.000000,100.000000,0.627177 +2048.000000,100.000000,100.000000,2.266261 +4096.000000,100.000000,100.000000,8.579533 +6144.000000,100.000000,100.000000,19.190802 +8192.000000,100.000000,100.000000,33.459198 diff --git a/benchmarks/mha/triton/batch_size=1-bias=False-headdim=256-num_heads=16-blocksize_q=128-blocksize_k=128-mode=fwd.png b/benchmarks/mha/triton/batch_size=1-bias=False-headdim=256-num_heads=16-blocksize_q=128-blocksize_k=128-mode=fwd.png new file mode 100644 index 0000000..3a41d69 Binary files /dev/null and b/benchmarks/mha/triton/batch_size=1-bias=False-headdim=256-num_heads=16-blocksize_q=128-blocksize_k=128-mode=fwd.png differ diff --git a/benchmarks/mha/triton/batch_size=1-bias=False-headdim=256-num_heads=16-blocksize_q=128-blocksize_k=32-mode=fwd.csv b/benchmarks/mha/triton/batch_size=1-bias=False-headdim=256-num_heads=16-blocksize_q=128-blocksize_k=32-mode=fwd.csv new file mode 100644 index 0000000..82a1ad3 --- /dev/null +++ b/benchmarks/mha/triton/batch_size=1-bias=False-headdim=256-num_heads=16-blocksize_q=128-blocksize_k=32-mode=fwd.csv @@ -0,0 +1,6 @@ +seqlen,Triton-BlockPtr,Triton-PtrBlock,Jax +1024.000000,1.690055,1.755562,0.639195 +2048.000000,2.717568,2.680005,2.225120 +4096.000000,2.809710,2.726766,8.499490 +6144.000000,0.003072,0.003072,19.367949 +8192.000000,0.423936,0.359424,34.027023 diff --git a/benchmarks/mha/triton/batch_size=1-bias=False-headdim=256-num_heads=16-blocksize_q=128-blocksize_k=32-mode=fwd.png b/benchmarks/mha/triton/batch_size=1-bias=False-headdim=256-num_heads=16-blocksize_q=128-blocksize_k=32-mode=fwd.png new file mode 100644 index 0000000..466f77c Binary files /dev/null and b/benchmarks/mha/triton/batch_size=1-bias=False-headdim=256-num_heads=16-blocksize_q=128-blocksize_k=32-mode=fwd.png differ diff --git a/benchmarks/mha/triton/batch_size=1-bias=False-headdim=256-num_heads=16-blocksize_q=128-blocksize_k=64-mode=fwd.csv b/benchmarks/mha/triton/batch_size=1-bias=False-headdim=256-num_heads=16-blocksize_q=128-blocksize_k=64-mode=fwd.csv new file mode 100644 index 0000000..35ac745 --- /dev/null +++ b/benchmarks/mha/triton/batch_size=1-bias=False-headdim=256-num_heads=16-blocksize_q=128-blocksize_k=64-mode=fwd.csv @@ -0,0 +1,6 @@ +seqlen,Triton-BlockPtr,Triton-PtrBlock,Jax +1024.000000,3.770517,3.749226,0.625865 +2048.000000,4.167748,4.096273,2.260407 +4096.000000,2.687744,2.608128,8.657213 +6144.000000,0.189952,0.217600,19.321568 +8192.000000,0.407552,0.381952,34.033310 diff --git a/benchmarks/mha/triton/batch_size=1-bias=False-headdim=256-num_heads=16-blocksize_q=128-blocksize_k=64-mode=fwd.png b/benchmarks/mha/triton/batch_size=1-bias=False-headdim=256-num_heads=16-blocksize_q=128-blocksize_k=64-mode=fwd.png new file mode 100644 index 0000000..499b6fb Binary files /dev/null and b/benchmarks/mha/triton/batch_size=1-bias=False-headdim=256-num_heads=16-blocksize_q=128-blocksize_k=64-mode=fwd.png differ diff --git a/benchmarks/mha/triton/batch_size=1-bias=False-headdim=256-num_heads=16-blocksize_q=32-blocksize_k=128-mode=fwd.csv b/benchmarks/mha/triton/batch_size=1-bias=False-headdim=256-num_heads=16-blocksize_q=32-blocksize_k=128-mode=fwd.csv new file mode 100644 index 0000000..c05e0ea --- /dev/null +++ b/benchmarks/mha/triton/batch_size=1-bias=False-headdim=256-num_heads=16-blocksize_q=32-blocksize_k=128-mode=fwd.csv @@ -0,0 +1,6 @@ +seqlen,Triton-BlockPtr,Triton-PtrBlock,Jax +1024.000000,0.526499,0.547620,0.628370 +2048.000000,0.723867,0.755144,2.264489 +4096.000000,3.418806,4.591451,8.766250 +6144.000000,21.938980,23.481184,19.320877 +8192.000000,27.087725,27.141266,33.542881 diff --git a/benchmarks/mha/triton/batch_size=1-bias=False-headdim=256-num_heads=16-blocksize_q=32-blocksize_k=128-mode=fwd.png b/benchmarks/mha/triton/batch_size=1-bias=False-headdim=256-num_heads=16-blocksize_q=32-blocksize_k=128-mode=fwd.png new file mode 100644 index 0000000..47aba83 Binary files /dev/null and b/benchmarks/mha/triton/batch_size=1-bias=False-headdim=256-num_heads=16-blocksize_q=32-blocksize_k=128-mode=fwd.png differ diff --git a/benchmarks/mha/triton/batch_size=1-bias=False-headdim=256-num_heads=16-blocksize_q=32-blocksize_k=32-mode=fwd.csv b/benchmarks/mha/triton/batch_size=1-bias=False-headdim=256-num_heads=16-blocksize_q=32-blocksize_k=32-mode=fwd.csv new file mode 100644 index 0000000..9872884 --- /dev/null +++ b/benchmarks/mha/triton/batch_size=1-bias=False-headdim=256-num_heads=16-blocksize_q=32-blocksize_k=32-mode=fwd.csv @@ -0,0 +1,6 @@ +seqlen,Triton-BlockPtr,Triton-PtrBlock,Jax +1024.000000,0.480875,0.504145,0.609149 +2048.000000,0.797078,0.773144,2.209283 +4096.000000,4.941043,6.100704,8.522633 +6144.000000,13.501738,13.587289,19.342388 +8192.000000,23.328512,22.148439,33.728001 diff --git a/benchmarks/mha/triton/batch_size=1-bias=False-headdim=256-num_heads=16-blocksize_q=32-blocksize_k=32-mode=fwd.png b/benchmarks/mha/triton/batch_size=1-bias=False-headdim=256-num_heads=16-blocksize_q=32-blocksize_k=32-mode=fwd.png new file mode 100644 index 0000000..bd9f03a Binary files /dev/null and b/benchmarks/mha/triton/batch_size=1-bias=False-headdim=256-num_heads=16-blocksize_q=32-blocksize_k=32-mode=fwd.png differ diff --git a/benchmarks/mha/triton/batch_size=1-bias=False-headdim=256-num_heads=16-blocksize_q=32-blocksize_k=64-mode=fwd.csv b/benchmarks/mha/triton/batch_size=1-bias=False-headdim=256-num_heads=16-blocksize_q=32-blocksize_k=64-mode=fwd.csv new file mode 100644 index 0000000..d4526a5 --- /dev/null +++ b/benchmarks/mha/triton/batch_size=1-bias=False-headdim=256-num_heads=16-blocksize_q=32-blocksize_k=64-mode=fwd.csv @@ -0,0 +1,6 @@ +seqlen,Triton-BlockPtr,Triton-PtrBlock,Jax +1024.000000,0.509803,0.516781,0.630811 +2048.000000,0.745072,0.764514,2.244790 +4096.000000,4.551983,3.775602,8.626266 +6144.000000,19.617281,19.697336,19.330496 +8192.000000,29.000385,29.007162,33.484287 diff --git a/benchmarks/mha/triton/batch_size=1-bias=False-headdim=256-num_heads=16-blocksize_q=32-blocksize_k=64-mode=fwd.png b/benchmarks/mha/triton/batch_size=1-bias=False-headdim=256-num_heads=16-blocksize_q=32-blocksize_k=64-mode=fwd.png new file mode 100644 index 0000000..1124990 Binary files /dev/null and b/benchmarks/mha/triton/batch_size=1-bias=False-headdim=256-num_heads=16-blocksize_q=32-blocksize_k=64-mode=fwd.png differ diff --git a/benchmarks/mha/triton/batch_size=1-bias=False-headdim=256-num_heads=16-blocksize_q=64-blocksize_k=128-mode=fwd.csv b/benchmarks/mha/triton/batch_size=1-bias=False-headdim=256-num_heads=16-blocksize_q=64-blocksize_k=128-mode=fwd.csv new file mode 100644 index 0000000..4d2c5b9 --- /dev/null +++ b/benchmarks/mha/triton/batch_size=1-bias=False-headdim=256-num_heads=16-blocksize_q=64-blocksize_k=128-mode=fwd.csv @@ -0,0 +1,6 @@ +seqlen,Triton-BlockPtr,Triton-PtrBlock,Jax +1024.000000,1.779477,1.778627,0.622846 +2048.000000,3.437056,3.412864,2.205383 +4096.000000,1.805824,1.783040,8.804002 +6144.000000,0.001024,0.208384,19.443304 +8192.000000,0.376832,0.442368,34.250336 diff --git a/benchmarks/mha/triton/batch_size=1-bias=False-headdim=256-num_heads=16-blocksize_q=64-blocksize_k=128-mode=fwd.png b/benchmarks/mha/triton/batch_size=1-bias=False-headdim=256-num_heads=16-blocksize_q=64-blocksize_k=128-mode=fwd.png new file mode 100644 index 0000000..92f5a5c Binary files /dev/null and b/benchmarks/mha/triton/batch_size=1-bias=False-headdim=256-num_heads=16-blocksize_q=64-blocksize_k=128-mode=fwd.png differ diff --git a/benchmarks/mha/triton/batch_size=1-bias=False-headdim=256-num_heads=16-blocksize_q=64-blocksize_k=32-mode=fwd.csv b/benchmarks/mha/triton/batch_size=1-bias=False-headdim=256-num_heads=16-blocksize_q=64-blocksize_k=32-mode=fwd.csv new file mode 100644 index 0000000..859ba57 --- /dev/null +++ b/benchmarks/mha/triton/batch_size=1-bias=False-headdim=256-num_heads=16-blocksize_q=64-blocksize_k=32-mode=fwd.csv @@ -0,0 +1,6 @@ +seqlen,Triton-BlockPtr,Triton-PtrBlock,Jax +1024.000000,0.645120,0.610621,0.633214 +2048.000000,1.007605,1.093786,2.315720 +4096.000000,5.101298,4.139323,8.768910 +6144.000000,22.152515,22.280949,19.421843 +8192.000000,30.962130,30.793915,34.203136 diff --git a/benchmarks/mha/triton/batch_size=1-bias=False-headdim=256-num_heads=16-blocksize_q=64-blocksize_k=32-mode=fwd.png b/benchmarks/mha/triton/batch_size=1-bias=False-headdim=256-num_heads=16-blocksize_q=64-blocksize_k=32-mode=fwd.png new file mode 100644 index 0000000..e8f7093 Binary files /dev/null and b/benchmarks/mha/triton/batch_size=1-bias=False-headdim=256-num_heads=16-blocksize_q=64-blocksize_k=32-mode=fwd.png differ diff --git a/benchmarks/mha/triton/batch_size=1-bias=False-headdim=256-num_heads=16-blocksize_q=64-blocksize_k=64-mode=fwd.csv b/benchmarks/mha/triton/batch_size=1-bias=False-headdim=256-num_heads=16-blocksize_q=64-blocksize_k=64-mode=fwd.csv new file mode 100644 index 0000000..8c1942a --- /dev/null +++ b/benchmarks/mha/triton/batch_size=1-bias=False-headdim=256-num_heads=16-blocksize_q=64-blocksize_k=64-mode=fwd.csv @@ -0,0 +1,6 @@ +seqlen,Triton-BlockPtr,Triton-PtrBlock,Jax +1024.000000,0.569162,0.581191,0.655471 +2048.000000,0.845935,0.862447,2.230848 +4096.000000,3.489118,3.482038,8.624795 +6144.000000,21.376360,21.504879,19.250439 +8192.000000,29.933382,28.852037,34.966530 diff --git a/benchmarks/mha/triton/batch_size=1-bias=False-headdim=256-num_heads=16-blocksize_q=64-blocksize_k=64-mode=fwd.png b/benchmarks/mha/triton/batch_size=1-bias=False-headdim=256-num_heads=16-blocksize_q=64-blocksize_k=64-mode=fwd.png new file mode 100644 index 0000000..3007cd0 Binary files /dev/null and b/benchmarks/mha/triton/batch_size=1-bias=False-headdim=256-num_heads=16-blocksize_q=64-blocksize_k=64-mode=fwd.png differ diff --git a/benchmarks/mha/triton/batch_size=1-bias=False-headdim=256-num_heads=32-blocksize_q=128-blocksize_k=128-mode=fwd.csv b/benchmarks/mha/triton/batch_size=1-bias=False-headdim=256-num_heads=32-blocksize_q=128-blocksize_k=128-mode=fwd.csv new file mode 100644 index 0000000..0acb616 --- /dev/null +++ b/benchmarks/mha/triton/batch_size=1-bias=False-headdim=256-num_heads=32-blocksize_q=128-blocksize_k=128-mode=fwd.csv @@ -0,0 +1,6 @@ +seqlen,Triton-BlockPtr,Triton-PtrBlock,Jax +1024.000000,100.000000,100.000000,1.137254 +2048.000000,100.000000,100.000000,4.321752 +4096.000000,100.000000,100.000000,17.474201 +6144.000000,100.000000,100.000000,33.737728 +8192.000000,100.000000,100.000000,69.831711 diff --git a/benchmarks/mha/triton/batch_size=1-bias=False-headdim=256-num_heads=32-blocksize_q=128-blocksize_k=128-mode=fwd.png b/benchmarks/mha/triton/batch_size=1-bias=False-headdim=256-num_heads=32-blocksize_q=128-blocksize_k=128-mode=fwd.png new file mode 100644 index 0000000..ae4172b Binary files /dev/null and b/benchmarks/mha/triton/batch_size=1-bias=False-headdim=256-num_heads=32-blocksize_q=128-blocksize_k=128-mode=fwd.png differ diff --git a/benchmarks/mha/triton/batch_size=1-bias=False-headdim=256-num_heads=32-blocksize_q=128-blocksize_k=32-mode=fwd.csv b/benchmarks/mha/triton/batch_size=1-bias=False-headdim=256-num_heads=32-blocksize_q=128-blocksize_k=32-mode=fwd.csv new file mode 100644 index 0000000..ebf4189 --- /dev/null +++ b/benchmarks/mha/triton/batch_size=1-bias=False-headdim=256-num_heads=32-blocksize_q=128-blocksize_k=32-mode=fwd.csv @@ -0,0 +1,6 @@ +seqlen,Triton-BlockPtr,Triton-PtrBlock,Jax +1024.000000,1.596227,1.925315,1.146032 +2048.000000,7.503242,2.461538,4.465576 +4096.000000,0.122880,0.147456,17.153101 +6144.000000,0.410624,0.412672,38.433838 +8192.000000,0.463872,0.169984,66.060287 diff --git a/benchmarks/mha/triton/batch_size=1-bias=False-headdim=256-num_heads=32-blocksize_q=128-blocksize_k=32-mode=fwd.png b/benchmarks/mha/triton/batch_size=1-bias=False-headdim=256-num_heads=32-blocksize_q=128-blocksize_k=32-mode=fwd.png new file mode 100644 index 0000000..af82924 Binary files /dev/null and b/benchmarks/mha/triton/batch_size=1-bias=False-headdim=256-num_heads=32-blocksize_q=128-blocksize_k=32-mode=fwd.png differ diff --git a/benchmarks/mha/triton/batch_size=1-bias=False-headdim=256-num_heads=32-blocksize_q=128-blocksize_k=64-mode=fwd.csv b/benchmarks/mha/triton/batch_size=1-bias=False-headdim=256-num_heads=32-blocksize_q=128-blocksize_k=64-mode=fwd.csv new file mode 100644 index 0000000..a3dcecd --- /dev/null +++ b/benchmarks/mha/triton/batch_size=1-bias=False-headdim=256-num_heads=32-blocksize_q=128-blocksize_k=64-mode=fwd.csv @@ -0,0 +1,6 @@ +seqlen,Triton-BlockPtr,Triton-PtrBlock,Jax +1024.000000,2.622098,2.547598,1.147609 +2048.000000,3.399095,4.370688,4.341678 +4096.000000,0.172032,0.242688,16.865122 +6144.000000,0.378880,0.450560,39.172657 +8192.000000,0.191488,0.200704,65.827873 diff --git a/benchmarks/mha/triton/batch_size=1-bias=False-headdim=256-num_heads=32-blocksize_q=128-blocksize_k=64-mode=fwd.png b/benchmarks/mha/triton/batch_size=1-bias=False-headdim=256-num_heads=32-blocksize_q=128-blocksize_k=64-mode=fwd.png new file mode 100644 index 0000000..95273a3 Binary files /dev/null and b/benchmarks/mha/triton/batch_size=1-bias=False-headdim=256-num_heads=32-blocksize_q=128-blocksize_k=64-mode=fwd.png differ diff --git a/benchmarks/mha/triton/batch_size=1-bias=False-headdim=256-num_heads=32-blocksize_q=32-blocksize_k=128-mode=fwd.csv b/benchmarks/mha/triton/batch_size=1-bias=False-headdim=256-num_heads=32-blocksize_q=32-blocksize_k=128-mode=fwd.csv new file mode 100644 index 0000000..7ba0b15 --- /dev/null +++ b/benchmarks/mha/triton/batch_size=1-bias=False-headdim=256-num_heads=32-blocksize_q=32-blocksize_k=128-mode=fwd.csv @@ -0,0 +1,6 @@ +seqlen,Triton-BlockPtr,Triton-PtrBlock,Jax +1024.000000,0.568077,0.613468,1.148729 +2048.000000,4.238282,6.339657,4.403178 +4096.000000,26.973183,26.835711,16.993557 +6144.000000,23.631531,23.424681,38.667427 +8192.000000,0.216405,0.171691,66.169853 diff --git a/benchmarks/mha/triton/batch_size=1-bias=False-headdim=256-num_heads=32-blocksize_q=32-blocksize_k=128-mode=fwd.png b/benchmarks/mha/triton/batch_size=1-bias=False-headdim=256-num_heads=32-blocksize_q=32-blocksize_k=128-mode=fwd.png new file mode 100644 index 0000000..5c0ca07 Binary files /dev/null and b/benchmarks/mha/triton/batch_size=1-bias=False-headdim=256-num_heads=32-blocksize_q=32-blocksize_k=128-mode=fwd.png differ diff --git a/benchmarks/mha/triton/batch_size=1-bias=False-headdim=256-num_heads=32-blocksize_q=32-blocksize_k=32-mode=fwd.csv b/benchmarks/mha/triton/batch_size=1-bias=False-headdim=256-num_heads=32-blocksize_q=32-blocksize_k=32-mode=fwd.csv new file mode 100644 index 0000000..e829bfe --- /dev/null +++ b/benchmarks/mha/triton/batch_size=1-bias=False-headdim=256-num_heads=32-blocksize_q=32-blocksize_k=32-mode=fwd.csv @@ -0,0 +1,6 @@ +seqlen,Triton-BlockPtr,Triton-PtrBlock,Jax +1024.000000,0.559732,0.550420,1.139987 +2048.000000,3.346932,3.533494,4.291349 +4096.000000,12.397217,12.419523,16.971603 +6144.000000,25.648642,25.641922,38.407166 +8192.000000,37.331764,36.476246,67.489792 diff --git a/benchmarks/mha/triton/batch_size=1-bias=False-headdim=256-num_heads=32-blocksize_q=32-blocksize_k=32-mode=fwd.png b/benchmarks/mha/triton/batch_size=1-bias=False-headdim=256-num_heads=32-blocksize_q=32-blocksize_k=32-mode=fwd.png new file mode 100644 index 0000000..1cf4cc1 Binary files /dev/null and b/benchmarks/mha/triton/batch_size=1-bias=False-headdim=256-num_heads=32-blocksize_q=32-blocksize_k=32-mode=fwd.png differ diff --git a/benchmarks/mha/triton/batch_size=1-bias=False-headdim=256-num_heads=32-blocksize_q=32-blocksize_k=64-mode=fwd.csv b/benchmarks/mha/triton/batch_size=1-bias=False-headdim=256-num_heads=32-blocksize_q=32-blocksize_k=64-mode=fwd.csv new file mode 100644 index 0000000..29ccd35 --- /dev/null +++ b/benchmarks/mha/triton/batch_size=1-bias=False-headdim=256-num_heads=32-blocksize_q=32-blocksize_k=64-mode=fwd.csv @@ -0,0 +1,6 @@ +seqlen,Triton-BlockPtr,Triton-PtrBlock,Jax +1024.000000,0.526928,0.537197,1.147143 +2048.000000,3.086804,4.748772,4.442635 +4096.000000,17.681963,17.747116,17.319195 +6144.000000,32.575573,28.529755,38.559746 +8192.000000,34.401962,34.441387,66.801666 diff --git a/benchmarks/mha/triton/batch_size=1-bias=False-headdim=256-num_heads=32-blocksize_q=32-blocksize_k=64-mode=fwd.png b/benchmarks/mha/triton/batch_size=1-bias=False-headdim=256-num_heads=32-blocksize_q=32-blocksize_k=64-mode=fwd.png new file mode 100644 index 0000000..07c4626 Binary files /dev/null and b/benchmarks/mha/triton/batch_size=1-bias=False-headdim=256-num_heads=32-blocksize_q=32-blocksize_k=64-mode=fwd.png differ diff --git a/benchmarks/mha/triton/batch_size=1-bias=False-headdim=256-num_heads=32-blocksize_q=64-blocksize_k=128-mode=fwd.csv b/benchmarks/mha/triton/batch_size=1-bias=False-headdim=256-num_heads=32-blocksize_q=64-blocksize_k=128-mode=fwd.csv new file mode 100644 index 0000000..d2b48b3 --- /dev/null +++ b/benchmarks/mha/triton/batch_size=1-bias=False-headdim=256-num_heads=32-blocksize_q=64-blocksize_k=128-mode=fwd.csv @@ -0,0 +1,6 @@ +seqlen,Triton-BlockPtr,Triton-PtrBlock,Jax +1024.000000,1.974837,2.132517,1.147373 +2048.000000,3.062656,3.145728,4.466643 +4096.000000,0.178688,0.189952,16.335476 +6144.000000,0.667648,0.258048,40.501839 +8192.000000,0.185344,0.342016,53.240833 diff --git a/benchmarks/mha/triton/batch_size=1-bias=False-headdim=256-num_heads=32-blocksize_q=64-blocksize_k=128-mode=fwd.png b/benchmarks/mha/triton/batch_size=1-bias=False-headdim=256-num_heads=32-blocksize_q=64-blocksize_k=128-mode=fwd.png new file mode 100644 index 0000000..b2e3e68 Binary files /dev/null and b/benchmarks/mha/triton/batch_size=1-bias=False-headdim=256-num_heads=32-blocksize_q=64-blocksize_k=128-mode=fwd.png differ diff --git a/benchmarks/mha/triton/batch_size=1-bias=False-headdim=256-num_heads=32-blocksize_q=64-blocksize_k=32-mode=fwd.csv b/benchmarks/mha/triton/batch_size=1-bias=False-headdim=256-num_heads=32-blocksize_q=64-blocksize_k=32-mode=fwd.csv new file mode 100644 index 0000000..25cd624 --- /dev/null +++ b/benchmarks/mha/triton/batch_size=1-bias=False-headdim=256-num_heads=32-blocksize_q=64-blocksize_k=32-mode=fwd.csv @@ -0,0 +1,6 @@ +seqlen,Triton-BlockPtr,Triton-PtrBlock,Jax +1024.000000,0.668337,0.653787,1.145579 +2048.000000,4.592921,5.123542,4.474135 +4096.000000,21.205673,21.196133,17.430937 +6144.000000,34.289894,34.251892,37.833729 +8192.000000,34.529278,34.291912,66.630653 diff --git a/benchmarks/mha/triton/batch_size=1-bias=False-headdim=256-num_heads=32-blocksize_q=64-blocksize_k=32-mode=fwd.png b/benchmarks/mha/triton/batch_size=1-bias=False-headdim=256-num_heads=32-blocksize_q=64-blocksize_k=32-mode=fwd.png new file mode 100644 index 0000000..24ec9b5 Binary files /dev/null and b/benchmarks/mha/triton/batch_size=1-bias=False-headdim=256-num_heads=32-blocksize_q=64-blocksize_k=32-mode=fwd.png differ diff --git a/benchmarks/mha/triton/batch_size=1-bias=False-headdim=256-num_heads=32-blocksize_q=64-blocksize_k=64-mode=fwd.csv b/benchmarks/mha/triton/batch_size=1-bias=False-headdim=256-num_heads=32-blocksize_q=64-blocksize_k=64-mode=fwd.csv new file mode 100644 index 0000000..56e6de9 --- /dev/null +++ b/benchmarks/mha/triton/batch_size=1-bias=False-headdim=256-num_heads=32-blocksize_q=64-blocksize_k=64-mode=fwd.csv @@ -0,0 +1,6 @@ +seqlen,Triton-BlockPtr,Triton-PtrBlock,Jax +1024.000000,0.674864,0.655316,1.138876 +2048.000000,5.248979,4.897822,4.445136 +4096.000000,19.604015,19.653028,17.258350 +6144.000000,30.898075,32.447689,41.032768 +8192.000000,31.323139,31.024536,67.509247 diff --git a/benchmarks/mha/triton/batch_size=1-bias=False-headdim=256-num_heads=32-blocksize_q=64-blocksize_k=64-mode=fwd.png b/benchmarks/mha/triton/batch_size=1-bias=False-headdim=256-num_heads=32-blocksize_q=64-blocksize_k=64-mode=fwd.png new file mode 100644 index 0000000..64599f1 Binary files /dev/null and b/benchmarks/mha/triton/batch_size=1-bias=False-headdim=256-num_heads=32-blocksize_q=64-blocksize_k=64-mode=fwd.png differ diff --git a/benchmarks/mha/triton/batch_size=1-bias=False-headdim=256-num_heads=8-blocksize_q=128-blocksize_k=128-mode=fwd.csv b/benchmarks/mha/triton/batch_size=1-bias=False-headdim=256-num_heads=8-blocksize_q=128-blocksize_k=128-mode=fwd.csv new file mode 100644 index 0000000..ba14e70 --- /dev/null +++ b/benchmarks/mha/triton/batch_size=1-bias=False-headdim=256-num_heads=8-blocksize_q=128-blocksize_k=128-mode=fwd.csv @@ -0,0 +1,6 @@ +seqlen,Triton-BlockPtr,Triton-PtrBlock,Jax +1024.000000,100.000000,100.000000,0.371843 +2048.000000,100.000000,100.000000,1.134560 +4096.000000,100.000000,100.000000,4.364967 +6144.000000,100.000000,100.000000,9.464719 +8192.000000,100.000000,100.000000,16.796684 diff --git a/benchmarks/mha/triton/batch_size=1-bias=False-headdim=256-num_heads=8-blocksize_q=128-blocksize_k=128-mode=fwd.png b/benchmarks/mha/triton/batch_size=1-bias=False-headdim=256-num_heads=8-blocksize_q=128-blocksize_k=128-mode=fwd.png new file mode 100644 index 0000000..04ba368 Binary files /dev/null and b/benchmarks/mha/triton/batch_size=1-bias=False-headdim=256-num_heads=8-blocksize_q=128-blocksize_k=128-mode=fwd.png differ diff --git a/benchmarks/mha/triton/batch_size=1-bias=False-headdim=256-num_heads=8-blocksize_q=128-blocksize_k=32-mode=fwd.csv b/benchmarks/mha/triton/batch_size=1-bias=False-headdim=256-num_heads=8-blocksize_q=128-blocksize_k=32-mode=fwd.csv new file mode 100644 index 0000000..9860a21 --- /dev/null +++ b/benchmarks/mha/triton/batch_size=1-bias=False-headdim=256-num_heads=8-blocksize_q=128-blocksize_k=32-mode=fwd.csv @@ -0,0 +1,6 @@ +seqlen,Triton-BlockPtr,Triton-PtrBlock,Jax +1024.000000,1.736341,1.675025,0.354892 +2048.000000,2.768463,2.663974,1.148710 +4096.000000,3.518600,3.634722,4.414208 +6144.000000,4.695625,4.604782,9.674906 +8192.000000,2.117888,2.124544,17.209980 diff --git a/benchmarks/mha/triton/batch_size=1-bias=False-headdim=256-num_heads=8-blocksize_q=128-blocksize_k=32-mode=fwd.png b/benchmarks/mha/triton/batch_size=1-bias=False-headdim=256-num_heads=8-blocksize_q=128-blocksize_k=32-mode=fwd.png new file mode 100644 index 0000000..572e556 Binary files /dev/null and b/benchmarks/mha/triton/batch_size=1-bias=False-headdim=256-num_heads=8-blocksize_q=128-blocksize_k=32-mode=fwd.png differ diff --git a/benchmarks/mha/triton/batch_size=1-bias=False-headdim=256-num_heads=8-blocksize_q=128-blocksize_k=64-mode=fwd.csv b/benchmarks/mha/triton/batch_size=1-bias=False-headdim=256-num_heads=8-blocksize_q=128-blocksize_k=64-mode=fwd.csv new file mode 100644 index 0000000..cf77599 --- /dev/null +++ b/benchmarks/mha/triton/batch_size=1-bias=False-headdim=256-num_heads=8-blocksize_q=128-blocksize_k=64-mode=fwd.csv @@ -0,0 +1,6 @@ +seqlen,Triton-BlockPtr,Triton-PtrBlock,Jax +1024.000000,3.780096,3.750562,0.369498 +2048.000000,6.896679,6.859847,1.143648 +4096.000000,5.990656,5.984385,4.363189 +6144.000000,5.186304,5.313024,9.639249 +8192.000000,0.001536,0.001536,16.952389 diff --git a/benchmarks/mha/triton/batch_size=1-bias=False-headdim=256-num_heads=8-blocksize_q=128-blocksize_k=64-mode=fwd.png b/benchmarks/mha/triton/batch_size=1-bias=False-headdim=256-num_heads=8-blocksize_q=128-blocksize_k=64-mode=fwd.png new file mode 100644 index 0000000..c7e37bf Binary files /dev/null and b/benchmarks/mha/triton/batch_size=1-bias=False-headdim=256-num_heads=8-blocksize_q=128-blocksize_k=64-mode=fwd.png differ diff --git a/benchmarks/mha/triton/batch_size=1-bias=False-headdim=256-num_heads=8-blocksize_q=32-blocksize_k=128-mode=fwd.csv b/benchmarks/mha/triton/batch_size=1-bias=False-headdim=256-num_heads=8-blocksize_q=32-blocksize_k=128-mode=fwd.csv new file mode 100644 index 0000000..c0b29f7 --- /dev/null +++ b/benchmarks/mha/triton/batch_size=1-bias=False-headdim=256-num_heads=8-blocksize_q=32-blocksize_k=128-mode=fwd.csv @@ -0,0 +1,6 @@ +seqlen,Triton-BlockPtr,Triton-PtrBlock,Jax +1024.000000,0.369778,0.384524,0.384959 +2048.000000,0.637275,0.681604,1.157612 +4096.000000,0.939817,0.990062,4.406981 +6144.000000,1.162985,1.177848,9.701514 +8192.000000,7.874099,5.965393,17.207348 diff --git a/benchmarks/mha/triton/batch_size=1-bias=False-headdim=256-num_heads=8-blocksize_q=32-blocksize_k=128-mode=fwd.png b/benchmarks/mha/triton/batch_size=1-bias=False-headdim=256-num_heads=8-blocksize_q=32-blocksize_k=128-mode=fwd.png new file mode 100644 index 0000000..542825d Binary files /dev/null and b/benchmarks/mha/triton/batch_size=1-bias=False-headdim=256-num_heads=8-blocksize_q=32-blocksize_k=128-mode=fwd.png differ diff --git a/benchmarks/mha/triton/batch_size=1-bias=False-headdim=256-num_heads=8-blocksize_q=32-blocksize_k=32-mode=fwd.csv b/benchmarks/mha/triton/batch_size=1-bias=False-headdim=256-num_heads=8-blocksize_q=32-blocksize_k=32-mode=fwd.csv new file mode 100644 index 0000000..3235078 --- /dev/null +++ b/benchmarks/mha/triton/batch_size=1-bias=False-headdim=256-num_heads=8-blocksize_q=32-blocksize_k=32-mode=fwd.csv @@ -0,0 +1,6 @@ +seqlen,Triton-BlockPtr,Triton-PtrBlock,Jax +1024.000000,0.224335,0.334521,0.352058 +2048.000000,0.663583,0.667566,1.141807 +4096.000000,0.880572,0.841928,4.334485 +6144.000000,2.412283,2.137417,9.139328 +8192.000000,5.922133,12.023428,16.802202 diff --git a/benchmarks/mha/triton/batch_size=1-bias=False-headdim=256-num_heads=8-blocksize_q=32-blocksize_k=32-mode=fwd.png b/benchmarks/mha/triton/batch_size=1-bias=False-headdim=256-num_heads=8-blocksize_q=32-blocksize_k=32-mode=fwd.png new file mode 100644 index 0000000..313fddd Binary files /dev/null and b/benchmarks/mha/triton/batch_size=1-bias=False-headdim=256-num_heads=8-blocksize_q=32-blocksize_k=32-mode=fwd.png differ diff --git a/benchmarks/mha/triton/batch_size=1-bias=False-headdim=256-num_heads=8-blocksize_q=32-blocksize_k=64-mode=fwd.csv b/benchmarks/mha/triton/batch_size=1-bias=False-headdim=256-num_heads=8-blocksize_q=32-blocksize_k=64-mode=fwd.csv new file mode 100644 index 0000000..ed12de9 --- /dev/null +++ b/benchmarks/mha/triton/batch_size=1-bias=False-headdim=256-num_heads=8-blocksize_q=32-blocksize_k=64-mode=fwd.csv @@ -0,0 +1,6 @@ +seqlen,Triton-BlockPtr,Triton-PtrBlock,Jax +1024.000000,0.306213,0.299393,0.365252 +2048.000000,0.624072,0.639598,1.142366 +4096.000000,0.911616,0.898196,4.494039 +6144.000000,1.088205,1.079662,9.710712 +8192.000000,5.713388,5.074370,17.145876 diff --git a/benchmarks/mha/triton/batch_size=1-bias=False-headdim=256-num_heads=8-blocksize_q=32-blocksize_k=64-mode=fwd.png b/benchmarks/mha/triton/batch_size=1-bias=False-headdim=256-num_heads=8-blocksize_q=32-blocksize_k=64-mode=fwd.png new file mode 100644 index 0000000..b1c93ac Binary files /dev/null and b/benchmarks/mha/triton/batch_size=1-bias=False-headdim=256-num_heads=8-blocksize_q=32-blocksize_k=64-mode=fwd.png differ diff --git a/benchmarks/mha/triton/batch_size=1-bias=False-headdim=256-num_heads=8-blocksize_q=64-blocksize_k=128-mode=fwd.csv b/benchmarks/mha/triton/batch_size=1-bias=False-headdim=256-num_heads=8-blocksize_q=64-blocksize_k=128-mode=fwd.csv new file mode 100644 index 0000000..87f2cbe --- /dev/null +++ b/benchmarks/mha/triton/batch_size=1-bias=False-headdim=256-num_heads=8-blocksize_q=64-blocksize_k=128-mode=fwd.csv @@ -0,0 +1,6 @@ +seqlen,Triton-BlockPtr,Triton-PtrBlock,Jax +1024.000000,2.278818,2.141492,0.368299 +2048.000000,2.949332,2.987396,1.188242 +4096.000000,4.598272,4.471296,4.455858 +6144.000000,2.391296,2.479872,9.661729 +8192.000000,0.001024,0.001536,16.916683 diff --git a/benchmarks/mha/triton/batch_size=1-bias=False-headdim=256-num_heads=8-blocksize_q=64-blocksize_k=128-mode=fwd.png b/benchmarks/mha/triton/batch_size=1-bias=False-headdim=256-num_heads=8-blocksize_q=64-blocksize_k=128-mode=fwd.png new file mode 100644 index 0000000..b2e6a8d Binary files /dev/null and b/benchmarks/mha/triton/batch_size=1-bias=False-headdim=256-num_heads=8-blocksize_q=64-blocksize_k=128-mode=fwd.png differ diff --git a/benchmarks/mha/triton/batch_size=1-bias=False-headdim=256-num_heads=8-blocksize_q=64-blocksize_k=32-mode=fwd.csv b/benchmarks/mha/triton/batch_size=1-bias=False-headdim=256-num_heads=8-blocksize_q=64-blocksize_k=32-mode=fwd.csv new file mode 100644 index 0000000..8377695 --- /dev/null +++ b/benchmarks/mha/triton/batch_size=1-bias=False-headdim=256-num_heads=8-blocksize_q=64-blocksize_k=32-mode=fwd.csv @@ -0,0 +1,6 @@ +seqlen,Triton-BlockPtr,Triton-PtrBlock,Jax +1024.000000,0.404796,0.392017,0.368611 +2048.000000,0.860952,0.851063,1.156427 +4096.000000,1.218174,1.266258,4.468285 +6144.000000,1.605363,1.592583,9.476838 +8192.000000,7.188713,10.272725,16.921192 diff --git a/benchmarks/mha/triton/batch_size=1-bias=False-headdim=256-num_heads=8-blocksize_q=64-blocksize_k=32-mode=fwd.png b/benchmarks/mha/triton/batch_size=1-bias=False-headdim=256-num_heads=8-blocksize_q=64-blocksize_k=32-mode=fwd.png new file mode 100644 index 0000000..03ac13d Binary files /dev/null and b/benchmarks/mha/triton/batch_size=1-bias=False-headdim=256-num_heads=8-blocksize_q=64-blocksize_k=32-mode=fwd.png differ diff --git a/benchmarks/mha/triton/batch_size=1-bias=False-headdim=256-num_heads=8-blocksize_q=64-blocksize_k=64-mode=fwd.csv b/benchmarks/mha/triton/batch_size=1-bias=False-headdim=256-num_heads=8-blocksize_q=64-blocksize_k=64-mode=fwd.csv new file mode 100644 index 0000000..63ad026 --- /dev/null +++ b/benchmarks/mha/triton/batch_size=1-bias=False-headdim=256-num_heads=8-blocksize_q=64-blocksize_k=64-mode=fwd.csv @@ -0,0 +1,6 @@ +seqlen,Triton-BlockPtr,Triton-PtrBlock,Jax +1024.000000,0.339307,0.368469,0.377970 +2048.000000,0.776271,0.764972,1.144708 +4096.000000,1.090184,1.113348,4.415763 +6144.000000,1.715121,1.704115,9.832918 +8192.000000,9.347237,7.519969,17.254618 diff --git a/benchmarks/mha/triton/batch_size=1-bias=False-headdim=256-num_heads=8-blocksize_q=64-blocksize_k=64-mode=fwd.png b/benchmarks/mha/triton/batch_size=1-bias=False-headdim=256-num_heads=8-blocksize_q=64-blocksize_k=64-mode=fwd.png new file mode 100644 index 0000000..d24e2a8 Binary files /dev/null and b/benchmarks/mha/triton/batch_size=1-bias=False-headdim=256-num_heads=8-blocksize_q=64-blocksize_k=64-mode=fwd.png differ diff --git a/benchmarks/mha/triton/batch_size=1-bias=False-headdim=64-num_heads=16-blocksize_q=128-blocksize_k=128-mode=bwd.csv b/benchmarks/mha/triton/batch_size=1-bias=False-headdim=64-num_heads=16-blocksize_q=128-blocksize_k=128-mode=bwd.csv new file mode 100644 index 0000000..fec5ca0 --- /dev/null +++ b/benchmarks/mha/triton/batch_size=1-bias=False-headdim=64-num_heads=16-blocksize_q=128-blocksize_k=128-mode=bwd.csv @@ -0,0 +1,6 @@ +seqlen,Triton-BlockPtr,Triton-PtrBlock,Jax +1024.000000,52.160511,54.205441,1.698163 +2048.000000,51.294209,54.280193,4.249713 +4096.000000,47.785984,45.697025,16.268471 +6144.000000,37.733376,52.141056,37.903423 +8192.000000,36.900864,37.159935,65.060867 diff --git a/benchmarks/mha/triton/batch_size=1-bias=False-headdim=64-num_heads=16-blocksize_q=128-blocksize_k=128-mode=bwd.png b/benchmarks/mha/triton/batch_size=1-bias=False-headdim=64-num_heads=16-blocksize_q=128-blocksize_k=128-mode=bwd.png new file mode 100644 index 0000000..76e6fbe Binary files /dev/null and b/benchmarks/mha/triton/batch_size=1-bias=False-headdim=64-num_heads=16-blocksize_q=128-blocksize_k=128-mode=bwd.png differ diff --git a/benchmarks/mha/triton/batch_size=1-bias=False-headdim=64-num_heads=16-blocksize_q=128-blocksize_k=128-mode=fwd.csv b/benchmarks/mha/triton/batch_size=1-bias=False-headdim=64-num_heads=16-blocksize_q=128-blocksize_k=128-mode=fwd.csv new file mode 100644 index 0000000..19bc9ea --- /dev/null +++ b/benchmarks/mha/triton/batch_size=1-bias=False-headdim=64-num_heads=16-blocksize_q=128-blocksize_k=128-mode=fwd.csv @@ -0,0 +1,6 @@ +seqlen,Triton-BlockPtr,Triton-PtrBlock,Jax +1024.000000,1.524727,1.521496,0.499171 +2048.000000,1.812220,1.772526,1.731661 +4096.000000,3.122499,3.099109,6.499397 +6144.000000,3.661597,3.467833,14.636735 +8192.000000,2.808627,2.653798,25.612350 diff --git a/benchmarks/mha/triton/batch_size=1-bias=False-headdim=64-num_heads=16-blocksize_q=128-blocksize_k=128-mode=fwd.png b/benchmarks/mha/triton/batch_size=1-bias=False-headdim=64-num_heads=16-blocksize_q=128-blocksize_k=128-mode=fwd.png new file mode 100644 index 0000000..a2bf760 Binary files /dev/null and b/benchmarks/mha/triton/batch_size=1-bias=False-headdim=64-num_heads=16-blocksize_q=128-blocksize_k=128-mode=fwd.png differ diff --git a/benchmarks/mha/triton/batch_size=1-bias=False-headdim=64-num_heads=16-blocksize_q=128-blocksize_k=32-mode=bwd.csv b/benchmarks/mha/triton/batch_size=1-bias=False-headdim=64-num_heads=16-blocksize_q=128-blocksize_k=32-mode=bwd.csv new file mode 100644 index 0000000..a6901e5 --- /dev/null +++ b/benchmarks/mha/triton/batch_size=1-bias=False-headdim=64-num_heads=16-blocksize_q=128-blocksize_k=32-mode=bwd.csv @@ -0,0 +1,6 @@ +seqlen,Triton-BlockPtr,Triton-PtrBlock,Jax +1024.000000,43.653633,43.868671,1.694281 +2048.000000,41.075714,41.712639,4.043270 +4096.000000,36.769791,37.067265,15.010315 +6144.000000,26.605057,25.541121,33.495552 +8192.000000,12.087296,12.556288,59.853920 diff --git a/benchmarks/mha/triton/batch_size=1-bias=False-headdim=64-num_heads=16-blocksize_q=128-blocksize_k=32-mode=bwd.png b/benchmarks/mha/triton/batch_size=1-bias=False-headdim=64-num_heads=16-blocksize_q=128-blocksize_k=32-mode=bwd.png new file mode 100644 index 0000000..0f055c0 Binary files /dev/null and b/benchmarks/mha/triton/batch_size=1-bias=False-headdim=64-num_heads=16-blocksize_q=128-blocksize_k=32-mode=bwd.png differ diff --git a/benchmarks/mha/triton/batch_size=1-bias=False-headdim=64-num_heads=16-blocksize_q=128-blocksize_k=32-mode=fwd.csv b/benchmarks/mha/triton/batch_size=1-bias=False-headdim=64-num_heads=16-blocksize_q=128-blocksize_k=32-mode=fwd.csv new file mode 100644 index 0000000..17bff82 --- /dev/null +++ b/benchmarks/mha/triton/batch_size=1-bias=False-headdim=64-num_heads=16-blocksize_q=128-blocksize_k=32-mode=fwd.csv @@ -0,0 +1,6 @@ +seqlen,Triton-BlockPtr,Triton-PtrBlock,Jax +1024.000000,0.303373,0.299032,0.492727 +2048.000000,0.430687,0.428421,1.692620 +4096.000000,0.888595,0.914232,6.617262 +6144.000000,1.204659,1.124719,14.414795 +8192.000000,1.398020,1.397843,25.760183 diff --git a/benchmarks/mha/triton/batch_size=1-bias=False-headdim=64-num_heads=16-blocksize_q=128-blocksize_k=32-mode=fwd.png b/benchmarks/mha/triton/batch_size=1-bias=False-headdim=64-num_heads=16-blocksize_q=128-blocksize_k=32-mode=fwd.png new file mode 100644 index 0000000..c4214ed Binary files /dev/null and b/benchmarks/mha/triton/batch_size=1-bias=False-headdim=64-num_heads=16-blocksize_q=128-blocksize_k=32-mode=fwd.png differ diff --git a/benchmarks/mha/triton/batch_size=1-bias=False-headdim=64-num_heads=16-blocksize_q=128-blocksize_k=64-mode=bwd.csv b/benchmarks/mha/triton/batch_size=1-bias=False-headdim=64-num_heads=16-blocksize_q=128-blocksize_k=64-mode=bwd.csv new file mode 100644 index 0000000..cba6c81 --- /dev/null +++ b/benchmarks/mha/triton/batch_size=1-bias=False-headdim=64-num_heads=16-blocksize_q=128-blocksize_k=64-mode=bwd.csv @@ -0,0 +1,6 @@ +seqlen,Triton-BlockPtr,Triton-PtrBlock,Jax +1024.000000,49.365501,50.030079,1.730187 +2048.000000,46.869503,47.346176,4.109924 +4096.000000,41.626114,40.020992,14.991632 +6144.000000,41.735168,46.204929,33.457664 +8192.000000,33.533951,34.434048,60.874752 diff --git a/benchmarks/mha/triton/batch_size=1-bias=False-headdim=64-num_heads=16-blocksize_q=128-blocksize_k=64-mode=bwd.png b/benchmarks/mha/triton/batch_size=1-bias=False-headdim=64-num_heads=16-blocksize_q=128-blocksize_k=64-mode=bwd.png new file mode 100644 index 0000000..99fb1a1 Binary files /dev/null and b/benchmarks/mha/triton/batch_size=1-bias=False-headdim=64-num_heads=16-blocksize_q=128-blocksize_k=64-mode=bwd.png differ diff --git a/benchmarks/mha/triton/batch_size=1-bias=False-headdim=64-num_heads=16-blocksize_q=128-blocksize_k=64-mode=fwd.csv b/benchmarks/mha/triton/batch_size=1-bias=False-headdim=64-num_heads=16-blocksize_q=128-blocksize_k=64-mode=fwd.csv new file mode 100644 index 0000000..ddc98f9 --- /dev/null +++ b/benchmarks/mha/triton/batch_size=1-bias=False-headdim=64-num_heads=16-blocksize_q=128-blocksize_k=64-mode=fwd.csv @@ -0,0 +1,6 @@ +seqlen,Triton-BlockPtr,Triton-PtrBlock,Jax +1024.000000,1.104034,1.098036,0.501938 +2048.000000,1.592134,1.653701,1.690128 +4096.000000,3.049275,3.039114,6.403024 +6144.000000,4.123050,4.024746,14.442870 +8192.000000,3.405239,3.427621,25.918463 diff --git a/benchmarks/mha/triton/batch_size=1-bias=False-headdim=64-num_heads=16-blocksize_q=128-blocksize_k=64-mode=fwd.png b/benchmarks/mha/triton/batch_size=1-bias=False-headdim=64-num_heads=16-blocksize_q=128-blocksize_k=64-mode=fwd.png new file mode 100644 index 0000000..75c99b2 Binary files /dev/null and b/benchmarks/mha/triton/batch_size=1-bias=False-headdim=64-num_heads=16-blocksize_q=128-blocksize_k=64-mode=fwd.png differ diff --git a/benchmarks/mha/triton/batch_size=1-bias=False-headdim=64-num_heads=16-blocksize_q=32-blocksize_k=128-mode=bwd.csv b/benchmarks/mha/triton/batch_size=1-bias=False-headdim=64-num_heads=16-blocksize_q=32-blocksize_k=128-mode=bwd.csv new file mode 100644 index 0000000..7161409 --- /dev/null +++ b/benchmarks/mha/triton/batch_size=1-bias=False-headdim=64-num_heads=16-blocksize_q=32-blocksize_k=128-mode=bwd.csv @@ -0,0 +1,6 @@ +seqlen,Triton-BlockPtr,Triton-PtrBlock,Jax +1024.000000,44.892704,44.855808,1.741563 +2048.000000,43.861504,46.113281,4.452190 +4096.000000,42.012161,42.339840,16.378704 +6144.000000,38.088707,38.007294,36.154541 +8192.000000,30.499840,30.746624,63.962112 diff --git a/benchmarks/mha/triton/batch_size=1-bias=False-headdim=64-num_heads=16-blocksize_q=32-blocksize_k=128-mode=bwd.png b/benchmarks/mha/triton/batch_size=1-bias=False-headdim=64-num_heads=16-blocksize_q=32-blocksize_k=128-mode=bwd.png new file mode 100644 index 0000000..808e2df Binary files /dev/null and b/benchmarks/mha/triton/batch_size=1-bias=False-headdim=64-num_heads=16-blocksize_q=32-blocksize_k=128-mode=bwd.png differ diff --git a/benchmarks/mha/triton/batch_size=1-bias=False-headdim=64-num_heads=16-blocksize_q=32-blocksize_k=128-mode=fwd.csv b/benchmarks/mha/triton/batch_size=1-bias=False-headdim=64-num_heads=16-blocksize_q=32-blocksize_k=128-mode=fwd.csv new file mode 100644 index 0000000..c0891c7 --- /dev/null +++ b/benchmarks/mha/triton/batch_size=1-bias=False-headdim=64-num_heads=16-blocksize_q=32-blocksize_k=128-mode=fwd.csv @@ -0,0 +1,6 @@ +seqlen,Triton-BlockPtr,Triton-PtrBlock,Jax +1024.000000,0.275573,0.263368,0.490252 +2048.000000,0.403158,0.384110,1.680093 +4096.000000,0.599696,0.621871,6.501139 +6144.000000,0.773967,0.758887,14.502646 +8192.000000,0.832616,0.823245,25.629450 diff --git a/benchmarks/mha/triton/batch_size=1-bias=False-headdim=64-num_heads=16-blocksize_q=32-blocksize_k=128-mode=fwd.png b/benchmarks/mha/triton/batch_size=1-bias=False-headdim=64-num_heads=16-blocksize_q=32-blocksize_k=128-mode=fwd.png new file mode 100644 index 0000000..00bc54d Binary files /dev/null and b/benchmarks/mha/triton/batch_size=1-bias=False-headdim=64-num_heads=16-blocksize_q=32-blocksize_k=128-mode=fwd.png differ diff --git a/benchmarks/mha/triton/batch_size=1-bias=False-headdim=64-num_heads=16-blocksize_q=32-blocksize_k=32-mode=bwd.csv b/benchmarks/mha/triton/batch_size=1-bias=False-headdim=64-num_heads=16-blocksize_q=32-blocksize_k=32-mode=bwd.csv new file mode 100644 index 0000000..da734c0 --- /dev/null +++ b/benchmarks/mha/triton/batch_size=1-bias=False-headdim=64-num_heads=16-blocksize_q=32-blocksize_k=32-mode=bwd.csv @@ -0,0 +1,6 @@ +seqlen,Triton-BlockPtr,Triton-PtrBlock,Jax +1024.000000,41.683456,40.755711,1.675731 +2048.000000,39.890945,43.104256,3.993672 +4096.000000,34.370049,34.316799,15.082901 +6144.000000,26.870785,26.437119,33.444862 +8192.000000,14.582272,15.560192,59.865280 diff --git a/benchmarks/mha/triton/batch_size=1-bias=False-headdim=64-num_heads=16-blocksize_q=32-blocksize_k=32-mode=bwd.png b/benchmarks/mha/triton/batch_size=1-bias=False-headdim=64-num_heads=16-blocksize_q=32-blocksize_k=32-mode=bwd.png new file mode 100644 index 0000000..5d6e436 Binary files /dev/null and b/benchmarks/mha/triton/batch_size=1-bias=False-headdim=64-num_heads=16-blocksize_q=32-blocksize_k=32-mode=bwd.png differ diff --git a/benchmarks/mha/triton/batch_size=1-bias=False-headdim=64-num_heads=16-blocksize_q=32-blocksize_k=32-mode=fwd.csv b/benchmarks/mha/triton/batch_size=1-bias=False-headdim=64-num_heads=16-blocksize_q=32-blocksize_k=32-mode=fwd.csv new file mode 100644 index 0000000..6559538 --- /dev/null +++ b/benchmarks/mha/triton/batch_size=1-bias=False-headdim=64-num_heads=16-blocksize_q=32-blocksize_k=32-mode=fwd.csv @@ -0,0 +1,6 @@ +seqlen,Triton-BlockPtr,Triton-PtrBlock,Jax +1024.000000,0.267923,0.288756,0.508724 +2048.000000,0.468710,0.476373,1.648291 +4096.000000,0.708315,0.695910,6.518453 +6144.000000,0.831820,0.863960,14.329547 +8192.000000,1.091042,1.119915,25.502806 diff --git a/benchmarks/mha/triton/batch_size=1-bias=False-headdim=64-num_heads=16-blocksize_q=32-blocksize_k=32-mode=fwd.png b/benchmarks/mha/triton/batch_size=1-bias=False-headdim=64-num_heads=16-blocksize_q=32-blocksize_k=32-mode=fwd.png new file mode 100644 index 0000000..9d6491e Binary files /dev/null and b/benchmarks/mha/triton/batch_size=1-bias=False-headdim=64-num_heads=16-blocksize_q=32-blocksize_k=32-mode=fwd.png differ diff --git a/benchmarks/mha/triton/batch_size=1-bias=False-headdim=64-num_heads=16-blocksize_q=32-blocksize_k=64-mode=bwd.csv b/benchmarks/mha/triton/batch_size=1-bias=False-headdim=64-num_heads=16-blocksize_q=32-blocksize_k=64-mode=bwd.csv new file mode 100644 index 0000000..cfab6f8 --- /dev/null +++ b/benchmarks/mha/triton/batch_size=1-bias=False-headdim=64-num_heads=16-blocksize_q=32-blocksize_k=64-mode=bwd.csv @@ -0,0 +1,6 @@ +seqlen,Triton-BlockPtr,Triton-PtrBlock,Jax +1024.000000,41.755234,40.783875,1.714975 +2048.000000,40.426498,42.135551,4.034365 +4096.000000,37.232128,43.988480,18.324589 +6144.000000,30.871040,30.644737,34.408016 +8192.000000,22.926849,22.830080,63.136768 diff --git a/benchmarks/mha/triton/batch_size=1-bias=False-headdim=64-num_heads=16-blocksize_q=32-blocksize_k=64-mode=bwd.png b/benchmarks/mha/triton/batch_size=1-bias=False-headdim=64-num_heads=16-blocksize_q=32-blocksize_k=64-mode=bwd.png new file mode 100644 index 0000000..3c4c92e Binary files /dev/null and b/benchmarks/mha/triton/batch_size=1-bias=False-headdim=64-num_heads=16-blocksize_q=32-blocksize_k=64-mode=bwd.png differ diff --git a/benchmarks/mha/triton/batch_size=1-bias=False-headdim=64-num_heads=16-blocksize_q=32-blocksize_k=64-mode=fwd.csv b/benchmarks/mha/triton/batch_size=1-bias=False-headdim=64-num_heads=16-blocksize_q=32-blocksize_k=64-mode=fwd.csv new file mode 100644 index 0000000..49d7a1a --- /dev/null +++ b/benchmarks/mha/triton/batch_size=1-bias=False-headdim=64-num_heads=16-blocksize_q=32-blocksize_k=64-mode=fwd.csv @@ -0,0 +1,6 @@ +seqlen,Triton-BlockPtr,Triton-PtrBlock,Jax +1024.000000,0.283083,0.289570,0.495058 +2048.000000,0.414450,0.424834,1.668225 +4096.000000,0.668033,0.643956,6.780261 +6144.000000,0.832643,0.743972,14.482822 +8192.000000,0.890290,0.888354,25.620234 diff --git a/benchmarks/mha/triton/batch_size=1-bias=False-headdim=64-num_heads=16-blocksize_q=32-blocksize_k=64-mode=fwd.png b/benchmarks/mha/triton/batch_size=1-bias=False-headdim=64-num_heads=16-blocksize_q=32-blocksize_k=64-mode=fwd.png new file mode 100644 index 0000000..aac125c Binary files /dev/null and b/benchmarks/mha/triton/batch_size=1-bias=False-headdim=64-num_heads=16-blocksize_q=32-blocksize_k=64-mode=fwd.png differ diff --git a/benchmarks/mha/triton/batch_size=1-bias=False-headdim=64-num_heads=16-blocksize_q=64-blocksize_k=128-mode=bwd.csv b/benchmarks/mha/triton/batch_size=1-bias=False-headdim=64-num_heads=16-blocksize_q=64-blocksize_k=128-mode=bwd.csv new file mode 100644 index 0000000..fe31fb8 --- /dev/null +++ b/benchmarks/mha/triton/batch_size=1-bias=False-headdim=64-num_heads=16-blocksize_q=64-blocksize_k=128-mode=bwd.csv @@ -0,0 +1,6 @@ +seqlen,Triton-BlockPtr,Triton-PtrBlock,Jax +1024.000000,46.689278,47.676414,1.656079 +2048.000000,52.752899,47.375870,4.110466 +4096.000000,42.417152,42.398209,14.973857 +6144.000000,33.818623,34.609154,33.393311 +8192.000000,42.629120,34.408958,59.381920 diff --git a/benchmarks/mha/triton/batch_size=1-bias=False-headdim=64-num_heads=16-blocksize_q=64-blocksize_k=128-mode=bwd.png b/benchmarks/mha/triton/batch_size=1-bias=False-headdim=64-num_heads=16-blocksize_q=64-blocksize_k=128-mode=bwd.png new file mode 100644 index 0000000..0b1613d Binary files /dev/null and b/benchmarks/mha/triton/batch_size=1-bias=False-headdim=64-num_heads=16-blocksize_q=64-blocksize_k=128-mode=bwd.png differ diff --git a/benchmarks/mha/triton/batch_size=1-bias=False-headdim=64-num_heads=16-blocksize_q=64-blocksize_k=128-mode=fwd.csv b/benchmarks/mha/triton/batch_size=1-bias=False-headdim=64-num_heads=16-blocksize_q=64-blocksize_k=128-mode=fwd.csv new file mode 100644 index 0000000..c852660 --- /dev/null +++ b/benchmarks/mha/triton/batch_size=1-bias=False-headdim=64-num_heads=16-blocksize_q=64-blocksize_k=128-mode=fwd.csv @@ -0,0 +1,6 @@ +seqlen,Triton-BlockPtr,Triton-PtrBlock,Jax +1024.000000,0.657344,0.652755,0.531304 +2048.000000,0.974639,0.926502,1.702639 +4096.000000,1.494973,1.497175,6.492184 +6144.000000,1.905828,2.001053,14.037692 +8192.000000,2.419166,2.368512,25.635157 diff --git a/benchmarks/mha/triton/batch_size=1-bias=False-headdim=64-num_heads=16-blocksize_q=64-blocksize_k=128-mode=fwd.png b/benchmarks/mha/triton/batch_size=1-bias=False-headdim=64-num_heads=16-blocksize_q=64-blocksize_k=128-mode=fwd.png new file mode 100644 index 0000000..f9fae4e Binary files /dev/null and b/benchmarks/mha/triton/batch_size=1-bias=False-headdim=64-num_heads=16-blocksize_q=64-blocksize_k=128-mode=fwd.png differ diff --git a/benchmarks/mha/triton/batch_size=1-bias=False-headdim=64-num_heads=16-blocksize_q=64-blocksize_k=32-mode=bwd.csv b/benchmarks/mha/triton/batch_size=1-bias=False-headdim=64-num_heads=16-blocksize_q=64-blocksize_k=32-mode=bwd.csv new file mode 100644 index 0000000..1da9f0f --- /dev/null +++ b/benchmarks/mha/triton/batch_size=1-bias=False-headdim=64-num_heads=16-blocksize_q=64-blocksize_k=32-mode=bwd.csv @@ -0,0 +1,6 @@ +seqlen,Triton-BlockPtr,Triton-PtrBlock,Jax +1024.000000,43.619843,43.843071,1.780928 +2048.000000,44.911613,43.822594,4.263554 +4096.000000,36.800003,37.216766,16.327387 +6144.000000,26.360321,25.637888,37.076050 +8192.000000,10.979328,11.812864,64.386047 diff --git a/benchmarks/mha/triton/batch_size=1-bias=False-headdim=64-num_heads=16-blocksize_q=64-blocksize_k=32-mode=bwd.png b/benchmarks/mha/triton/batch_size=1-bias=False-headdim=64-num_heads=16-blocksize_q=64-blocksize_k=32-mode=bwd.png new file mode 100644 index 0000000..4188d68 Binary files /dev/null and b/benchmarks/mha/triton/batch_size=1-bias=False-headdim=64-num_heads=16-blocksize_q=64-blocksize_k=32-mode=bwd.png differ diff --git a/benchmarks/mha/triton/batch_size=1-bias=False-headdim=64-num_heads=16-blocksize_q=64-blocksize_k=32-mode=fwd.csv b/benchmarks/mha/triton/batch_size=1-bias=False-headdim=64-num_heads=16-blocksize_q=64-blocksize_k=32-mode=fwd.csv new file mode 100644 index 0000000..c2fdd13 --- /dev/null +++ b/benchmarks/mha/triton/batch_size=1-bias=False-headdim=64-num_heads=16-blocksize_q=64-blocksize_k=32-mode=fwd.csv @@ -0,0 +1,6 @@ +seqlen,Triton-BlockPtr,Triton-PtrBlock,Jax +1024.000000,0.292285,0.283281,0.503054 +2048.000000,0.297977,0.274140,1.689561 +4096.000000,0.690927,0.658178,6.690127 +6144.000000,0.766031,0.771679,14.178992 +8192.000000,0.952094,0.945233,25.653334 diff --git a/benchmarks/mha/triton/batch_size=1-bias=False-headdim=64-num_heads=16-blocksize_q=64-blocksize_k=32-mode=fwd.png b/benchmarks/mha/triton/batch_size=1-bias=False-headdim=64-num_heads=16-blocksize_q=64-blocksize_k=32-mode=fwd.png new file mode 100644 index 0000000..ec87a09 Binary files /dev/null and b/benchmarks/mha/triton/batch_size=1-bias=False-headdim=64-num_heads=16-blocksize_q=64-blocksize_k=32-mode=fwd.png differ diff --git a/benchmarks/mha/triton/batch_size=1-bias=False-headdim=64-num_heads=16-blocksize_q=64-blocksize_k=64-mode=bwd.csv b/benchmarks/mha/triton/batch_size=1-bias=False-headdim=64-num_heads=16-blocksize_q=64-blocksize_k=64-mode=bwd.csv new file mode 100644 index 0000000..e00de2a --- /dev/null +++ b/benchmarks/mha/triton/batch_size=1-bias=False-headdim=64-num_heads=16-blocksize_q=64-blocksize_k=64-mode=bwd.csv @@ -0,0 +1,6 @@ +seqlen,Triton-BlockPtr,Triton-PtrBlock,Jax +1024.000000,44.706306,46.061058,1.769020 +2048.000000,44.637695,44.380672,4.279611 +4096.000000,40.455681,40.014336,16.366161 +6144.000000,33.285118,32.748543,35.928623 +8192.000000,23.710209,24.669184,63.973408 diff --git a/benchmarks/mha/triton/batch_size=1-bias=False-headdim=64-num_heads=16-blocksize_q=64-blocksize_k=64-mode=bwd.png b/benchmarks/mha/triton/batch_size=1-bias=False-headdim=64-num_heads=16-blocksize_q=64-blocksize_k=64-mode=bwd.png new file mode 100644 index 0000000..b036489 Binary files /dev/null and b/benchmarks/mha/triton/batch_size=1-bias=False-headdim=64-num_heads=16-blocksize_q=64-blocksize_k=64-mode=bwd.png differ diff --git a/benchmarks/mha/triton/batch_size=1-bias=False-headdim=64-num_heads=16-blocksize_q=64-blocksize_k=64-mode=fwd.csv b/benchmarks/mha/triton/batch_size=1-bias=False-headdim=64-num_heads=16-blocksize_q=64-blocksize_k=64-mode=fwd.csv new file mode 100644 index 0000000..8cdad73 --- /dev/null +++ b/benchmarks/mha/triton/batch_size=1-bias=False-headdim=64-num_heads=16-blocksize_q=64-blocksize_k=64-mode=fwd.csv @@ -0,0 +1,6 @@ +seqlen,Triton-BlockPtr,Triton-PtrBlock,Jax +1024.000000,0.298225,0.317259,0.488989 +2048.000000,0.259732,0.273314,1.696744 +4096.000000,0.625925,0.606134,6.562274 +6144.000000,0.638592,0.620856,14.674603 +8192.000000,0.759344,0.909012,24.447786 diff --git a/benchmarks/mha/triton/batch_size=1-bias=False-headdim=64-num_heads=16-blocksize_q=64-blocksize_k=64-mode=fwd.png b/benchmarks/mha/triton/batch_size=1-bias=False-headdim=64-num_heads=16-blocksize_q=64-blocksize_k=64-mode=fwd.png new file mode 100644 index 0000000..a4614de Binary files /dev/null and b/benchmarks/mha/triton/batch_size=1-bias=False-headdim=64-num_heads=16-blocksize_q=64-blocksize_k=64-mode=fwd.png differ diff --git a/benchmarks/mha/triton/batch_size=1-bias=False-headdim=64-num_heads=32-blocksize_q=128-blocksize_k=128-mode=bwd.csv b/benchmarks/mha/triton/batch_size=1-bias=False-headdim=64-num_heads=32-blocksize_q=128-blocksize_k=128-mode=bwd.csv new file mode 100644 index 0000000..06a5539 --- /dev/null +++ b/benchmarks/mha/triton/batch_size=1-bias=False-headdim=64-num_heads=32-blocksize_q=128-blocksize_k=128-mode=bwd.csv @@ -0,0 +1,6 @@ +seqlen,Triton-BlockPtr,Triton-PtrBlock,Jax +1024.000000,52.079617,51.291138,2.198667 +2048.000000,49.489922,50.028545,300.000000 +4096.000000,44.910591,54.361088,300.000000 +6144.000000,37.248001,35.458050,300.000000 +8192.000000,39.242752,37.771263,100.000000 diff --git a/benchmarks/mha/triton/batch_size=1-bias=False-headdim=64-num_heads=32-blocksize_q=128-blocksize_k=128-mode=bwd.png b/benchmarks/mha/triton/batch_size=1-bias=False-headdim=64-num_heads=32-blocksize_q=128-blocksize_k=128-mode=bwd.png new file mode 100644 index 0000000..6591316 Binary files /dev/null and b/benchmarks/mha/triton/batch_size=1-bias=False-headdim=64-num_heads=32-blocksize_q=128-blocksize_k=128-mode=bwd.png differ diff --git a/benchmarks/mha/triton/batch_size=1-bias=False-headdim=64-num_heads=32-blocksize_q=128-blocksize_k=128-mode=fwd.csv b/benchmarks/mha/triton/batch_size=1-bias=False-headdim=64-num_heads=32-blocksize_q=128-blocksize_k=128-mode=fwd.csv new file mode 100644 index 0000000..5380715 --- /dev/null +++ b/benchmarks/mha/triton/batch_size=1-bias=False-headdim=64-num_heads=32-blocksize_q=128-blocksize_k=128-mode=fwd.csv @@ -0,0 +1,6 @@ +seqlen,Triton-BlockPtr,Triton-PtrBlock,Jax +1024.000000,1.082441,1.162919,0.901414 +2048.000000,1.988023,1.901289,3.394349 +4096.000000,2.936218,2.928538,13.135702 +6144.000000,1.553664,2.496307,28.691530 +8192.000000,0.003072,0.002048,51.007648 diff --git a/benchmarks/mha/triton/batch_size=1-bias=False-headdim=64-num_heads=32-blocksize_q=128-blocksize_k=128-mode=fwd.png b/benchmarks/mha/triton/batch_size=1-bias=False-headdim=64-num_heads=32-blocksize_q=128-blocksize_k=128-mode=fwd.png new file mode 100644 index 0000000..0f0701e Binary files /dev/null and b/benchmarks/mha/triton/batch_size=1-bias=False-headdim=64-num_heads=32-blocksize_q=128-blocksize_k=128-mode=fwd.png differ diff --git a/benchmarks/mha/triton/batch_size=1-bias=False-headdim=64-num_heads=32-blocksize_q=128-blocksize_k=32-mode=bwd.csv b/benchmarks/mha/triton/batch_size=1-bias=False-headdim=64-num_heads=32-blocksize_q=128-blocksize_k=32-mode=bwd.csv new file mode 100644 index 0000000..fbf3ec7 --- /dev/null +++ b/benchmarks/mha/triton/batch_size=1-bias=False-headdim=64-num_heads=32-blocksize_q=128-blocksize_k=32-mode=bwd.csv @@ -0,0 +1,6 @@ +seqlen,Triton-BlockPtr,Triton-PtrBlock,Jax +1024.000000,42.233345,43.237377,2.177752 +2048.000000,40.549377,40.032768,300.000000 +4096.000000,27.828735,30.327808,300.000000 +6144.000000,10.635264,10.809343,300.000000 +8192.000000,0.002048,0.001024,100.000000 diff --git a/benchmarks/mha/triton/batch_size=1-bias=False-headdim=64-num_heads=32-blocksize_q=128-blocksize_k=32-mode=bwd.png b/benchmarks/mha/triton/batch_size=1-bias=False-headdim=64-num_heads=32-blocksize_q=128-blocksize_k=32-mode=bwd.png new file mode 100644 index 0000000..cfa66dc Binary files /dev/null and b/benchmarks/mha/triton/batch_size=1-bias=False-headdim=64-num_heads=32-blocksize_q=128-blocksize_k=32-mode=bwd.png differ diff --git a/benchmarks/mha/triton/batch_size=1-bias=False-headdim=64-num_heads=32-blocksize_q=128-blocksize_k=32-mode=fwd.csv b/benchmarks/mha/triton/batch_size=1-bias=False-headdim=64-num_heads=32-blocksize_q=128-blocksize_k=32-mode=fwd.csv new file mode 100644 index 0000000..7a898d7 --- /dev/null +++ b/benchmarks/mha/triton/batch_size=1-bias=False-headdim=64-num_heads=32-blocksize_q=128-blocksize_k=32-mode=fwd.csv @@ -0,0 +1,6 @@ +seqlen,Triton-BlockPtr,Triton-PtrBlock,Jax +1024.000000,0.292623,0.300019,0.895510 +2048.000000,0.640768,0.649501,3.382843 +4096.000000,1.007894,1.090071,12.849183 +6144.000000,1.733537,2.670486,28.478806 +8192.000000,7.564288,7.766414,50.998272 diff --git a/benchmarks/mha/triton/batch_size=1-bias=False-headdim=64-num_heads=32-blocksize_q=128-blocksize_k=32-mode=fwd.png b/benchmarks/mha/triton/batch_size=1-bias=False-headdim=64-num_heads=32-blocksize_q=128-blocksize_k=32-mode=fwd.png new file mode 100644 index 0000000..3104292 Binary files /dev/null and b/benchmarks/mha/triton/batch_size=1-bias=False-headdim=64-num_heads=32-blocksize_q=128-blocksize_k=32-mode=fwd.png differ diff --git a/benchmarks/mha/triton/batch_size=1-bias=False-headdim=64-num_heads=32-blocksize_q=128-blocksize_k=64-mode=bwd.csv b/benchmarks/mha/triton/batch_size=1-bias=False-headdim=64-num_heads=32-blocksize_q=128-blocksize_k=64-mode=bwd.csv new file mode 100644 index 0000000..150d598 --- /dev/null +++ b/benchmarks/mha/triton/batch_size=1-bias=False-headdim=64-num_heads=32-blocksize_q=128-blocksize_k=64-mode=bwd.csv @@ -0,0 +1,6 @@ +seqlen,Triton-BlockPtr,Triton-PtrBlock,Jax +1024.000000,47.959042,48.157185,2.170290 +2048.000000,47.318016,47.245827,300.000000 +4096.000000,49.299454,47.794174,300.000000 +6144.000000,32.896000,33.789951,300.000000 +8192.000000,38.889473,35.369984,100.000000 diff --git a/benchmarks/mha/triton/batch_size=1-bias=False-headdim=64-num_heads=32-blocksize_q=128-blocksize_k=64-mode=bwd.png b/benchmarks/mha/triton/batch_size=1-bias=False-headdim=64-num_heads=32-blocksize_q=128-blocksize_k=64-mode=bwd.png new file mode 100644 index 0000000..5259287 Binary files /dev/null and b/benchmarks/mha/triton/batch_size=1-bias=False-headdim=64-num_heads=32-blocksize_q=128-blocksize_k=64-mode=bwd.png differ diff --git a/benchmarks/mha/triton/batch_size=1-bias=False-headdim=64-num_heads=32-blocksize_q=128-blocksize_k=64-mode=fwd.csv b/benchmarks/mha/triton/batch_size=1-bias=False-headdim=64-num_heads=32-blocksize_q=128-blocksize_k=64-mode=fwd.csv new file mode 100644 index 0000000..c2c96fc --- /dev/null +++ b/benchmarks/mha/triton/batch_size=1-bias=False-headdim=64-num_heads=32-blocksize_q=128-blocksize_k=64-mode=fwd.csv @@ -0,0 +1,6 @@ +seqlen,Triton-BlockPtr,Triton-PtrBlock,Jax +1024.000000,1.071115,1.156618,0.901076 +2048.000000,1.874269,1.843858,3.372844 +4096.000000,3.200220,3.089115,12.916680 +6144.000000,2.786816,2.735616,28.420128 +8192.000000,1.979136,0.002389,51.022846 diff --git a/benchmarks/mha/triton/batch_size=1-bias=False-headdim=64-num_heads=32-blocksize_q=128-blocksize_k=64-mode=fwd.png b/benchmarks/mha/triton/batch_size=1-bias=False-headdim=64-num_heads=32-blocksize_q=128-blocksize_k=64-mode=fwd.png new file mode 100644 index 0000000..7f3b599 Binary files /dev/null and b/benchmarks/mha/triton/batch_size=1-bias=False-headdim=64-num_heads=32-blocksize_q=128-blocksize_k=64-mode=fwd.png differ diff --git a/benchmarks/mha/triton/batch_size=1-bias=False-headdim=64-num_heads=32-blocksize_q=32-blocksize_k=128-mode=bwd.csv b/benchmarks/mha/triton/batch_size=1-bias=False-headdim=64-num_heads=32-blocksize_q=32-blocksize_k=128-mode=bwd.csv new file mode 100644 index 0000000..8c99656 --- /dev/null +++ b/benchmarks/mha/triton/batch_size=1-bias=False-headdim=64-num_heads=32-blocksize_q=32-blocksize_k=128-mode=bwd.csv @@ -0,0 +1,6 @@ +seqlen,Triton-BlockPtr,Triton-PtrBlock,Jax +1024.000000,41.783806,43.025406,2.207317 +2048.000000,40.727043,41.016319,7.793104 +4096.000000,34.960384,37.386238,29.592575 +6144.000000,26.578943,28.072449,300.000000 +8192.000000,15.703552,15.120384,100.000000 diff --git a/benchmarks/mha/triton/batch_size=1-bias=False-headdim=64-num_heads=32-blocksize_q=32-blocksize_k=128-mode=bwd.png b/benchmarks/mha/triton/batch_size=1-bias=False-headdim=64-num_heads=32-blocksize_q=32-blocksize_k=128-mode=bwd.png new file mode 100644 index 0000000..b0a103a Binary files /dev/null and b/benchmarks/mha/triton/batch_size=1-bias=False-headdim=64-num_heads=32-blocksize_q=32-blocksize_k=128-mode=bwd.png differ diff --git a/benchmarks/mha/triton/batch_size=1-bias=False-headdim=64-num_heads=32-blocksize_q=32-blocksize_k=128-mode=fwd.csv b/benchmarks/mha/triton/batch_size=1-bias=False-headdim=64-num_heads=32-blocksize_q=32-blocksize_k=128-mode=fwd.csv new file mode 100644 index 0000000..6bf0a97 --- /dev/null +++ b/benchmarks/mha/triton/batch_size=1-bias=False-headdim=64-num_heads=32-blocksize_q=32-blocksize_k=128-mode=fwd.csv @@ -0,0 +1,6 @@ +seqlen,Triton-BlockPtr,Triton-PtrBlock,Jax +1024.000000,0.194274,0.197417,0.893069 +2048.000000,0.489115,0.510745,3.374991 +4096.000000,0.749033,0.736170,12.789350 +6144.000000,1.510902,0.927365,26.680799 +8192.000000,9.315708,8.356712,50.984959 diff --git a/benchmarks/mha/triton/batch_size=1-bias=False-headdim=64-num_heads=32-blocksize_q=32-blocksize_k=128-mode=fwd.png b/benchmarks/mha/triton/batch_size=1-bias=False-headdim=64-num_heads=32-blocksize_q=32-blocksize_k=128-mode=fwd.png new file mode 100644 index 0000000..80d4d00 Binary files /dev/null and b/benchmarks/mha/triton/batch_size=1-bias=False-headdim=64-num_heads=32-blocksize_q=32-blocksize_k=128-mode=fwd.png differ diff --git a/benchmarks/mha/triton/batch_size=1-bias=False-headdim=64-num_heads=32-blocksize_q=32-blocksize_k=32-mode=bwd.csv b/benchmarks/mha/triton/batch_size=1-bias=False-headdim=64-num_heads=32-blocksize_q=32-blocksize_k=32-mode=bwd.csv new file mode 100644 index 0000000..dce2588 --- /dev/null +++ b/benchmarks/mha/triton/batch_size=1-bias=False-headdim=64-num_heads=32-blocksize_q=32-blocksize_k=32-mode=bwd.csv @@ -0,0 +1,6 @@ +seqlen,Triton-BlockPtr,Triton-PtrBlock,Jax +1024.000000,42.449921,42.796547,2.238206 +2048.000000,38.190590,37.873154,7.730373 +4096.000000,28.506111,27.966976,29.548096 +6144.000000,13.130239,14.025728,66.506752 +8192.000000,0.001024,0.001024,100.000000 diff --git a/benchmarks/mha/triton/batch_size=1-bias=False-headdim=64-num_heads=32-blocksize_q=32-blocksize_k=32-mode=bwd.png b/benchmarks/mha/triton/batch_size=1-bias=False-headdim=64-num_heads=32-blocksize_q=32-blocksize_k=32-mode=bwd.png new file mode 100644 index 0000000..b475ab2 Binary files /dev/null and b/benchmarks/mha/triton/batch_size=1-bias=False-headdim=64-num_heads=32-blocksize_q=32-blocksize_k=32-mode=bwd.png differ diff --git a/benchmarks/mha/triton/batch_size=1-bias=False-headdim=64-num_heads=32-blocksize_q=32-blocksize_k=32-mode=fwd.csv b/benchmarks/mha/triton/batch_size=1-bias=False-headdim=64-num_heads=32-blocksize_q=32-blocksize_k=32-mode=fwd.csv new file mode 100644 index 0000000..108a276 --- /dev/null +++ b/benchmarks/mha/triton/batch_size=1-bias=False-headdim=64-num_heads=32-blocksize_q=32-blocksize_k=32-mode=fwd.csv @@ -0,0 +1,6 @@ +seqlen,Triton-BlockPtr,Triton-PtrBlock,Jax +1024.000000,0.268104,0.257740,0.897739 +2048.000000,0.588016,0.590320,3.320127 +4096.000000,0.758632,0.821480,12.905618 +6144.000000,2.374976,1.114896,28.482285 +8192.000000,7.673276,7.397023,50.958336 diff --git a/benchmarks/mha/triton/batch_size=1-bias=False-headdim=64-num_heads=32-blocksize_q=32-blocksize_k=32-mode=fwd.png b/benchmarks/mha/triton/batch_size=1-bias=False-headdim=64-num_heads=32-blocksize_q=32-blocksize_k=32-mode=fwd.png new file mode 100644 index 0000000..25c6452 Binary files /dev/null and b/benchmarks/mha/triton/batch_size=1-bias=False-headdim=64-num_heads=32-blocksize_q=32-blocksize_k=32-mode=fwd.png differ diff --git a/benchmarks/mha/triton/batch_size=1-bias=False-headdim=64-num_heads=32-blocksize_q=32-blocksize_k=64-mode=bwd.csv b/benchmarks/mha/triton/batch_size=1-bias=False-headdim=64-num_heads=32-blocksize_q=32-blocksize_k=64-mode=bwd.csv new file mode 100644 index 0000000..944c8f1 --- /dev/null +++ b/benchmarks/mha/triton/batch_size=1-bias=False-headdim=64-num_heads=32-blocksize_q=32-blocksize_k=64-mode=bwd.csv @@ -0,0 +1,6 @@ +seqlen,Triton-BlockPtr,Triton-PtrBlock,Jax +1024.000000,41.471489,41.312256,2.191019 +2048.000000,41.888771,39.794174,8.093368 +4096.000000,31.771135,32.185341,29.881342 +6144.000000,19.444225,20.192257,300.000000 +8192.000000,16.817152,18.491905,100.000000 diff --git a/benchmarks/mha/triton/batch_size=1-bias=False-headdim=64-num_heads=32-blocksize_q=32-blocksize_k=64-mode=bwd.png b/benchmarks/mha/triton/batch_size=1-bias=False-headdim=64-num_heads=32-blocksize_q=32-blocksize_k=64-mode=bwd.png new file mode 100644 index 0000000..a85e15d Binary files /dev/null and b/benchmarks/mha/triton/batch_size=1-bias=False-headdim=64-num_heads=32-blocksize_q=32-blocksize_k=64-mode=bwd.png differ diff --git a/benchmarks/mha/triton/batch_size=1-bias=False-headdim=64-num_heads=32-blocksize_q=32-blocksize_k=64-mode=fwd.csv b/benchmarks/mha/triton/batch_size=1-bias=False-headdim=64-num_heads=32-blocksize_q=32-blocksize_k=64-mode=fwd.csv new file mode 100644 index 0000000..e793fd8 --- /dev/null +++ b/benchmarks/mha/triton/batch_size=1-bias=False-headdim=64-num_heads=32-blocksize_q=32-blocksize_k=64-mode=fwd.csv @@ -0,0 +1,6 @@ +seqlen,Triton-BlockPtr,Triton-PtrBlock,Jax +1024.000000,0.278821,0.203902,0.904601 +2048.000000,0.495527,0.524544,3.348698 +4096.000000,0.831747,0.756073,12.721632 +6144.000000,1.585729,1.062855,28.581675 +8192.000000,6.993608,8.414359,50.848831 diff --git a/benchmarks/mha/triton/batch_size=1-bias=False-headdim=64-num_heads=32-blocksize_q=32-blocksize_k=64-mode=fwd.png b/benchmarks/mha/triton/batch_size=1-bias=False-headdim=64-num_heads=32-blocksize_q=32-blocksize_k=64-mode=fwd.png new file mode 100644 index 0000000..c9ca735 Binary files /dev/null and b/benchmarks/mha/triton/batch_size=1-bias=False-headdim=64-num_heads=32-blocksize_q=32-blocksize_k=64-mode=fwd.png differ diff --git a/benchmarks/mha/triton/batch_size=1-bias=False-headdim=64-num_heads=32-blocksize_q=64-blocksize_k=128-mode=bwd.csv b/benchmarks/mha/triton/batch_size=1-bias=False-headdim=64-num_heads=32-blocksize_q=64-blocksize_k=128-mode=bwd.csv new file mode 100644 index 0000000..964b7e5 --- /dev/null +++ b/benchmarks/mha/triton/batch_size=1-bias=False-headdim=64-num_heads=32-blocksize_q=64-blocksize_k=128-mode=bwd.csv @@ -0,0 +1,6 @@ +seqlen,Triton-BlockPtr,Triton-PtrBlock,Jax +1024.000000,48.079872,48.340481,2.215052 +2048.000000,47.191551,44.837376,300.000000 +4096.000000,37.274624,37.072895,300.000000 +6144.000000,47.077377,48.559105,300.000000 +8192.000000,30.067713,29.582336,100.000000 diff --git a/benchmarks/mha/triton/batch_size=1-bias=False-headdim=64-num_heads=32-blocksize_q=64-blocksize_k=128-mode=bwd.png b/benchmarks/mha/triton/batch_size=1-bias=False-headdim=64-num_heads=32-blocksize_q=64-blocksize_k=128-mode=bwd.png new file mode 100644 index 0000000..3dc5d4c Binary files /dev/null and b/benchmarks/mha/triton/batch_size=1-bias=False-headdim=64-num_heads=32-blocksize_q=64-blocksize_k=128-mode=bwd.png differ diff --git a/benchmarks/mha/triton/batch_size=1-bias=False-headdim=64-num_heads=32-blocksize_q=64-blocksize_k=128-mode=fwd.csv b/benchmarks/mha/triton/batch_size=1-bias=False-headdim=64-num_heads=32-blocksize_q=64-blocksize_k=128-mode=fwd.csv new file mode 100644 index 0000000..82362b0 --- /dev/null +++ b/benchmarks/mha/triton/batch_size=1-bias=False-headdim=64-num_heads=32-blocksize_q=64-blocksize_k=128-mode=fwd.csv @@ -0,0 +1,6 @@ +seqlen,Triton-BlockPtr,Triton-PtrBlock,Jax +1024.000000,0.692391,0.690284,0.906840 +2048.000000,0.966737,1.074766,3.366642 +4096.000000,1.568351,1.579410,12.816339 +6144.000000,2.231735,2.267355,27.133463 +8192.000000,2.602624,2.479232,51.085312 diff --git a/benchmarks/mha/triton/batch_size=1-bias=False-headdim=64-num_heads=32-blocksize_q=64-blocksize_k=128-mode=fwd.png b/benchmarks/mha/triton/batch_size=1-bias=False-headdim=64-num_heads=32-blocksize_q=64-blocksize_k=128-mode=fwd.png new file mode 100644 index 0000000..15a0843 Binary files /dev/null and b/benchmarks/mha/triton/batch_size=1-bias=False-headdim=64-num_heads=32-blocksize_q=64-blocksize_k=128-mode=fwd.png differ diff --git a/benchmarks/mha/triton/batch_size=1-bias=False-headdim=64-num_heads=32-blocksize_q=64-blocksize_k=32-mode=bwd.csv b/benchmarks/mha/triton/batch_size=1-bias=False-headdim=64-num_heads=32-blocksize_q=64-blocksize_k=32-mode=bwd.csv new file mode 100644 index 0000000..f441f2e --- /dev/null +++ b/benchmarks/mha/triton/batch_size=1-bias=False-headdim=64-num_heads=32-blocksize_q=64-blocksize_k=32-mode=bwd.csv @@ -0,0 +1,6 @@ +seqlen,Triton-BlockPtr,Triton-PtrBlock,Jax +1024.000000,39.704063,40.694275,2.149170 +2048.000000,37.292030,37.990913,300.000000 +4096.000000,24.531456,25.107456,300.000000 +6144.000000,5.763072,22.595585,300.000000 +8192.000000,0.001024,0.002048,100.000000 diff --git a/benchmarks/mha/triton/batch_size=1-bias=False-headdim=64-num_heads=32-blocksize_q=64-blocksize_k=32-mode=bwd.png b/benchmarks/mha/triton/batch_size=1-bias=False-headdim=64-num_heads=32-blocksize_q=64-blocksize_k=32-mode=bwd.png new file mode 100644 index 0000000..3d5758d Binary files /dev/null and b/benchmarks/mha/triton/batch_size=1-bias=False-headdim=64-num_heads=32-blocksize_q=64-blocksize_k=32-mode=bwd.png differ diff --git a/benchmarks/mha/triton/batch_size=1-bias=False-headdim=64-num_heads=32-blocksize_q=64-blocksize_k=32-mode=fwd.csv b/benchmarks/mha/triton/batch_size=1-bias=False-headdim=64-num_heads=32-blocksize_q=64-blocksize_k=32-mode=fwd.csv new file mode 100644 index 0000000..74c4cc4 --- /dev/null +++ b/benchmarks/mha/triton/batch_size=1-bias=False-headdim=64-num_heads=32-blocksize_q=64-blocksize_k=32-mode=fwd.csv @@ -0,0 +1,6 @@ +seqlen,Triton-BlockPtr,Triton-PtrBlock,Jax +1024.000000,0.296567,0.293862,0.888089 +2048.000000,0.470762,0.498192,3.372458 +4096.000000,0.729678,0.687250,12.775748 +6144.000000,1.461504,0.935951,28.775423 +8192.000000,6.069362,4.621676,51.770367 diff --git a/benchmarks/mha/triton/batch_size=1-bias=False-headdim=64-num_heads=32-blocksize_q=64-blocksize_k=32-mode=fwd.png b/benchmarks/mha/triton/batch_size=1-bias=False-headdim=64-num_heads=32-blocksize_q=64-blocksize_k=32-mode=fwd.png new file mode 100644 index 0000000..7633f89 Binary files /dev/null and b/benchmarks/mha/triton/batch_size=1-bias=False-headdim=64-num_heads=32-blocksize_q=64-blocksize_k=32-mode=fwd.png differ diff --git a/benchmarks/mha/triton/batch_size=1-bias=False-headdim=64-num_heads=32-blocksize_q=64-blocksize_k=64-mode=bwd.csv b/benchmarks/mha/triton/batch_size=1-bias=False-headdim=64-num_heads=32-blocksize_q=64-blocksize_k=64-mode=bwd.csv new file mode 100644 index 0000000..f507520 --- /dev/null +++ b/benchmarks/mha/triton/batch_size=1-bias=False-headdim=64-num_heads=32-blocksize_q=64-blocksize_k=64-mode=bwd.csv @@ -0,0 +1,6 @@ +seqlen,Triton-BlockPtr,Triton-PtrBlock,Jax +1024.000000,41.308159,40.687103,2.195997 +2048.000000,38.883327,39.604225,300.000000 +4096.000000,32.592896,32.568832,300.000000 +6144.000000,21.382656,22.393345,300.000000 +8192.000000,24.840191,22.914560,100.000000 diff --git a/benchmarks/mha/triton/batch_size=1-bias=False-headdim=64-num_heads=32-blocksize_q=64-blocksize_k=64-mode=bwd.png b/benchmarks/mha/triton/batch_size=1-bias=False-headdim=64-num_heads=32-blocksize_q=64-blocksize_k=64-mode=bwd.png new file mode 100644 index 0000000..1a46a95 Binary files /dev/null and b/benchmarks/mha/triton/batch_size=1-bias=False-headdim=64-num_heads=32-blocksize_q=64-blocksize_k=64-mode=bwd.png differ diff --git a/benchmarks/mha/triton/batch_size=1-bias=False-headdim=64-num_heads=32-blocksize_q=64-blocksize_k=64-mode=fwd.csv b/benchmarks/mha/triton/batch_size=1-bias=False-headdim=64-num_heads=32-blocksize_q=64-blocksize_k=64-mode=fwd.csv new file mode 100644 index 0000000..315306c --- /dev/null +++ b/benchmarks/mha/triton/batch_size=1-bias=False-headdim=64-num_heads=32-blocksize_q=64-blocksize_k=64-mode=fwd.csv @@ -0,0 +1,6 @@ +seqlen,Triton-BlockPtr,Triton-PtrBlock,Jax +1024.000000,0.260127,0.261084,0.907107 +2048.000000,0.489120,0.503194,3.285888 +4096.000000,0.702334,0.687267,12.795319 +6144.000000,0.868415,0.853787,28.554922 +8192.000000,5.292191,5.980858,52.552929 diff --git a/benchmarks/mha/triton/batch_size=1-bias=False-headdim=64-num_heads=32-blocksize_q=64-blocksize_k=64-mode=fwd.png b/benchmarks/mha/triton/batch_size=1-bias=False-headdim=64-num_heads=32-blocksize_q=64-blocksize_k=64-mode=fwd.png new file mode 100644 index 0000000..fd0a03c Binary files /dev/null and b/benchmarks/mha/triton/batch_size=1-bias=False-headdim=64-num_heads=32-blocksize_q=64-blocksize_k=64-mode=fwd.png differ diff --git a/benchmarks/mha/triton/batch_size=1-bias=False-headdim=64-num_heads=8-blocksize_q=128-blocksize_k=128-mode=bwd.csv b/benchmarks/mha/triton/batch_size=1-bias=False-headdim=64-num_heads=8-blocksize_q=128-blocksize_k=128-mode=bwd.csv new file mode 100644 index 0000000..4b50fae --- /dev/null +++ b/benchmarks/mha/triton/batch_size=1-bias=False-headdim=64-num_heads=8-blocksize_q=128-blocksize_k=128-mode=bwd.csv @@ -0,0 +1,6 @@ +seqlen,Triton-BlockPtr,Triton-PtrBlock,Jax +1024.000000,52.571136,54.286335,1.336869 +2048.000000,54.935551,54.233089,2.203917 +4096.000000,48.314369,49.284096,7.858112 +6144.000000,43.418625,44.575745,16.787161 +8192.000000,40.323071,53.779457,30.044502 diff --git a/benchmarks/mha/triton/batch_size=1-bias=False-headdim=64-num_heads=8-blocksize_q=128-blocksize_k=128-mode=bwd.png b/benchmarks/mha/triton/batch_size=1-bias=False-headdim=64-num_heads=8-blocksize_q=128-blocksize_k=128-mode=bwd.png new file mode 100644 index 0000000..797ebf4 Binary files /dev/null and b/benchmarks/mha/triton/batch_size=1-bias=False-headdim=64-num_heads=8-blocksize_q=128-blocksize_k=128-mode=bwd.png differ diff --git a/benchmarks/mha/triton/batch_size=1-bias=False-headdim=64-num_heads=8-blocksize_q=128-blocksize_k=128-mode=fwd.csv b/benchmarks/mha/triton/batch_size=1-bias=False-headdim=64-num_heads=8-blocksize_q=128-blocksize_k=128-mode=fwd.csv new file mode 100644 index 0000000..f1705ec --- /dev/null +++ b/benchmarks/mha/triton/batch_size=1-bias=False-headdim=64-num_heads=8-blocksize_q=128-blocksize_k=128-mode=fwd.csv @@ -0,0 +1,6 @@ +seqlen,Triton-BlockPtr,Triton-PtrBlock,Jax +1024.000000,0.986422,0.964327,0.296334 +2048.000000,3.023531,3.025749,0.895366 +4096.000000,2.943616,2.952864,3.329914 +6144.000000,10.140613,10.020750,7.312576 +8192.000000,4.462695,4.424192,13.005107 diff --git a/benchmarks/mha/triton/batch_size=1-bias=False-headdim=64-num_heads=8-blocksize_q=128-blocksize_k=128-mode=fwd.png b/benchmarks/mha/triton/batch_size=1-bias=False-headdim=64-num_heads=8-blocksize_q=128-blocksize_k=128-mode=fwd.png new file mode 100644 index 0000000..0e0e639 Binary files /dev/null and b/benchmarks/mha/triton/batch_size=1-bias=False-headdim=64-num_heads=8-blocksize_q=128-blocksize_k=128-mode=fwd.png differ diff --git a/benchmarks/mha/triton/batch_size=1-bias=False-headdim=64-num_heads=8-blocksize_q=128-blocksize_k=32-mode=bwd.csv b/benchmarks/mha/triton/batch_size=1-bias=False-headdim=64-num_heads=8-blocksize_q=128-blocksize_k=32-mode=bwd.csv new file mode 100644 index 0000000..49e4878 --- /dev/null +++ b/benchmarks/mha/triton/batch_size=1-bias=False-headdim=64-num_heads=8-blocksize_q=128-blocksize_k=32-mode=bwd.csv @@ -0,0 +1,6 @@ +seqlen,Triton-BlockPtr,Triton-PtrBlock,Jax +1024.000000,41.035309,40.819710,1.296570 +2048.000000,39.990784,40.225281,2.215243 +4096.000000,38.446079,37.881859,8.184044 +6144.000000,32.667137,32.935425,17.994085 +8192.000000,26.741760,25.692673,32.180950 diff --git a/benchmarks/mha/triton/batch_size=1-bias=False-headdim=64-num_heads=8-blocksize_q=128-blocksize_k=32-mode=bwd.png b/benchmarks/mha/triton/batch_size=1-bias=False-headdim=64-num_heads=8-blocksize_q=128-blocksize_k=32-mode=bwd.png new file mode 100644 index 0000000..30cf025 Binary files /dev/null and b/benchmarks/mha/triton/batch_size=1-bias=False-headdim=64-num_heads=8-blocksize_q=128-blocksize_k=32-mode=bwd.png differ diff --git a/benchmarks/mha/triton/batch_size=1-bias=False-headdim=64-num_heads=8-blocksize_q=128-blocksize_k=32-mode=fwd.csv b/benchmarks/mha/triton/batch_size=1-bias=False-headdim=64-num_heads=8-blocksize_q=128-blocksize_k=32-mode=fwd.csv new file mode 100644 index 0000000..67e381d --- /dev/null +++ b/benchmarks/mha/triton/batch_size=1-bias=False-headdim=64-num_heads=8-blocksize_q=128-blocksize_k=32-mode=fwd.csv @@ -0,0 +1,6 @@ +seqlen,Triton-BlockPtr,Triton-PtrBlock,Jax +1024.000000,0.318273,0.318949,0.314035 +2048.000000,0.290389,0.283678,0.868269 +4096.000000,0.856528,0.828110,3.297724 +6144.000000,1.193067,1.092842,7.319801 +8192.000000,1.384435,1.284415,12.749312 diff --git a/benchmarks/mha/triton/batch_size=1-bias=False-headdim=64-num_heads=8-blocksize_q=128-blocksize_k=32-mode=fwd.png b/benchmarks/mha/triton/batch_size=1-bias=False-headdim=64-num_heads=8-blocksize_q=128-blocksize_k=32-mode=fwd.png new file mode 100644 index 0000000..fcb3fd1 Binary files /dev/null and b/benchmarks/mha/triton/batch_size=1-bias=False-headdim=64-num_heads=8-blocksize_q=128-blocksize_k=32-mode=fwd.png differ diff --git a/benchmarks/mha/triton/batch_size=1-bias=False-headdim=64-num_heads=8-blocksize_q=128-blocksize_k=64-mode=bwd.csv b/benchmarks/mha/triton/batch_size=1-bias=False-headdim=64-num_heads=8-blocksize_q=128-blocksize_k=64-mode=bwd.csv new file mode 100644 index 0000000..03678b3 --- /dev/null +++ b/benchmarks/mha/triton/batch_size=1-bias=False-headdim=64-num_heads=8-blocksize_q=128-blocksize_k=64-mode=bwd.csv @@ -0,0 +1,6 @@ +seqlen,Triton-BlockPtr,Triton-PtrBlock,Jax +1024.000000,45.311073,46.057983,1.332440 +2048.000000,45.194752,47.624191,2.219133 +4096.000000,44.857857,45.221375,7.999828 +6144.000000,130.884094,40.868355,17.851673 +8192.000000,45.209602,44.301315,31.801687 diff --git a/benchmarks/mha/triton/batch_size=1-bias=False-headdim=64-num_heads=8-blocksize_q=128-blocksize_k=64-mode=bwd.png b/benchmarks/mha/triton/batch_size=1-bias=False-headdim=64-num_heads=8-blocksize_q=128-blocksize_k=64-mode=bwd.png new file mode 100644 index 0000000..a7227b5 Binary files /dev/null and b/benchmarks/mha/triton/batch_size=1-bias=False-headdim=64-num_heads=8-blocksize_q=128-blocksize_k=64-mode=bwd.png differ diff --git a/benchmarks/mha/triton/batch_size=1-bias=False-headdim=64-num_heads=8-blocksize_q=128-blocksize_k=64-mode=fwd.csv b/benchmarks/mha/triton/batch_size=1-bias=False-headdim=64-num_heads=8-blocksize_q=128-blocksize_k=64-mode=fwd.csv new file mode 100644 index 0000000..7ba126e --- /dev/null +++ b/benchmarks/mha/triton/batch_size=1-bias=False-headdim=64-num_heads=8-blocksize_q=128-blocksize_k=64-mode=fwd.csv @@ -0,0 +1,6 @@ +seqlen,Triton-BlockPtr,Triton-PtrBlock,Jax +1024.000000,0.749852,0.724501,0.293608 +2048.000000,2.074576,2.071059,0.883282 +4096.000000,2.840201,2.829121,3.245122 +6144.000000,7.177380,7.196886,7.207584 +8192.000000,4.920247,4.996974,13.035813 diff --git a/benchmarks/mha/triton/batch_size=1-bias=False-headdim=64-num_heads=8-blocksize_q=128-blocksize_k=64-mode=fwd.png b/benchmarks/mha/triton/batch_size=1-bias=False-headdim=64-num_heads=8-blocksize_q=128-blocksize_k=64-mode=fwd.png new file mode 100644 index 0000000..1896e38 Binary files /dev/null and b/benchmarks/mha/triton/batch_size=1-bias=False-headdim=64-num_heads=8-blocksize_q=128-blocksize_k=64-mode=fwd.png differ diff --git a/benchmarks/mha/triton/batch_size=1-bias=False-headdim=64-num_heads=8-blocksize_q=32-blocksize_k=128-mode=bwd.csv b/benchmarks/mha/triton/batch_size=1-bias=False-headdim=64-num_heads=8-blocksize_q=32-blocksize_k=128-mode=bwd.csv new file mode 100644 index 0000000..e6642c6 --- /dev/null +++ b/benchmarks/mha/triton/batch_size=1-bias=False-headdim=64-num_heads=8-blocksize_q=32-blocksize_k=128-mode=bwd.csv @@ -0,0 +1,6 @@ +seqlen,Triton-BlockPtr,Triton-PtrBlock,Jax +1024.000000,39.062706,39.139839,1.392907 +2048.000000,38.938622,39.918083,2.231972 +4096.000000,38.449661,38.187008,7.998953 +6144.000000,37.007362,36.594688,17.431482 +8192.000000,33.146881,33.517056,32.820320 diff --git a/benchmarks/mha/triton/batch_size=1-bias=False-headdim=64-num_heads=8-blocksize_q=32-blocksize_k=128-mode=bwd.png b/benchmarks/mha/triton/batch_size=1-bias=False-headdim=64-num_heads=8-blocksize_q=32-blocksize_k=128-mode=bwd.png new file mode 100644 index 0000000..0fa595a Binary files /dev/null and b/benchmarks/mha/triton/batch_size=1-bias=False-headdim=64-num_heads=8-blocksize_q=32-blocksize_k=128-mode=bwd.png differ diff --git a/benchmarks/mha/triton/batch_size=1-bias=False-headdim=64-num_heads=8-blocksize_q=32-blocksize_k=128-mode=fwd.csv b/benchmarks/mha/triton/batch_size=1-bias=False-headdim=64-num_heads=8-blocksize_q=32-blocksize_k=128-mode=fwd.csv new file mode 100644 index 0000000..0ce2023 --- /dev/null +++ b/benchmarks/mha/triton/batch_size=1-bias=False-headdim=64-num_heads=8-blocksize_q=32-blocksize_k=128-mode=fwd.csv @@ -0,0 +1,6 @@ +seqlen,Triton-BlockPtr,Triton-PtrBlock,Jax +1024.000000,0.317588,0.326749,0.308598 +2048.000000,0.183474,0.178400,0.890750 +4096.000000,0.581134,0.556604,3.265938 +6144.000000,0.668777,0.643651,7.210877 +8192.000000,0.734979,0.763596,12.884759 diff --git a/benchmarks/mha/triton/batch_size=1-bias=False-headdim=64-num_heads=8-blocksize_q=32-blocksize_k=128-mode=fwd.png b/benchmarks/mha/triton/batch_size=1-bias=False-headdim=64-num_heads=8-blocksize_q=32-blocksize_k=128-mode=fwd.png new file mode 100644 index 0000000..84c1d19 Binary files /dev/null and b/benchmarks/mha/triton/batch_size=1-bias=False-headdim=64-num_heads=8-blocksize_q=32-blocksize_k=128-mode=fwd.png differ diff --git a/benchmarks/mha/triton/batch_size=1-bias=False-headdim=64-num_heads=8-blocksize_q=32-blocksize_k=32-mode=bwd.csv b/benchmarks/mha/triton/batch_size=1-bias=False-headdim=64-num_heads=8-blocksize_q=32-blocksize_k=32-mode=bwd.csv new file mode 100644 index 0000000..371db1f --- /dev/null +++ b/benchmarks/mha/triton/batch_size=1-bias=False-headdim=64-num_heads=8-blocksize_q=32-blocksize_k=32-mode=bwd.csv @@ -0,0 +1,6 @@ +seqlen,Triton-BlockPtr,Triton-PtrBlock,Jax +1024.000000,35.832832,36.336639,1.322190 +2048.000000,35.884033,35.482113,2.193433 +4096.000000,33.507328,34.103806,7.741800 +6144.000000,29.461504,30.001152,17.455212 +8192.000000,23.644161,24.775167,31.463381 diff --git a/benchmarks/mha/triton/batch_size=1-bias=False-headdim=64-num_heads=8-blocksize_q=32-blocksize_k=32-mode=bwd.png b/benchmarks/mha/triton/batch_size=1-bias=False-headdim=64-num_heads=8-blocksize_q=32-blocksize_k=32-mode=bwd.png new file mode 100644 index 0000000..6113a6a Binary files /dev/null and b/benchmarks/mha/triton/batch_size=1-bias=False-headdim=64-num_heads=8-blocksize_q=32-blocksize_k=32-mode=bwd.png differ diff --git a/benchmarks/mha/triton/batch_size=1-bias=False-headdim=64-num_heads=8-blocksize_q=32-blocksize_k=32-mode=fwd.csv b/benchmarks/mha/triton/batch_size=1-bias=False-headdim=64-num_heads=8-blocksize_q=32-blocksize_k=32-mode=fwd.csv new file mode 100644 index 0000000..413d845 --- /dev/null +++ b/benchmarks/mha/triton/batch_size=1-bias=False-headdim=64-num_heads=8-blocksize_q=32-blocksize_k=32-mode=fwd.csv @@ -0,0 +1,6 @@ +seqlen,Triton-BlockPtr,Triton-PtrBlock,Jax +1024.000000,0.304420,0.315460,0.300184 +2048.000000,0.235732,0.250502,0.903695 +4096.000000,0.650667,0.647369,3.362326 +6144.000000,0.937984,0.936868,7.363729 +8192.000000,1.094707,1.084745,13.023402 diff --git a/benchmarks/mha/triton/batch_size=1-bias=False-headdim=64-num_heads=8-blocksize_q=32-blocksize_k=32-mode=fwd.png b/benchmarks/mha/triton/batch_size=1-bias=False-headdim=64-num_heads=8-blocksize_q=32-blocksize_k=32-mode=fwd.png new file mode 100644 index 0000000..1e91c06 Binary files /dev/null and b/benchmarks/mha/triton/batch_size=1-bias=False-headdim=64-num_heads=8-blocksize_q=32-blocksize_k=32-mode=fwd.png differ diff --git a/benchmarks/mha/triton/batch_size=1-bias=False-headdim=64-num_heads=8-blocksize_q=32-blocksize_k=64-mode=bwd.csv b/benchmarks/mha/triton/batch_size=1-bias=False-headdim=64-num_heads=8-blocksize_q=32-blocksize_k=64-mode=bwd.csv new file mode 100644 index 0000000..c4f443d --- /dev/null +++ b/benchmarks/mha/triton/batch_size=1-bias=False-headdim=64-num_heads=8-blocksize_q=32-blocksize_k=64-mode=bwd.csv @@ -0,0 +1,6 @@ +seqlen,Triton-BlockPtr,Triton-PtrBlock,Jax +1024.000000,38.010368,38.346176,1.427182 +2048.000000,37.487617,38.526978,4.504594 +4096.000000,35.168770,35.681282,7.861829 +6144.000000,32.790527,32.652290,17.580915 +8192.000000,28.063232,27.634687,31.056395 diff --git a/benchmarks/mha/triton/batch_size=1-bias=False-headdim=64-num_heads=8-blocksize_q=32-blocksize_k=64-mode=bwd.png b/benchmarks/mha/triton/batch_size=1-bias=False-headdim=64-num_heads=8-blocksize_q=32-blocksize_k=64-mode=bwd.png new file mode 100644 index 0000000..f431591 Binary files /dev/null and b/benchmarks/mha/triton/batch_size=1-bias=False-headdim=64-num_heads=8-blocksize_q=32-blocksize_k=64-mode=bwd.png differ diff --git a/benchmarks/mha/triton/batch_size=1-bias=False-headdim=64-num_heads=8-blocksize_q=32-blocksize_k=64-mode=fwd.csv b/benchmarks/mha/triton/batch_size=1-bias=False-headdim=64-num_heads=8-blocksize_q=32-blocksize_k=64-mode=fwd.csv new file mode 100644 index 0000000..e3ef766 --- /dev/null +++ b/benchmarks/mha/triton/batch_size=1-bias=False-headdim=64-num_heads=8-blocksize_q=32-blocksize_k=64-mode=fwd.csv @@ -0,0 +1,6 @@ +seqlen,Triton-BlockPtr,Triton-PtrBlock,Jax +1024.000000,0.353060,0.307993,0.283132 +2048.000000,0.253584,0.258981,0.887355 +4096.000000,0.585810,0.625254,3.314362 +6144.000000,0.733644,0.665862,7.355601 +8192.000000,0.711191,0.731782,13.027549 diff --git a/benchmarks/mha/triton/batch_size=1-bias=False-headdim=64-num_heads=8-blocksize_q=32-blocksize_k=64-mode=fwd.png b/benchmarks/mha/triton/batch_size=1-bias=False-headdim=64-num_heads=8-blocksize_q=32-blocksize_k=64-mode=fwd.png new file mode 100644 index 0000000..b283536 Binary files /dev/null and b/benchmarks/mha/triton/batch_size=1-bias=False-headdim=64-num_heads=8-blocksize_q=32-blocksize_k=64-mode=fwd.png differ diff --git a/benchmarks/mha/triton/batch_size=1-bias=False-headdim=64-num_heads=8-blocksize_q=64-blocksize_k=128-mode=bwd.csv b/benchmarks/mha/triton/batch_size=1-bias=False-headdim=64-num_heads=8-blocksize_q=64-blocksize_k=128-mode=bwd.csv new file mode 100644 index 0000000..86b0b9d --- /dev/null +++ b/benchmarks/mha/triton/batch_size=1-bias=False-headdim=64-num_heads=8-blocksize_q=64-blocksize_k=128-mode=bwd.csv @@ -0,0 +1,6 @@ +seqlen,Triton-BlockPtr,Triton-PtrBlock,Jax +1024.000000,44.584961,44.800369,1.348415 +2048.000000,44.342785,44.447742,2.229659 +4096.000000,42.759682,43.175423,8.166669 +6144.000000,39.245312,39.000061,17.567398 +8192.000000,34.007042,33.820160,30.994370 diff --git a/benchmarks/mha/triton/batch_size=1-bias=False-headdim=64-num_heads=8-blocksize_q=64-blocksize_k=128-mode=bwd.png b/benchmarks/mha/triton/batch_size=1-bias=False-headdim=64-num_heads=8-blocksize_q=64-blocksize_k=128-mode=bwd.png new file mode 100644 index 0000000..fc61b6c Binary files /dev/null and b/benchmarks/mha/triton/batch_size=1-bias=False-headdim=64-num_heads=8-blocksize_q=64-blocksize_k=128-mode=bwd.png differ diff --git a/benchmarks/mha/triton/batch_size=1-bias=False-headdim=64-num_heads=8-blocksize_q=64-blocksize_k=128-mode=fwd.csv b/benchmarks/mha/triton/batch_size=1-bias=False-headdim=64-num_heads=8-blocksize_q=64-blocksize_k=128-mode=fwd.csv new file mode 100644 index 0000000..0f45409 --- /dev/null +++ b/benchmarks/mha/triton/batch_size=1-bias=False-headdim=64-num_heads=8-blocksize_q=64-blocksize_k=128-mode=fwd.csv @@ -0,0 +1,6 @@ +seqlen,Triton-BlockPtr,Triton-PtrBlock,Jax +1024.000000,0.336947,0.282397,0.299422 +2048.000000,0.920680,0.902962,0.887359 +4096.000000,1.495256,1.456183,3.359261 +6144.000000,1.886525,1.886255,7.241199 +8192.000000,2.268691,2.252764,12.953650 diff --git a/benchmarks/mha/triton/batch_size=1-bias=False-headdim=64-num_heads=8-blocksize_q=64-blocksize_k=128-mode=fwd.png b/benchmarks/mha/triton/batch_size=1-bias=False-headdim=64-num_heads=8-blocksize_q=64-blocksize_k=128-mode=fwd.png new file mode 100644 index 0000000..45e40fa Binary files /dev/null and b/benchmarks/mha/triton/batch_size=1-bias=False-headdim=64-num_heads=8-blocksize_q=64-blocksize_k=128-mode=fwd.png differ diff --git a/benchmarks/mha/triton/batch_size=1-bias=False-headdim=64-num_heads=8-blocksize_q=64-blocksize_k=32-mode=bwd.csv b/benchmarks/mha/triton/batch_size=1-bias=False-headdim=64-num_heads=8-blocksize_q=64-blocksize_k=32-mode=bwd.csv new file mode 100644 index 0000000..88cebfc --- /dev/null +++ b/benchmarks/mha/triton/batch_size=1-bias=False-headdim=64-num_heads=8-blocksize_q=64-blocksize_k=32-mode=bwd.csv @@ -0,0 +1,6 @@ +seqlen,Triton-BlockPtr,Triton-PtrBlock,Jax +1024.000000,38.834496,38.491646,1.431602 +2048.000000,37.678078,38.115837,2.296343 +4096.000000,34.809341,35.026432,7.842674 +6144.000000,29.389824,30.330879,17.783072 +8192.000000,22.793728,22.497280,31.165781 diff --git a/benchmarks/mha/triton/batch_size=1-bias=False-headdim=64-num_heads=8-blocksize_q=64-blocksize_k=32-mode=bwd.png b/benchmarks/mha/triton/batch_size=1-bias=False-headdim=64-num_heads=8-blocksize_q=64-blocksize_k=32-mode=bwd.png new file mode 100644 index 0000000..d2e58f5 Binary files /dev/null and b/benchmarks/mha/triton/batch_size=1-bias=False-headdim=64-num_heads=8-blocksize_q=64-blocksize_k=32-mode=bwd.png differ diff --git a/benchmarks/mha/triton/batch_size=1-bias=False-headdim=64-num_heads=8-blocksize_q=64-blocksize_k=32-mode=fwd.csv b/benchmarks/mha/triton/batch_size=1-bias=False-headdim=64-num_heads=8-blocksize_q=64-blocksize_k=32-mode=fwd.csv new file mode 100644 index 0000000..a5feb32 --- /dev/null +++ b/benchmarks/mha/triton/batch_size=1-bias=False-headdim=64-num_heads=8-blocksize_q=64-blocksize_k=32-mode=fwd.csv @@ -0,0 +1,6 @@ +seqlen,Triton-BlockPtr,Triton-PtrBlock,Jax +1024.000000,0.309407,0.307079,0.284503 +2048.000000,0.280597,0.260318,0.900728 +4096.000000,0.629191,0.629461,3.296050 +6144.000000,0.761119,0.756234,7.309856 +8192.000000,0.996835,0.847737,13.030063 diff --git a/benchmarks/mha/triton/batch_size=1-bias=False-headdim=64-num_heads=8-blocksize_q=64-blocksize_k=32-mode=fwd.png b/benchmarks/mha/triton/batch_size=1-bias=False-headdim=64-num_heads=8-blocksize_q=64-blocksize_k=32-mode=fwd.png new file mode 100644 index 0000000..660acca Binary files /dev/null and b/benchmarks/mha/triton/batch_size=1-bias=False-headdim=64-num_heads=8-blocksize_q=64-blocksize_k=32-mode=fwd.png differ diff --git a/benchmarks/mha/triton/batch_size=1-bias=False-headdim=64-num_heads=8-blocksize_q=64-blocksize_k=64-mode=bwd.csv b/benchmarks/mha/triton/batch_size=1-bias=False-headdim=64-num_heads=8-blocksize_q=64-blocksize_k=64-mode=bwd.csv new file mode 100644 index 0000000..cab097f --- /dev/null +++ b/benchmarks/mha/triton/batch_size=1-bias=False-headdim=64-num_heads=8-blocksize_q=64-blocksize_k=64-mode=bwd.csv @@ -0,0 +1,6 @@ +seqlen,Triton-BlockPtr,Triton-PtrBlock,Jax +1024.000000,38.587551,39.071232,1.312983 +2048.000000,38.632446,38.906879,2.226239 +4096.000000,36.688385,37.032448,8.052073 +6144.000000,34.950657,34.782722,17.811047 +8192.000000,30.153728,30.028288,31.120245 diff --git a/benchmarks/mha/triton/batch_size=1-bias=False-headdim=64-num_heads=8-blocksize_q=64-blocksize_k=64-mode=bwd.png b/benchmarks/mha/triton/batch_size=1-bias=False-headdim=64-num_heads=8-blocksize_q=64-blocksize_k=64-mode=bwd.png new file mode 100644 index 0000000..e3b770c Binary files /dev/null and b/benchmarks/mha/triton/batch_size=1-bias=False-headdim=64-num_heads=8-blocksize_q=64-blocksize_k=64-mode=bwd.png differ diff --git a/benchmarks/mha/triton/batch_size=1-bias=False-headdim=64-num_heads=8-blocksize_q=64-blocksize_k=64-mode=fwd.csv b/benchmarks/mha/triton/batch_size=1-bias=False-headdim=64-num_heads=8-blocksize_q=64-blocksize_k=64-mode=fwd.csv new file mode 100644 index 0000000..fc68892 --- /dev/null +++ b/benchmarks/mha/triton/batch_size=1-bias=False-headdim=64-num_heads=8-blocksize_q=64-blocksize_k=64-mode=fwd.csv @@ -0,0 +1,6 @@ +seqlen,Triton-BlockPtr,Triton-PtrBlock,Jax +1024.000000,0.336735,0.328934,0.301195 +2048.000000,0.379163,0.352202,0.882639 +4096.000000,0.562122,0.576559,3.266600 +6144.000000,0.634808,0.622716,7.243106 +8192.000000,0.846580,0.744171,12.821294 diff --git a/benchmarks/mha/triton/batch_size=1-bias=False-headdim=64-num_heads=8-blocksize_q=64-blocksize_k=64-mode=fwd.png b/benchmarks/mha/triton/batch_size=1-bias=False-headdim=64-num_heads=8-blocksize_q=64-blocksize_k=64-mode=fwd.png new file mode 100644 index 0000000..9e00b62 Binary files /dev/null and b/benchmarks/mha/triton/batch_size=1-bias=False-headdim=64-num_heads=8-blocksize_q=64-blocksize_k=64-mode=fwd.png differ diff --git a/benchmarks/mha/triton/batch_size=1-bias=True-headdim=128-num_heads=16-blocksize_q=128-blocksize_k=128-mode=fwd.csv b/benchmarks/mha/triton/batch_size=1-bias=True-headdim=128-num_heads=16-blocksize_q=128-blocksize_k=128-mode=fwd.csv new file mode 100644 index 0000000..4726df9 --- /dev/null +++ b/benchmarks/mha/triton/batch_size=1-bias=True-headdim=128-num_heads=16-blocksize_q=128-blocksize_k=128-mode=fwd.csv @@ -0,0 +1,6 @@ +seqlen,Triton-BlockPtr,Triton-PtrBlock,Jax +1024.000000,3.433984,3.445369,0.605748 +2048.000000,4.275917,4.283699,2.361694 +4096.000000,1.285632,1.205760,8.792658 +6144.000000,0.004096,0.002048,19.620121 +8192.000000,0.003072,0.003072,34.009216 diff --git a/benchmarks/mha/triton/batch_size=1-bias=True-headdim=128-num_heads=16-blocksize_q=128-blocksize_k=128-mode=fwd.png b/benchmarks/mha/triton/batch_size=1-bias=True-headdim=128-num_heads=16-blocksize_q=128-blocksize_k=128-mode=fwd.png new file mode 100644 index 0000000..c19422d Binary files /dev/null and b/benchmarks/mha/triton/batch_size=1-bias=True-headdim=128-num_heads=16-blocksize_q=128-blocksize_k=128-mode=fwd.png differ diff --git a/benchmarks/mha/triton/batch_size=1-bias=True-headdim=128-num_heads=16-blocksize_q=128-blocksize_k=32-mode=fwd.csv b/benchmarks/mha/triton/batch_size=1-bias=True-headdim=128-num_heads=16-blocksize_q=128-blocksize_k=32-mode=fwd.csv new file mode 100644 index 0000000..705f01a --- /dev/null +++ b/benchmarks/mha/triton/batch_size=1-bias=True-headdim=128-num_heads=16-blocksize_q=128-blocksize_k=32-mode=fwd.csv @@ -0,0 +1,6 @@ +seqlen,Triton-BlockPtr,Triton-PtrBlock,Jax +1024.000000,0.535561,0.492613,0.624464 +2048.000000,1.897961,1.911615,2.326837 +4096.000000,5.809407,5.330261,8.732139 +6144.000000,10.589184,10.575407,19.440289 +8192.000000,13.273089,15.072597,34.166271 diff --git a/benchmarks/mha/triton/batch_size=1-bias=True-headdim=128-num_heads=16-blocksize_q=128-blocksize_k=32-mode=fwd.png b/benchmarks/mha/triton/batch_size=1-bias=True-headdim=128-num_heads=16-blocksize_q=128-blocksize_k=32-mode=fwd.png new file mode 100644 index 0000000..6d28b08 Binary files /dev/null and b/benchmarks/mha/triton/batch_size=1-bias=True-headdim=128-num_heads=16-blocksize_q=128-blocksize_k=32-mode=fwd.png differ diff --git a/benchmarks/mha/triton/batch_size=1-bias=True-headdim=128-num_heads=16-blocksize_q=128-blocksize_k=64-mode=fwd.csv b/benchmarks/mha/triton/batch_size=1-bias=True-headdim=128-num_heads=16-blocksize_q=128-blocksize_k=64-mode=fwd.csv new file mode 100644 index 0000000..609da8b --- /dev/null +++ b/benchmarks/mha/triton/batch_size=1-bias=True-headdim=128-num_heads=16-blocksize_q=128-blocksize_k=64-mode=fwd.csv @@ -0,0 +1,6 @@ +seqlen,Triton-BlockPtr,Triton-PtrBlock,Jax +1024.000000,3.847191,3.821503,0.613257 +2048.000000,7.239095,7.185847,2.246083 +4096.000000,4.123648,4.171435,8.773578 +6144.000000,0.003072,0.003072,19.379322 +8192.000000,0.004096,0.003072,34.126064 diff --git a/benchmarks/mha/triton/batch_size=1-bias=True-headdim=128-num_heads=16-blocksize_q=128-blocksize_k=64-mode=fwd.png b/benchmarks/mha/triton/batch_size=1-bias=True-headdim=128-num_heads=16-blocksize_q=128-blocksize_k=64-mode=fwd.png new file mode 100644 index 0000000..f7221e9 Binary files /dev/null and b/benchmarks/mha/triton/batch_size=1-bias=True-headdim=128-num_heads=16-blocksize_q=128-blocksize_k=64-mode=fwd.png differ diff --git a/benchmarks/mha/triton/batch_size=1-bias=True-headdim=128-num_heads=16-blocksize_q=32-blocksize_k=128-mode=fwd.csv b/benchmarks/mha/triton/batch_size=1-bias=True-headdim=128-num_heads=16-blocksize_q=32-blocksize_k=128-mode=fwd.csv new file mode 100644 index 0000000..d7a7c90 --- /dev/null +++ b/benchmarks/mha/triton/batch_size=1-bias=True-headdim=128-num_heads=16-blocksize_q=32-blocksize_k=128-mode=fwd.csv @@ -0,0 +1,6 @@ +seqlen,Triton-BlockPtr,Triton-PtrBlock,Jax +1024.000000,0.482413,0.510899,0.601046 +2048.000000,1.535955,1.643900,2.334584 +4096.000000,4.950063,4.959139,8.614684 +6144.000000,10.176513,9.558295,19.547359 +8192.000000,14.409898,12.245503,34.743454 diff --git a/benchmarks/mha/triton/batch_size=1-bias=True-headdim=128-num_heads=16-blocksize_q=32-blocksize_k=128-mode=fwd.png b/benchmarks/mha/triton/batch_size=1-bias=True-headdim=128-num_heads=16-blocksize_q=32-blocksize_k=128-mode=fwd.png new file mode 100644 index 0000000..9ac5e12 Binary files /dev/null and b/benchmarks/mha/triton/batch_size=1-bias=True-headdim=128-num_heads=16-blocksize_q=32-blocksize_k=128-mode=fwd.png differ diff --git a/benchmarks/mha/triton/batch_size=1-bias=True-headdim=128-num_heads=16-blocksize_q=32-blocksize_k=32-mode=fwd.csv b/benchmarks/mha/triton/batch_size=1-bias=True-headdim=128-num_heads=16-blocksize_q=32-blocksize_k=32-mode=fwd.csv new file mode 100644 index 0000000..53cc65f --- /dev/null +++ b/benchmarks/mha/triton/batch_size=1-bias=True-headdim=128-num_heads=16-blocksize_q=32-blocksize_k=32-mode=fwd.csv @@ -0,0 +1,6 @@ +seqlen,Triton-BlockPtr,Triton-PtrBlock,Jax +1024.000000,0.526124,0.565341,0.610957 +2048.000000,1.616399,1.590910,2.278205 +4096.000000,4.975928,4.884836,8.749184 +6144.000000,9.679872,9.752018,19.437422 +8192.000000,12.499969,16.855894,34.408958 diff --git a/benchmarks/mha/triton/batch_size=1-bias=True-headdim=128-num_heads=16-blocksize_q=32-blocksize_k=32-mode=fwd.png b/benchmarks/mha/triton/batch_size=1-bias=True-headdim=128-num_heads=16-blocksize_q=32-blocksize_k=32-mode=fwd.png new file mode 100644 index 0000000..579ac06 Binary files /dev/null and b/benchmarks/mha/triton/batch_size=1-bias=True-headdim=128-num_heads=16-blocksize_q=32-blocksize_k=32-mode=fwd.png differ diff --git a/benchmarks/mha/triton/batch_size=1-bias=True-headdim=128-num_heads=16-blocksize_q=32-blocksize_k=64-mode=fwd.csv b/benchmarks/mha/triton/batch_size=1-bias=True-headdim=128-num_heads=16-blocksize_q=32-blocksize_k=64-mode=fwd.csv new file mode 100644 index 0000000..7800b0a --- /dev/null +++ b/benchmarks/mha/triton/batch_size=1-bias=True-headdim=128-num_heads=16-blocksize_q=32-blocksize_k=64-mode=fwd.csv @@ -0,0 +1,6 @@ +seqlen,Triton-BlockPtr,Triton-PtrBlock,Jax +1024.000000,0.506155,0.520014,0.627987 +2048.000000,1.577653,1.612652,2.300623 +4096.000000,5.221132,5.401356,8.793837 +6144.000000,9.650278,9.766809,19.467955 +8192.000000,12.395862,14.794751,34.451454 diff --git a/benchmarks/mha/triton/batch_size=1-bias=True-headdim=128-num_heads=16-blocksize_q=32-blocksize_k=64-mode=fwd.png b/benchmarks/mha/triton/batch_size=1-bias=True-headdim=128-num_heads=16-blocksize_q=32-blocksize_k=64-mode=fwd.png new file mode 100644 index 0000000..540ad68 Binary files /dev/null and b/benchmarks/mha/triton/batch_size=1-bias=True-headdim=128-num_heads=16-blocksize_q=32-blocksize_k=64-mode=fwd.png differ diff --git a/benchmarks/mha/triton/batch_size=1-bias=True-headdim=128-num_heads=16-blocksize_q=64-blocksize_k=128-mode=fwd.csv b/benchmarks/mha/triton/batch_size=1-bias=True-headdim=128-num_heads=16-blocksize_q=64-blocksize_k=128-mode=fwd.csv new file mode 100644 index 0000000..494544b --- /dev/null +++ b/benchmarks/mha/triton/batch_size=1-bias=True-headdim=128-num_heads=16-blocksize_q=64-blocksize_k=128-mode=fwd.csv @@ -0,0 +1,6 @@ +seqlen,Triton-BlockPtr,Triton-PtrBlock,Jax +1024.000000,2.311573,2.272879,0.617300 +2048.000000,3.052764,3.055909,2.232784 +4096.000000,3.858091,3.961173,8.664483 +6144.000000,0.003072,0.003072,19.400097 +8192.000000,0.003072,0.004096,34.452095 diff --git a/benchmarks/mha/triton/batch_size=1-bias=True-headdim=128-num_heads=16-blocksize_q=64-blocksize_k=128-mode=fwd.png b/benchmarks/mha/triton/batch_size=1-bias=True-headdim=128-num_heads=16-blocksize_q=64-blocksize_k=128-mode=fwd.png new file mode 100644 index 0000000..26721e6 Binary files /dev/null and b/benchmarks/mha/triton/batch_size=1-bias=True-headdim=128-num_heads=16-blocksize_q=64-blocksize_k=128-mode=fwd.png differ diff --git a/benchmarks/mha/triton/batch_size=1-bias=True-headdim=128-num_heads=16-blocksize_q=64-blocksize_k=32-mode=fwd.csv b/benchmarks/mha/triton/batch_size=1-bias=True-headdim=128-num_heads=16-blocksize_q=64-blocksize_k=32-mode=fwd.csv new file mode 100644 index 0000000..4237e76 --- /dev/null +++ b/benchmarks/mha/triton/batch_size=1-bias=True-headdim=128-num_heads=16-blocksize_q=64-blocksize_k=32-mode=fwd.csv @@ -0,0 +1,6 @@ +seqlen,Triton-BlockPtr,Triton-PtrBlock,Jax +1024.000000,0.558362,0.584712,0.617515 +2048.000000,1.729792,1.829189,2.323383 +4096.000000,5.447936,5.322167,8.762418 +6144.000000,9.970893,9.932799,19.534885 +8192.000000,12.710229,17.453741,34.451454 diff --git a/benchmarks/mha/triton/batch_size=1-bias=True-headdim=128-num_heads=16-blocksize_q=64-blocksize_k=32-mode=fwd.png b/benchmarks/mha/triton/batch_size=1-bias=True-headdim=128-num_heads=16-blocksize_q=64-blocksize_k=32-mode=fwd.png new file mode 100644 index 0000000..d611fd4 Binary files /dev/null and b/benchmarks/mha/triton/batch_size=1-bias=True-headdim=128-num_heads=16-blocksize_q=64-blocksize_k=32-mode=fwd.png differ diff --git a/benchmarks/mha/triton/batch_size=1-bias=True-headdim=128-num_heads=16-blocksize_q=64-blocksize_k=64-mode=fwd.csv b/benchmarks/mha/triton/batch_size=1-bias=True-headdim=128-num_heads=16-blocksize_q=64-blocksize_k=64-mode=fwd.csv new file mode 100644 index 0000000..203d06e --- /dev/null +++ b/benchmarks/mha/triton/batch_size=1-bias=True-headdim=128-num_heads=16-blocksize_q=64-blocksize_k=64-mode=fwd.csv @@ -0,0 +1,6 @@ +seqlen,Triton-BlockPtr,Triton-PtrBlock,Jax +1024.000000,0.510852,0.559451,0.594321 +2048.000000,1.665068,1.738157,2.323330 +4096.000000,5.622226,5.065240,8.702442 +6144.000000,9.805172,9.890629,19.576332 +8192.000000,12.501845,14.445057,34.137600 diff --git a/benchmarks/mha/triton/batch_size=1-bias=True-headdim=128-num_heads=16-blocksize_q=64-blocksize_k=64-mode=fwd.png b/benchmarks/mha/triton/batch_size=1-bias=True-headdim=128-num_heads=16-blocksize_q=64-blocksize_k=64-mode=fwd.png new file mode 100644 index 0000000..3b637c3 Binary files /dev/null and b/benchmarks/mha/triton/batch_size=1-bias=True-headdim=128-num_heads=16-blocksize_q=64-blocksize_k=64-mode=fwd.png differ diff --git a/benchmarks/mha/triton/batch_size=1-bias=True-headdim=128-num_heads=32-blocksize_q=128-blocksize_k=128-mode=fwd.csv b/benchmarks/mha/triton/batch_size=1-bias=True-headdim=128-num_heads=32-blocksize_q=128-blocksize_k=128-mode=fwd.csv new file mode 100644 index 0000000..a2bb17e --- /dev/null +++ b/benchmarks/mha/triton/batch_size=1-bias=True-headdim=128-num_heads=32-blocksize_q=128-blocksize_k=128-mode=fwd.csv @@ -0,0 +1,6 @@ +seqlen,Triton-BlockPtr,Triton-PtrBlock,Jax +1024.000000,2.753764,2.839609,1.106562 +2048.000000,4.863181,4.897792,4.370005 +4096.000000,0.003072,0.004096,16.703783 +6144.000000,0.003072,0.003072,38.594063 +8192.000000,0.003072,0.004096,67.849213 diff --git a/benchmarks/mha/triton/batch_size=1-bias=True-headdim=128-num_heads=32-blocksize_q=128-blocksize_k=128-mode=fwd.png b/benchmarks/mha/triton/batch_size=1-bias=True-headdim=128-num_heads=32-blocksize_q=128-blocksize_k=128-mode=fwd.png new file mode 100644 index 0000000..67cc04c Binary files /dev/null and b/benchmarks/mha/triton/batch_size=1-bias=True-headdim=128-num_heads=32-blocksize_q=128-blocksize_k=128-mode=fwd.png differ diff --git a/benchmarks/mha/triton/batch_size=1-bias=True-headdim=128-num_heads=32-blocksize_q=128-blocksize_k=32-mode=fwd.csv b/benchmarks/mha/triton/batch_size=1-bias=True-headdim=128-num_heads=32-blocksize_q=128-blocksize_k=32-mode=fwd.csv new file mode 100644 index 0000000..19edf6a --- /dev/null +++ b/benchmarks/mha/triton/batch_size=1-bias=True-headdim=128-num_heads=32-blocksize_q=128-blocksize_k=32-mode=fwd.csv @@ -0,0 +1,6 @@ +seqlen,Triton-BlockPtr,Triton-PtrBlock,Jax +1024.000000,1.025221,1.034738,1.094697 +2048.000000,2.964429,2.840934,4.493396 +4096.000000,13.402670,11.080611,17.136703 +6144.000000,18.417051,18.452890,38.565887 +8192.000000,22.925995,22.885717,67.944450 diff --git a/benchmarks/mha/triton/batch_size=1-bias=True-headdim=128-num_heads=32-blocksize_q=128-blocksize_k=32-mode=fwd.png b/benchmarks/mha/triton/batch_size=1-bias=True-headdim=128-num_heads=32-blocksize_q=128-blocksize_k=32-mode=fwd.png new file mode 100644 index 0000000..4b7b89d Binary files /dev/null and b/benchmarks/mha/triton/batch_size=1-bias=True-headdim=128-num_heads=32-blocksize_q=128-blocksize_k=32-mode=fwd.png differ diff --git a/benchmarks/mha/triton/batch_size=1-bias=True-headdim=128-num_heads=32-blocksize_q=128-blocksize_k=64-mode=fwd.csv b/benchmarks/mha/triton/batch_size=1-bias=True-headdim=128-num_heads=32-blocksize_q=128-blocksize_k=64-mode=fwd.csv new file mode 100644 index 0000000..bed2809 --- /dev/null +++ b/benchmarks/mha/triton/batch_size=1-bias=True-headdim=128-num_heads=32-blocksize_q=128-blocksize_k=64-mode=fwd.csv @@ -0,0 +1,6 @@ +seqlen,Triton-BlockPtr,Triton-PtrBlock,Jax +1024.000000,3.742720,3.900340,1.085318 +2048.000000,4.260864,4.048018,4.566476 +4096.000000,0.003072,0.003072,17.301767 +6144.000000,0.003072,0.003072,38.655487 +8192.000000,0.003072,0.003072,67.994911 diff --git a/benchmarks/mha/triton/batch_size=1-bias=True-headdim=128-num_heads=32-blocksize_q=128-blocksize_k=64-mode=fwd.png b/benchmarks/mha/triton/batch_size=1-bias=True-headdim=128-num_heads=32-blocksize_q=128-blocksize_k=64-mode=fwd.png new file mode 100644 index 0000000..80e063a Binary files /dev/null and b/benchmarks/mha/triton/batch_size=1-bias=True-headdim=128-num_heads=32-blocksize_q=128-blocksize_k=64-mode=fwd.png differ diff --git a/benchmarks/mha/triton/batch_size=1-bias=True-headdim=128-num_heads=32-blocksize_q=32-blocksize_k=128-mode=fwd.csv b/benchmarks/mha/triton/batch_size=1-bias=True-headdim=128-num_heads=32-blocksize_q=32-blocksize_k=128-mode=fwd.csv new file mode 100644 index 0000000..489ba71 --- /dev/null +++ b/benchmarks/mha/triton/batch_size=1-bias=True-headdim=128-num_heads=32-blocksize_q=32-blocksize_k=128-mode=fwd.csv @@ -0,0 +1,6 @@ +seqlen,Triton-BlockPtr,Triton-PtrBlock,Jax +1024.000000,0.957045,0.994006,1.112871 +2048.000000,2.705583,2.577587,4.502679 +4096.000000,10.638615,9.475910,17.381037 +6144.000000,21.745459,21.520794,38.934013 +8192.000000,23.737686,23.342079,68.062210 diff --git a/benchmarks/mha/triton/batch_size=1-bias=True-headdim=128-num_heads=32-blocksize_q=32-blocksize_k=128-mode=fwd.png b/benchmarks/mha/triton/batch_size=1-bias=True-headdim=128-num_heads=32-blocksize_q=32-blocksize_k=128-mode=fwd.png new file mode 100644 index 0000000..d632103 Binary files /dev/null and b/benchmarks/mha/triton/batch_size=1-bias=True-headdim=128-num_heads=32-blocksize_q=32-blocksize_k=128-mode=fwd.png differ diff --git a/benchmarks/mha/triton/batch_size=1-bias=True-headdim=128-num_heads=32-blocksize_q=32-blocksize_k=32-mode=fwd.csv b/benchmarks/mha/triton/batch_size=1-bias=True-headdim=128-num_heads=32-blocksize_q=32-blocksize_k=32-mode=fwd.csv new file mode 100644 index 0000000..474b694 --- /dev/null +++ b/benchmarks/mha/triton/batch_size=1-bias=True-headdim=128-num_heads=32-blocksize_q=32-blocksize_k=32-mode=fwd.csv @@ -0,0 +1,6 @@ +seqlen,Triton-BlockPtr,Triton-PtrBlock,Jax +1024.000000,0.968704,0.973122,1.091584 +2048.000000,2.630451,2.661146,4.439019 +4096.000000,12.460713,8.881322,17.302126 +6144.000000,21.493555,21.457306,38.924850 +8192.000000,23.603882,23.715498,68.141052 diff --git a/benchmarks/mha/triton/batch_size=1-bias=True-headdim=128-num_heads=32-blocksize_q=32-blocksize_k=32-mode=fwd.png b/benchmarks/mha/triton/batch_size=1-bias=True-headdim=128-num_heads=32-blocksize_q=32-blocksize_k=32-mode=fwd.png new file mode 100644 index 0000000..4e8e4c2 Binary files /dev/null and b/benchmarks/mha/triton/batch_size=1-bias=True-headdim=128-num_heads=32-blocksize_q=32-blocksize_k=32-mode=fwd.png differ diff --git a/benchmarks/mha/triton/batch_size=1-bias=True-headdim=128-num_heads=32-blocksize_q=32-blocksize_k=64-mode=fwd.csv b/benchmarks/mha/triton/batch_size=1-bias=True-headdim=128-num_heads=32-blocksize_q=32-blocksize_k=64-mode=fwd.csv new file mode 100644 index 0000000..b8e9c96 --- /dev/null +++ b/benchmarks/mha/triton/batch_size=1-bias=True-headdim=128-num_heads=32-blocksize_q=32-blocksize_k=64-mode=fwd.csv @@ -0,0 +1,6 @@ +seqlen,Triton-BlockPtr,Triton-PtrBlock,Jax +1024.000000,0.911812,0.910364,1.089223 +2048.000000,2.741373,2.567629,4.523314 +4096.000000,12.713426,10.677898,17.298477 +6144.000000,22.441986,22.220596,38.669823 +8192.000000,24.245932,24.235689,68.697090 diff --git a/benchmarks/mha/triton/batch_size=1-bias=True-headdim=128-num_heads=32-blocksize_q=32-blocksize_k=64-mode=fwd.png b/benchmarks/mha/triton/batch_size=1-bias=True-headdim=128-num_heads=32-blocksize_q=32-blocksize_k=64-mode=fwd.png new file mode 100644 index 0000000..ab59446 Binary files /dev/null and b/benchmarks/mha/triton/batch_size=1-bias=True-headdim=128-num_heads=32-blocksize_q=32-blocksize_k=64-mode=fwd.png differ diff --git a/benchmarks/mha/triton/batch_size=1-bias=True-headdim=128-num_heads=32-blocksize_q=64-blocksize_k=128-mode=fwd.csv b/benchmarks/mha/triton/batch_size=1-bias=True-headdim=128-num_heads=32-blocksize_q=64-blocksize_k=128-mode=fwd.csv new file mode 100644 index 0000000..4d64466 --- /dev/null +++ b/benchmarks/mha/triton/batch_size=1-bias=True-headdim=128-num_heads=32-blocksize_q=64-blocksize_k=128-mode=fwd.csv @@ -0,0 +1,6 @@ +seqlen,Triton-BlockPtr,Triton-PtrBlock,Jax +1024.000000,1.777743,1.969821,1.083130 +2048.000000,3.676453,4.038948,4.553192 +4096.000000,1.687552,1.638912,17.356241 +6144.000000,0.003072,0.003072,35.399170 +8192.000000,0.003072,0.003072,68.041725 diff --git a/benchmarks/mha/triton/batch_size=1-bias=True-headdim=128-num_heads=32-blocksize_q=64-blocksize_k=128-mode=fwd.png b/benchmarks/mha/triton/batch_size=1-bias=True-headdim=128-num_heads=32-blocksize_q=64-blocksize_k=128-mode=fwd.png new file mode 100644 index 0000000..051c1b6 Binary files /dev/null and b/benchmarks/mha/triton/batch_size=1-bias=True-headdim=128-num_heads=32-blocksize_q=64-blocksize_k=128-mode=fwd.png differ diff --git a/benchmarks/mha/triton/batch_size=1-bias=True-headdim=128-num_heads=32-blocksize_q=64-blocksize_k=32-mode=fwd.csv b/benchmarks/mha/triton/batch_size=1-bias=True-headdim=128-num_heads=32-blocksize_q=64-blocksize_k=32-mode=fwd.csv new file mode 100644 index 0000000..e7b170c --- /dev/null +++ b/benchmarks/mha/triton/batch_size=1-bias=True-headdim=128-num_heads=32-blocksize_q=64-blocksize_k=32-mode=fwd.csv @@ -0,0 +1,6 @@ +seqlen,Triton-BlockPtr,Triton-PtrBlock,Jax +1024.000000,0.998172,0.997159,1.083440 +2048.000000,2.897211,2.752906,4.578744 +4096.000000,9.656879,14.192175,17.401037 +6144.000000,22.697779,19.473202,38.521347 +8192.000000,24.537771,24.734037,68.796417 diff --git a/benchmarks/mha/triton/batch_size=1-bias=True-headdim=128-num_heads=32-blocksize_q=64-blocksize_k=32-mode=fwd.png b/benchmarks/mha/triton/batch_size=1-bias=True-headdim=128-num_heads=32-blocksize_q=64-blocksize_k=32-mode=fwd.png new file mode 100644 index 0000000..edc4460 Binary files /dev/null and b/benchmarks/mha/triton/batch_size=1-bias=True-headdim=128-num_heads=32-blocksize_q=64-blocksize_k=32-mode=fwd.png differ diff --git a/benchmarks/mha/triton/batch_size=1-bias=True-headdim=128-num_heads=32-blocksize_q=64-blocksize_k=64-mode=fwd.csv b/benchmarks/mha/triton/batch_size=1-bias=True-headdim=128-num_heads=32-blocksize_q=64-blocksize_k=64-mode=fwd.csv new file mode 100644 index 0000000..df80955 --- /dev/null +++ b/benchmarks/mha/triton/batch_size=1-bias=True-headdim=128-num_heads=32-blocksize_q=64-blocksize_k=64-mode=fwd.csv @@ -0,0 +1,6 @@ +seqlen,Triton-BlockPtr,Triton-PtrBlock,Jax +1024.000000,0.951505,0.948169,1.080559 +2048.000000,2.976518,2.638799,4.534225 +4096.000000,10.240512,9.564502,17.289421 +6144.000000,22.150658,20.432896,38.984238 +8192.000000,22.020781,22.278145,68.149246 diff --git a/benchmarks/mha/triton/batch_size=1-bias=True-headdim=128-num_heads=32-blocksize_q=64-blocksize_k=64-mode=fwd.png b/benchmarks/mha/triton/batch_size=1-bias=True-headdim=128-num_heads=32-blocksize_q=64-blocksize_k=64-mode=fwd.png new file mode 100644 index 0000000..bf2cbc3 Binary files /dev/null and b/benchmarks/mha/triton/batch_size=1-bias=True-headdim=128-num_heads=32-blocksize_q=64-blocksize_k=64-mode=fwd.png differ diff --git a/benchmarks/mha/triton/batch_size=1-bias=True-headdim=128-num_heads=8-blocksize_q=128-blocksize_k=128-mode=fwd.csv b/benchmarks/mha/triton/batch_size=1-bias=True-headdim=128-num_heads=8-blocksize_q=128-blocksize_k=128-mode=fwd.csv new file mode 100644 index 0000000..53457d7 --- /dev/null +++ b/benchmarks/mha/triton/batch_size=1-bias=True-headdim=128-num_heads=8-blocksize_q=128-blocksize_k=128-mode=fwd.csv @@ -0,0 +1,6 @@ +seqlen,Triton-BlockPtr,Triton-PtrBlock,Jax +1024.000000,3.040020,2.974981,0.348574 +2048.000000,6.084949,5.996487,1.167124 +4096.000000,5.499700,5.457920,4.349885 +6144.000000,1.137664,1.137664,9.901888 +8192.000000,0.003072,0.003072,17.248333 diff --git a/benchmarks/mha/triton/batch_size=1-bias=True-headdim=128-num_heads=8-blocksize_q=128-blocksize_k=128-mode=fwd.png b/benchmarks/mha/triton/batch_size=1-bias=True-headdim=128-num_heads=8-blocksize_q=128-blocksize_k=128-mode=fwd.png new file mode 100644 index 0000000..80b4742 Binary files /dev/null and b/benchmarks/mha/triton/batch_size=1-bias=True-headdim=128-num_heads=8-blocksize_q=128-blocksize_k=128-mode=fwd.png differ diff --git a/benchmarks/mha/triton/batch_size=1-bias=True-headdim=128-num_heads=8-blocksize_q=128-blocksize_k=32-mode=fwd.csv b/benchmarks/mha/triton/batch_size=1-bias=True-headdim=128-num_heads=8-blocksize_q=128-blocksize_k=32-mode=fwd.csv new file mode 100644 index 0000000..85c6686 --- /dev/null +++ b/benchmarks/mha/triton/batch_size=1-bias=True-headdim=128-num_heads=8-blocksize_q=128-blocksize_k=32-mode=fwd.csv @@ -0,0 +1,6 @@ +seqlen,Triton-BlockPtr,Triton-PtrBlock,Jax +1024.000000,0.560635,0.555188,0.364574 +2048.000000,1.076938,1.107759,1.200377 +4096.000000,3.644748,3.445171,4.462516 +6144.000000,6.466183,6.513503,9.924298 +8192.000000,10.267090,10.228271,17.215923 diff --git a/benchmarks/mha/triton/batch_size=1-bias=True-headdim=128-num_heads=8-blocksize_q=128-blocksize_k=32-mode=fwd.png b/benchmarks/mha/triton/batch_size=1-bias=True-headdim=128-num_heads=8-blocksize_q=128-blocksize_k=32-mode=fwd.png new file mode 100644 index 0000000..8350381 Binary files /dev/null and b/benchmarks/mha/triton/batch_size=1-bias=True-headdim=128-num_heads=8-blocksize_q=128-blocksize_k=32-mode=fwd.png differ diff --git a/benchmarks/mha/triton/batch_size=1-bias=True-headdim=128-num_heads=8-blocksize_q=128-blocksize_k=64-mode=fwd.csv b/benchmarks/mha/triton/batch_size=1-bias=True-headdim=128-num_heads=8-blocksize_q=128-blocksize_k=64-mode=fwd.csv new file mode 100644 index 0000000..761e591 --- /dev/null +++ b/benchmarks/mha/triton/batch_size=1-bias=True-headdim=128-num_heads=8-blocksize_q=128-blocksize_k=64-mode=fwd.csv @@ -0,0 +1,6 @@ +seqlen,Triton-BlockPtr,Triton-PtrBlock,Jax +1024.000000,2.581756,2.529858,0.346759 +2048.000000,7.672044,7.718761,1.186732 +4096.000000,10.138625,10.072649,4.452056 +6144.000000,7.129429,7.067989,9.868782 +8192.000000,0.004096,0.003072,17.150057 diff --git a/benchmarks/mha/triton/batch_size=1-bias=True-headdim=128-num_heads=8-blocksize_q=128-blocksize_k=64-mode=fwd.png b/benchmarks/mha/triton/batch_size=1-bias=True-headdim=128-num_heads=8-blocksize_q=128-blocksize_k=64-mode=fwd.png new file mode 100644 index 0000000..61f97b1 Binary files /dev/null and b/benchmarks/mha/triton/batch_size=1-bias=True-headdim=128-num_heads=8-blocksize_q=128-blocksize_k=64-mode=fwd.png differ diff --git a/benchmarks/mha/triton/batch_size=1-bias=True-headdim=128-num_heads=8-blocksize_q=32-blocksize_k=128-mode=fwd.csv b/benchmarks/mha/triton/batch_size=1-bias=True-headdim=128-num_heads=8-blocksize_q=32-blocksize_k=128-mode=fwd.csv new file mode 100644 index 0000000..efb533d --- /dev/null +++ b/benchmarks/mha/triton/batch_size=1-bias=True-headdim=128-num_heads=8-blocksize_q=32-blocksize_k=128-mode=fwd.csv @@ -0,0 +1,6 @@ +seqlen,Triton-BlockPtr,Triton-PtrBlock,Jax +1024.000000,0.512348,0.510214,0.358040 +2048.000000,1.105719,1.136913,1.194837 +4096.000000,2.942582,2.856186,4.486412 +6144.000000,5.508301,5.503437,9.821679 +8192.000000,9.229228,9.370795,17.296486 diff --git a/benchmarks/mha/triton/batch_size=1-bias=True-headdim=128-num_heads=8-blocksize_q=32-blocksize_k=128-mode=fwd.png b/benchmarks/mha/triton/batch_size=1-bias=True-headdim=128-num_heads=8-blocksize_q=32-blocksize_k=128-mode=fwd.png new file mode 100644 index 0000000..5733bf4 Binary files /dev/null and b/benchmarks/mha/triton/batch_size=1-bias=True-headdim=128-num_heads=8-blocksize_q=32-blocksize_k=128-mode=fwd.png differ diff --git a/benchmarks/mha/triton/batch_size=1-bias=True-headdim=128-num_heads=8-blocksize_q=32-blocksize_k=32-mode=fwd.csv b/benchmarks/mha/triton/batch_size=1-bias=True-headdim=128-num_heads=8-blocksize_q=32-blocksize_k=32-mode=fwd.csv new file mode 100644 index 0000000..814bc11 --- /dev/null +++ b/benchmarks/mha/triton/batch_size=1-bias=True-headdim=128-num_heads=8-blocksize_q=32-blocksize_k=32-mode=fwd.csv @@ -0,0 +1,6 @@ +seqlen,Triton-BlockPtr,Triton-PtrBlock,Jax +1024.000000,0.599912,0.534570,0.347691 +2048.000000,1.093824,1.108057,1.186531 +4096.000000,2.889011,2.905771,4.465289 +6144.000000,5.688320,5.695539,9.976750 +8192.000000,9.623296,9.651883,17.107878 diff --git a/benchmarks/mha/triton/batch_size=1-bias=True-headdim=128-num_heads=8-blocksize_q=32-blocksize_k=32-mode=fwd.png b/benchmarks/mha/triton/batch_size=1-bias=True-headdim=128-num_heads=8-blocksize_q=32-blocksize_k=32-mode=fwd.png new file mode 100644 index 0000000..6dd6253 Binary files /dev/null and b/benchmarks/mha/triton/batch_size=1-bias=True-headdim=128-num_heads=8-blocksize_q=32-blocksize_k=32-mode=fwd.png differ diff --git a/benchmarks/mha/triton/batch_size=1-bias=True-headdim=128-num_heads=8-blocksize_q=32-blocksize_k=64-mode=fwd.csv b/benchmarks/mha/triton/batch_size=1-bias=True-headdim=128-num_heads=8-blocksize_q=32-blocksize_k=64-mode=fwd.csv new file mode 100644 index 0000000..1e14dde --- /dev/null +++ b/benchmarks/mha/triton/batch_size=1-bias=True-headdim=128-num_heads=8-blocksize_q=32-blocksize_k=64-mode=fwd.csv @@ -0,0 +1,6 @@ +seqlen,Triton-BlockPtr,Triton-PtrBlock,Jax +1024.000000,0.504931,0.514386,0.364398 +2048.000000,1.116275,1.143132,1.181261 +4096.000000,3.063310,3.014886,4.496333 +6144.000000,5.727286,5.698721,9.904957 +8192.000000,9.104850,9.075433,17.331881 diff --git a/benchmarks/mha/triton/batch_size=1-bias=True-headdim=128-num_heads=8-blocksize_q=32-blocksize_k=64-mode=fwd.png b/benchmarks/mha/triton/batch_size=1-bias=True-headdim=128-num_heads=8-blocksize_q=32-blocksize_k=64-mode=fwd.png new file mode 100644 index 0000000..7188059 Binary files /dev/null and b/benchmarks/mha/triton/batch_size=1-bias=True-headdim=128-num_heads=8-blocksize_q=32-blocksize_k=64-mode=fwd.png differ diff --git a/benchmarks/mha/triton/batch_size=1-bias=True-headdim=128-num_heads=8-blocksize_q=64-blocksize_k=128-mode=fwd.csv b/benchmarks/mha/triton/batch_size=1-bias=True-headdim=128-num_heads=8-blocksize_q=64-blocksize_k=128-mode=fwd.csv new file mode 100644 index 0000000..ba9ec37 --- /dev/null +++ b/benchmarks/mha/triton/batch_size=1-bias=True-headdim=128-num_heads=8-blocksize_q=64-blocksize_k=128-mode=fwd.csv @@ -0,0 +1,6 @@ +seqlen,Triton-BlockPtr,Triton-PtrBlock,Jax +1024.000000,1.931442,1.900585,0.360111 +2048.000000,4.193792,4.176502,1.240034 +4096.000000,4.465518,4.435675,4.415443 +6144.000000,5.904384,5.866837,9.832171 +8192.000000,0.003072,0.003072,17.251333 diff --git a/benchmarks/mha/triton/batch_size=1-bias=True-headdim=128-num_heads=8-blocksize_q=64-blocksize_k=128-mode=fwd.png b/benchmarks/mha/triton/batch_size=1-bias=True-headdim=128-num_heads=8-blocksize_q=64-blocksize_k=128-mode=fwd.png new file mode 100644 index 0000000..803e24f Binary files /dev/null and b/benchmarks/mha/triton/batch_size=1-bias=True-headdim=128-num_heads=8-blocksize_q=64-blocksize_k=128-mode=fwd.png differ diff --git a/benchmarks/mha/triton/batch_size=1-bias=True-headdim=128-num_heads=8-blocksize_q=64-blocksize_k=32-mode=fwd.csv b/benchmarks/mha/triton/batch_size=1-bias=True-headdim=128-num_heads=8-blocksize_q=64-blocksize_k=32-mode=fwd.csv new file mode 100644 index 0000000..2def99d --- /dev/null +++ b/benchmarks/mha/triton/batch_size=1-bias=True-headdim=128-num_heads=8-blocksize_q=64-blocksize_k=32-mode=fwd.csv @@ -0,0 +1,6 @@ +seqlen,Triton-BlockPtr,Triton-PtrBlock,Jax +1024.000000,0.541245,0.529996,0.354656 +2048.000000,1.152850,1.207438,1.194703 +4096.000000,3.333174,3.271054,4.501976 +6144.000000,5.994927,6.032437,9.939907 +8192.000000,9.582685,9.535302,17.098560 diff --git a/benchmarks/mha/triton/batch_size=1-bias=True-headdim=128-num_heads=8-blocksize_q=64-blocksize_k=32-mode=fwd.png b/benchmarks/mha/triton/batch_size=1-bias=True-headdim=128-num_heads=8-blocksize_q=64-blocksize_k=32-mode=fwd.png new file mode 100644 index 0000000..9a3f02c Binary files /dev/null and b/benchmarks/mha/triton/batch_size=1-bias=True-headdim=128-num_heads=8-blocksize_q=64-blocksize_k=32-mode=fwd.png differ diff --git a/benchmarks/mha/triton/batch_size=1-bias=True-headdim=128-num_heads=8-blocksize_q=64-blocksize_k=64-mode=fwd.csv b/benchmarks/mha/triton/batch_size=1-bias=True-headdim=128-num_heads=8-blocksize_q=64-blocksize_k=64-mode=fwd.csv new file mode 100644 index 0000000..cf35a41 --- /dev/null +++ b/benchmarks/mha/triton/batch_size=1-bias=True-headdim=128-num_heads=8-blocksize_q=64-blocksize_k=64-mode=fwd.csv @@ -0,0 +1,6 @@ +seqlen,Triton-BlockPtr,Triton-PtrBlock,Jax +1024.000000,0.525824,0.522889,0.355479 +2048.000000,1.100346,1.099401,1.181028 +4096.000000,3.237861,3.078695,4.526319 +6144.000000,5.759074,5.770679,9.938994 +8192.000000,9.735083,9.602986,17.277338 diff --git a/benchmarks/mha/triton/batch_size=1-bias=True-headdim=128-num_heads=8-blocksize_q=64-blocksize_k=64-mode=fwd.png b/benchmarks/mha/triton/batch_size=1-bias=True-headdim=128-num_heads=8-blocksize_q=64-blocksize_k=64-mode=fwd.png new file mode 100644 index 0000000..56d55eb Binary files /dev/null and b/benchmarks/mha/triton/batch_size=1-bias=True-headdim=128-num_heads=8-blocksize_q=64-blocksize_k=64-mode=fwd.png differ diff --git a/benchmarks/mha/triton/batch_size=1-bias=True-headdim=256-num_heads=16-blocksize_q=128-blocksize_k=128-mode=fwd.csv b/benchmarks/mha/triton/batch_size=1-bias=True-headdim=256-num_heads=16-blocksize_q=128-blocksize_k=128-mode=fwd.csv new file mode 100644 index 0000000..7d8fc7c --- /dev/null +++ b/benchmarks/mha/triton/batch_size=1-bias=True-headdim=256-num_heads=16-blocksize_q=128-blocksize_k=128-mode=fwd.csv @@ -0,0 +1,6 @@ +seqlen,Triton-BlockPtr,Triton-PtrBlock,Jax +1024.000000,100.000000,100.000000,0.691624 +2048.000000,100.000000,100.000000,2.722255 +4096.000000,100.000000,100.000000,10.320362 +6144.000000,100.000000,100.000000,23.130680 +8192.000000,100.000000,100.000000,40.658432 diff --git a/benchmarks/mha/triton/batch_size=1-bias=True-headdim=256-num_heads=16-blocksize_q=128-blocksize_k=128-mode=fwd.png b/benchmarks/mha/triton/batch_size=1-bias=True-headdim=256-num_heads=16-blocksize_q=128-blocksize_k=128-mode=fwd.png new file mode 100644 index 0000000..a8b43ca Binary files /dev/null and b/benchmarks/mha/triton/batch_size=1-bias=True-headdim=256-num_heads=16-blocksize_q=128-blocksize_k=128-mode=fwd.png differ diff --git a/benchmarks/mha/triton/batch_size=1-bias=True-headdim=256-num_heads=16-blocksize_q=128-blocksize_k=32-mode=fwd.csv b/benchmarks/mha/triton/batch_size=1-bias=True-headdim=256-num_heads=16-blocksize_q=128-blocksize_k=32-mode=fwd.csv new file mode 100644 index 0000000..8c70e17 --- /dev/null +++ b/benchmarks/mha/triton/batch_size=1-bias=True-headdim=256-num_heads=16-blocksize_q=128-blocksize_k=32-mode=fwd.csv @@ -0,0 +1,6 @@ +seqlen,Triton-BlockPtr,Triton-PtrBlock,Jax +1024.000000,4.876773,4.929485,0.698022 +2048.000000,3.847987,3.750912,2.704280 +4096.000000,0.488448,0.453632,10.293716 +6144.000000,0.003072,0.003072,23.290936 +8192.000000,0.003072,0.003072,40.944687 diff --git a/benchmarks/mha/triton/batch_size=1-bias=True-headdim=256-num_heads=16-blocksize_q=128-blocksize_k=32-mode=fwd.png b/benchmarks/mha/triton/batch_size=1-bias=True-headdim=256-num_heads=16-blocksize_q=128-blocksize_k=32-mode=fwd.png new file mode 100644 index 0000000..d1e8a0c Binary files /dev/null and b/benchmarks/mha/triton/batch_size=1-bias=True-headdim=256-num_heads=16-blocksize_q=128-blocksize_k=32-mode=fwd.png differ diff --git a/benchmarks/mha/triton/batch_size=1-bias=True-headdim=256-num_heads=16-blocksize_q=128-blocksize_k=64-mode=fwd.csv b/benchmarks/mha/triton/batch_size=1-bias=True-headdim=256-num_heads=16-blocksize_q=128-blocksize_k=64-mode=fwd.csv new file mode 100644 index 0000000..6e89b41 --- /dev/null +++ b/benchmarks/mha/triton/batch_size=1-bias=True-headdim=256-num_heads=16-blocksize_q=128-blocksize_k=64-mode=fwd.csv @@ -0,0 +1,6 @@ +seqlen,Triton-BlockPtr,Triton-PtrBlock,Jax +1024.000000,4.473233,4.461936,0.704037 +2048.000000,4.178651,4.144274,2.670618 +4096.000000,0.315392,0.572416,10.311595 +6144.000000,0.003072,0.003072,23.181889 +8192.000000,0.004096,0.003072,40.328911 diff --git a/benchmarks/mha/triton/batch_size=1-bias=True-headdim=256-num_heads=16-blocksize_q=128-blocksize_k=64-mode=fwd.png b/benchmarks/mha/triton/batch_size=1-bias=True-headdim=256-num_heads=16-blocksize_q=128-blocksize_k=64-mode=fwd.png new file mode 100644 index 0000000..f78a9b8 Binary files /dev/null and b/benchmarks/mha/triton/batch_size=1-bias=True-headdim=256-num_heads=16-blocksize_q=128-blocksize_k=64-mode=fwd.png differ diff --git a/benchmarks/mha/triton/batch_size=1-bias=True-headdim=256-num_heads=16-blocksize_q=32-blocksize_k=128-mode=fwd.csv b/benchmarks/mha/triton/batch_size=1-bias=True-headdim=256-num_heads=16-blocksize_q=32-blocksize_k=128-mode=fwd.csv new file mode 100644 index 0000000..8ef3d36 --- /dev/null +++ b/benchmarks/mha/triton/batch_size=1-bias=True-headdim=256-num_heads=16-blocksize_q=32-blocksize_k=128-mode=fwd.csv @@ -0,0 +1,6 @@ +seqlen,Triton-BlockPtr,Triton-PtrBlock,Jax +1024.000000,0.860077,0.876893,0.714357 +2048.000000,1.887624,1.907285,2.703474 +4096.000000,10.163199,12.670791,10.478190 +6144.000000,26.590210,26.660456,23.091064 +8192.000000,3.069952,3.113472,40.281601 diff --git a/benchmarks/mha/triton/batch_size=1-bias=True-headdim=256-num_heads=16-blocksize_q=32-blocksize_k=128-mode=fwd.png b/benchmarks/mha/triton/batch_size=1-bias=True-headdim=256-num_heads=16-blocksize_q=32-blocksize_k=128-mode=fwd.png new file mode 100644 index 0000000..038a3af Binary files /dev/null and b/benchmarks/mha/triton/batch_size=1-bias=True-headdim=256-num_heads=16-blocksize_q=32-blocksize_k=128-mode=fwd.png differ diff --git a/benchmarks/mha/triton/batch_size=1-bias=True-headdim=256-num_heads=16-blocksize_q=32-blocksize_k=32-mode=fwd.csv b/benchmarks/mha/triton/batch_size=1-bias=True-headdim=256-num_heads=16-blocksize_q=32-blocksize_k=32-mode=fwd.csv new file mode 100644 index 0000000..924c569 --- /dev/null +++ b/benchmarks/mha/triton/batch_size=1-bias=True-headdim=256-num_heads=16-blocksize_q=32-blocksize_k=32-mode=fwd.csv @@ -0,0 +1,6 @@ +seqlen,Triton-BlockPtr,Triton-PtrBlock,Jax +1024.000000,0.807094,0.770943,0.690388 +2048.000000,1.717210,1.670945,2.716416 +4096.000000,10.335172,10.351857,10.006304 +6144.000000,21.918976,20.243073,23.192335 +8192.000000,22.811649,22.941185,40.975967 diff --git a/benchmarks/mha/triton/batch_size=1-bias=True-headdim=256-num_heads=16-blocksize_q=32-blocksize_k=32-mode=fwd.png b/benchmarks/mha/triton/batch_size=1-bias=True-headdim=256-num_heads=16-blocksize_q=32-blocksize_k=32-mode=fwd.png new file mode 100644 index 0000000..1e3bc8e Binary files /dev/null and b/benchmarks/mha/triton/batch_size=1-bias=True-headdim=256-num_heads=16-blocksize_q=32-blocksize_k=32-mode=fwd.png differ diff --git a/benchmarks/mha/triton/batch_size=1-bias=True-headdim=256-num_heads=16-blocksize_q=32-blocksize_k=64-mode=fwd.csv b/benchmarks/mha/triton/batch_size=1-bias=True-headdim=256-num_heads=16-blocksize_q=32-blocksize_k=64-mode=fwd.csv new file mode 100644 index 0000000..3eaa0cd --- /dev/null +++ b/benchmarks/mha/triton/batch_size=1-bias=True-headdim=256-num_heads=16-blocksize_q=32-blocksize_k=64-mode=fwd.csv @@ -0,0 +1,6 @@ +seqlen,Triton-BlockPtr,Triton-PtrBlock,Jax +1024.000000,0.837990,0.825227,0.698040 +2048.000000,1.764505,1.920213,2.667003 +4096.000000,12.916299,9.438209,10.284634 +6144.000000,21.046955,21.005653,23.195648 +8192.000000,19.953665,28.444927,40.342575 diff --git a/benchmarks/mha/triton/batch_size=1-bias=True-headdim=256-num_heads=16-blocksize_q=32-blocksize_k=64-mode=fwd.png b/benchmarks/mha/triton/batch_size=1-bias=True-headdim=256-num_heads=16-blocksize_q=32-blocksize_k=64-mode=fwd.png new file mode 100644 index 0000000..ecdce3e Binary files /dev/null and b/benchmarks/mha/triton/batch_size=1-bias=True-headdim=256-num_heads=16-blocksize_q=32-blocksize_k=64-mode=fwd.png differ diff --git a/benchmarks/mha/triton/batch_size=1-bias=True-headdim=256-num_heads=16-blocksize_q=64-blocksize_k=128-mode=fwd.csv b/benchmarks/mha/triton/batch_size=1-bias=True-headdim=256-num_heads=16-blocksize_q=64-blocksize_k=128-mode=fwd.csv new file mode 100644 index 0000000..63a888e --- /dev/null +++ b/benchmarks/mha/triton/batch_size=1-bias=True-headdim=256-num_heads=16-blocksize_q=64-blocksize_k=128-mode=fwd.csv @@ -0,0 +1,6 @@ +seqlen,Triton-BlockPtr,Triton-PtrBlock,Jax +1024.000000,2.267927,2.215847,0.700681 +2048.000000,3.376128,3.412992,2.738209 +4096.000000,0.364544,0.473088,10.226183 +6144.000000,0.003072,0.004096,23.258713 +8192.000000,0.002048,0.004096,40.428688 diff --git a/benchmarks/mha/triton/batch_size=1-bias=True-headdim=256-num_heads=16-blocksize_q=64-blocksize_k=128-mode=fwd.png b/benchmarks/mha/triton/batch_size=1-bias=True-headdim=256-num_heads=16-blocksize_q=64-blocksize_k=128-mode=fwd.png new file mode 100644 index 0000000..e9fd69a Binary files /dev/null and b/benchmarks/mha/triton/batch_size=1-bias=True-headdim=256-num_heads=16-blocksize_q=64-blocksize_k=128-mode=fwd.png differ diff --git a/benchmarks/mha/triton/batch_size=1-bias=True-headdim=256-num_heads=16-blocksize_q=64-blocksize_k=32-mode=fwd.csv b/benchmarks/mha/triton/batch_size=1-bias=True-headdim=256-num_heads=16-blocksize_q=64-blocksize_k=32-mode=fwd.csv new file mode 100644 index 0000000..e6bb677 --- /dev/null +++ b/benchmarks/mha/triton/batch_size=1-bias=True-headdim=256-num_heads=16-blocksize_q=64-blocksize_k=32-mode=fwd.csv @@ -0,0 +1,6 @@ +seqlen,Triton-BlockPtr,Triton-PtrBlock,Jax +1024.000000,0.927602,0.906539,0.709993 +2048.000000,1.966324,1.941163,2.660576 +4096.000000,14.953385,10.312547,10.175126 +6144.000000,24.183807,24.354815,23.311432 +8192.000000,22.935211,22.932821,40.468033 diff --git a/benchmarks/mha/triton/batch_size=1-bias=True-headdim=256-num_heads=16-blocksize_q=64-blocksize_k=32-mode=fwd.png b/benchmarks/mha/triton/batch_size=1-bias=True-headdim=256-num_heads=16-blocksize_q=64-blocksize_k=32-mode=fwd.png new file mode 100644 index 0000000..10824e0 Binary files /dev/null and b/benchmarks/mha/triton/batch_size=1-bias=True-headdim=256-num_heads=16-blocksize_q=64-blocksize_k=32-mode=fwd.png differ diff --git a/benchmarks/mha/triton/batch_size=1-bias=True-headdim=256-num_heads=16-blocksize_q=64-blocksize_k=64-mode=fwd.csv b/benchmarks/mha/triton/batch_size=1-bias=True-headdim=256-num_heads=16-blocksize_q=64-blocksize_k=64-mode=fwd.csv new file mode 100644 index 0000000..8d7225c --- /dev/null +++ b/benchmarks/mha/triton/batch_size=1-bias=True-headdim=256-num_heads=16-blocksize_q=64-blocksize_k=64-mode=fwd.csv @@ -0,0 +1,6 @@ +seqlen,Triton-BlockPtr,Triton-PtrBlock,Jax +1024.000000,1.036081,1.023680,0.714913 +2048.000000,2.136603,2.056358,2.678466 +4096.000000,13.453056,11.403092,10.167432 +6144.000000,18.498970,24.134861,23.233200 +8192.000000,24.797186,25.172651,40.913567 diff --git a/benchmarks/mha/triton/batch_size=1-bias=True-headdim=256-num_heads=16-blocksize_q=64-blocksize_k=64-mode=fwd.png b/benchmarks/mha/triton/batch_size=1-bias=True-headdim=256-num_heads=16-blocksize_q=64-blocksize_k=64-mode=fwd.png new file mode 100644 index 0000000..6f0c899 Binary files /dev/null and b/benchmarks/mha/triton/batch_size=1-bias=True-headdim=256-num_heads=16-blocksize_q=64-blocksize_k=64-mode=fwd.png differ diff --git a/benchmarks/mha/triton/batch_size=1-bias=True-headdim=256-num_heads=32-blocksize_q=128-blocksize_k=128-mode=fwd.csv b/benchmarks/mha/triton/batch_size=1-bias=True-headdim=256-num_heads=32-blocksize_q=128-blocksize_k=128-mode=fwd.csv new file mode 100644 index 0000000..77a393d --- /dev/null +++ b/benchmarks/mha/triton/batch_size=1-bias=True-headdim=256-num_heads=32-blocksize_q=128-blocksize_k=128-mode=fwd.csv @@ -0,0 +1,6 @@ +seqlen,Triton-BlockPtr,Triton-PtrBlock,Jax +1024.000000,100.000000,100.000000,1.285133 +2048.000000,100.000000,100.000000,5.231750 +4096.000000,100.000000,100.000000,21.936249 +6144.000000,100.000000,100.000000,46.858753 +8192.000000,100.000000,100.000000,300.000000 diff --git a/benchmarks/mha/triton/batch_size=1-bias=True-headdim=256-num_heads=32-blocksize_q=128-blocksize_k=128-mode=fwd.png b/benchmarks/mha/triton/batch_size=1-bias=True-headdim=256-num_heads=32-blocksize_q=128-blocksize_k=128-mode=fwd.png new file mode 100644 index 0000000..093ec83 Binary files /dev/null and b/benchmarks/mha/triton/batch_size=1-bias=True-headdim=256-num_heads=32-blocksize_q=128-blocksize_k=128-mode=fwd.png differ diff --git a/benchmarks/mha/triton/batch_size=1-bias=True-headdim=256-num_heads=32-blocksize_q=128-blocksize_k=32-mode=fwd.csv b/benchmarks/mha/triton/batch_size=1-bias=True-headdim=256-num_heads=32-blocksize_q=128-blocksize_k=32-mode=fwd.csv new file mode 100644 index 0000000..96bcf0e --- /dev/null +++ b/benchmarks/mha/triton/batch_size=1-bias=True-headdim=256-num_heads=32-blocksize_q=128-blocksize_k=32-mode=fwd.csv @@ -0,0 +1,6 @@ +seqlen,Triton-BlockPtr,Triton-PtrBlock,Jax +1024.000000,2.857704,3.145728,1.266194 +2048.000000,0.561152,0.583168,5.180035 +4096.000000,0.002048,0.004096,20.644926 +6144.000000,0.003072,0.002048,46.417408 +8192.000000,0.003072,0.003072,80.237572 diff --git a/benchmarks/mha/triton/batch_size=1-bias=True-headdim=256-num_heads=32-blocksize_q=128-blocksize_k=32-mode=fwd.png b/benchmarks/mha/triton/batch_size=1-bias=True-headdim=256-num_heads=32-blocksize_q=128-blocksize_k=32-mode=fwd.png new file mode 100644 index 0000000..6ac39a0 Binary files /dev/null and b/benchmarks/mha/triton/batch_size=1-bias=True-headdim=256-num_heads=32-blocksize_q=128-blocksize_k=32-mode=fwd.png differ diff --git a/benchmarks/mha/triton/batch_size=1-bias=True-headdim=256-num_heads=32-blocksize_q=128-blocksize_k=64-mode=fwd.csv b/benchmarks/mha/triton/batch_size=1-bias=True-headdim=256-num_heads=32-blocksize_q=128-blocksize_k=64-mode=fwd.csv new file mode 100644 index 0000000..1cb79b8 --- /dev/null +++ b/benchmarks/mha/triton/batch_size=1-bias=True-headdim=256-num_heads=32-blocksize_q=128-blocksize_k=64-mode=fwd.csv @@ -0,0 +1,6 @@ +seqlen,Triton-BlockPtr,Triton-PtrBlock,Jax +1024.000000,3.013710,3.198188,1.286163 +2048.000000,3.422549,21.921453,5.259470 +4096.000000,0.004096,0.002048,20.294281 +6144.000000,0.003072,0.004096,46.174721 +8192.000000,0.004096,0.004096,80.109634 diff --git a/benchmarks/mha/triton/batch_size=1-bias=True-headdim=256-num_heads=32-blocksize_q=128-blocksize_k=64-mode=fwd.png b/benchmarks/mha/triton/batch_size=1-bias=True-headdim=256-num_heads=32-blocksize_q=128-blocksize_k=64-mode=fwd.png new file mode 100644 index 0000000..2ff0676 Binary files /dev/null and b/benchmarks/mha/triton/batch_size=1-bias=True-headdim=256-num_heads=32-blocksize_q=128-blocksize_k=64-mode=fwd.png differ diff --git a/benchmarks/mha/triton/batch_size=1-bias=True-headdim=256-num_heads=32-blocksize_q=32-blocksize_k=128-mode=fwd.csv b/benchmarks/mha/triton/batch_size=1-bias=True-headdim=256-num_heads=32-blocksize_q=32-blocksize_k=128-mode=fwd.csv new file mode 100644 index 0000000..be3cc52 --- /dev/null +++ b/benchmarks/mha/triton/batch_size=1-bias=True-headdim=256-num_heads=32-blocksize_q=32-blocksize_k=128-mode=fwd.csv @@ -0,0 +1,6 @@ +seqlen,Triton-BlockPtr,Triton-PtrBlock,Jax +1024.000000,1.045358,1.053768,1.278692 +2048.000000,6.751124,7.346637,5.193031 +4096.000000,25.534668,23.934977,20.418610 +6144.000000,3.588608,3.576832,46.257294 +8192.000000,0.004096,0.003072,80.263329 diff --git a/benchmarks/mha/triton/batch_size=1-bias=True-headdim=256-num_heads=32-blocksize_q=32-blocksize_k=128-mode=fwd.png b/benchmarks/mha/triton/batch_size=1-bias=True-headdim=256-num_heads=32-blocksize_q=32-blocksize_k=128-mode=fwd.png new file mode 100644 index 0000000..2a8492c Binary files /dev/null and b/benchmarks/mha/triton/batch_size=1-bias=True-headdim=256-num_heads=32-blocksize_q=32-blocksize_k=128-mode=fwd.png differ diff --git a/benchmarks/mha/triton/batch_size=1-bias=True-headdim=256-num_heads=32-blocksize_q=32-blocksize_k=32-mode=fwd.csv b/benchmarks/mha/triton/batch_size=1-bias=True-headdim=256-num_heads=32-blocksize_q=32-blocksize_k=32-mode=fwd.csv new file mode 100644 index 0000000..64c0c79 --- /dev/null +++ b/benchmarks/mha/triton/batch_size=1-bias=True-headdim=256-num_heads=32-blocksize_q=32-blocksize_k=32-mode=fwd.csv @@ -0,0 +1,6 @@ +seqlen,Triton-BlockPtr,Triton-PtrBlock,Jax +1024.000000,0.993280,1.006732,1.257991 +2048.000000,5.729147,5.827188,5.154608 +4096.000000,20.154140,20.090538,20.252800 +6144.000000,25.625088,25.831936,46.000641 +8192.000000,5.995520,6.025216,80.070755 diff --git a/benchmarks/mha/triton/batch_size=1-bias=True-headdim=256-num_heads=32-blocksize_q=32-blocksize_k=32-mode=fwd.png b/benchmarks/mha/triton/batch_size=1-bias=True-headdim=256-num_heads=32-blocksize_q=32-blocksize_k=32-mode=fwd.png new file mode 100644 index 0000000..d3cca35 Binary files /dev/null and b/benchmarks/mha/triton/batch_size=1-bias=True-headdim=256-num_heads=32-blocksize_q=32-blocksize_k=32-mode=fwd.png differ diff --git a/benchmarks/mha/triton/batch_size=1-bias=True-headdim=256-num_heads=32-blocksize_q=32-blocksize_k=64-mode=fwd.csv b/benchmarks/mha/triton/batch_size=1-bias=True-headdim=256-num_heads=32-blocksize_q=32-blocksize_k=64-mode=fwd.csv new file mode 100644 index 0000000..82f3c9d --- /dev/null +++ b/benchmarks/mha/triton/batch_size=1-bias=True-headdim=256-num_heads=32-blocksize_q=32-blocksize_k=64-mode=fwd.csv @@ -0,0 +1,6 @@ +seqlen,Triton-BlockPtr,Triton-PtrBlock,Jax +1024.000000,0.946216,1.044312,1.276984 +2048.000000,6.008791,5.687414,5.095865 +4096.000000,21.311342,20.087954,20.641176 +6144.000000,22.704468,22.578516,46.187523 +8192.000000,5.995520,6.011392,79.864830 diff --git a/benchmarks/mha/triton/batch_size=1-bias=True-headdim=256-num_heads=32-blocksize_q=32-blocksize_k=64-mode=fwd.png b/benchmarks/mha/triton/batch_size=1-bias=True-headdim=256-num_heads=32-blocksize_q=32-blocksize_k=64-mode=fwd.png new file mode 100644 index 0000000..3522a65 Binary files /dev/null and b/benchmarks/mha/triton/batch_size=1-bias=True-headdim=256-num_heads=32-blocksize_q=32-blocksize_k=64-mode=fwd.png differ diff --git a/benchmarks/mha/triton/batch_size=1-bias=True-headdim=256-num_heads=32-blocksize_q=64-blocksize_k=128-mode=fwd.csv b/benchmarks/mha/triton/batch_size=1-bias=True-headdim=256-num_heads=32-blocksize_q=64-blocksize_k=128-mode=fwd.csv new file mode 100644 index 0000000..bcbb2f1 --- /dev/null +++ b/benchmarks/mha/triton/batch_size=1-bias=True-headdim=256-num_heads=32-blocksize_q=64-blocksize_k=128-mode=fwd.csv @@ -0,0 +1,6 @@ +seqlen,Triton-BlockPtr,Triton-PtrBlock,Jax +1024.000000,2.667362,2.823798,1.283813 +2048.000000,23.664640,2.624853,5.328611 +4096.000000,0.003072,0.003072,20.488480 +6144.000000,0.003072,0.003072,46.539776 +8192.000000,0.004096,0.003072,80.274429 diff --git a/benchmarks/mha/triton/batch_size=1-bias=True-headdim=256-num_heads=32-blocksize_q=64-blocksize_k=128-mode=fwd.png b/benchmarks/mha/triton/batch_size=1-bias=True-headdim=256-num_heads=32-blocksize_q=64-blocksize_k=128-mode=fwd.png new file mode 100644 index 0000000..8fd41fa Binary files /dev/null and b/benchmarks/mha/triton/batch_size=1-bias=True-headdim=256-num_heads=32-blocksize_q=64-blocksize_k=128-mode=fwd.png differ diff --git a/benchmarks/mha/triton/batch_size=1-bias=True-headdim=256-num_heads=32-blocksize_q=64-blocksize_k=32-mode=fwd.csv b/benchmarks/mha/triton/batch_size=1-bias=True-headdim=256-num_heads=32-blocksize_q=64-blocksize_k=32-mode=fwd.csv new file mode 100644 index 0000000..1c7eb7f --- /dev/null +++ b/benchmarks/mha/triton/batch_size=1-bias=True-headdim=256-num_heads=32-blocksize_q=64-blocksize_k=32-mode=fwd.csv @@ -0,0 +1,6 @@ +seqlen,Triton-BlockPtr,Triton-PtrBlock,Jax +1024.000000,1.080277,1.070151,1.282946 +2048.000000,7.777146,7.711573,5.067273 +4096.000000,23.125845,21.848917,20.461887 +6144.000000,26.096979,26.323626,46.479874 +8192.000000,0.004096,0.004096,81.335617 diff --git a/benchmarks/mha/triton/batch_size=1-bias=True-headdim=256-num_heads=32-blocksize_q=64-blocksize_k=32-mode=fwd.png b/benchmarks/mha/triton/batch_size=1-bias=True-headdim=256-num_heads=32-blocksize_q=64-blocksize_k=32-mode=fwd.png new file mode 100644 index 0000000..a9156ff Binary files /dev/null and b/benchmarks/mha/triton/batch_size=1-bias=True-headdim=256-num_heads=32-blocksize_q=64-blocksize_k=32-mode=fwd.png differ diff --git a/benchmarks/mha/triton/batch_size=1-bias=True-headdim=256-num_heads=32-blocksize_q=64-blocksize_k=64-mode=fwd.csv b/benchmarks/mha/triton/batch_size=1-bias=True-headdim=256-num_heads=32-blocksize_q=64-blocksize_k=64-mode=fwd.csv new file mode 100644 index 0000000..74bae2b --- /dev/null +++ b/benchmarks/mha/triton/batch_size=1-bias=True-headdim=256-num_heads=32-blocksize_q=64-blocksize_k=64-mode=fwd.csv @@ -0,0 +1,6 @@ +seqlen,Triton-BlockPtr,Triton-PtrBlock,Jax +1024.000000,1.052016,1.037632,1.280033 +2048.000000,8.745107,8.687523,5.226451 +4096.000000,23.070105,24.140970,20.682049 +6144.000000,3.602432,3.573248,47.135506 +8192.000000,0.004096,0.003072,80.204926 diff --git a/benchmarks/mha/triton/batch_size=1-bias=True-headdim=256-num_heads=32-blocksize_q=64-blocksize_k=64-mode=fwd.png b/benchmarks/mha/triton/batch_size=1-bias=True-headdim=256-num_heads=32-blocksize_q=64-blocksize_k=64-mode=fwd.png new file mode 100644 index 0000000..304d7a8 Binary files /dev/null and b/benchmarks/mha/triton/batch_size=1-bias=True-headdim=256-num_heads=32-blocksize_q=64-blocksize_k=64-mode=fwd.png differ diff --git a/benchmarks/mha/triton/batch_size=1-bias=True-headdim=256-num_heads=8-blocksize_q=128-blocksize_k=128-mode=fwd.csv b/benchmarks/mha/triton/batch_size=1-bias=True-headdim=256-num_heads=8-blocksize_q=128-blocksize_k=128-mode=fwd.csv new file mode 100644 index 0000000..171a04d --- /dev/null +++ b/benchmarks/mha/triton/batch_size=1-bias=True-headdim=256-num_heads=8-blocksize_q=128-blocksize_k=128-mode=fwd.csv @@ -0,0 +1,6 @@ +seqlen,Triton-BlockPtr,Triton-PtrBlock,Jax +1024.000000,100.000000,100.000000,0.408813 +2048.000000,100.000000,100.000000,1.355545 +4096.000000,100.000000,100.000000,5.114824 +6144.000000,100.000000,100.000000,11.536039 +8192.000000,100.000000,100.000000,20.098846 diff --git a/benchmarks/mha/triton/batch_size=1-bias=True-headdim=256-num_heads=8-blocksize_q=128-blocksize_k=128-mode=fwd.png b/benchmarks/mha/triton/batch_size=1-bias=True-headdim=256-num_heads=8-blocksize_q=128-blocksize_k=128-mode=fwd.png new file mode 100644 index 0000000..e9286a5 Binary files /dev/null and b/benchmarks/mha/triton/batch_size=1-bias=True-headdim=256-num_heads=8-blocksize_q=128-blocksize_k=128-mode=fwd.png differ diff --git a/benchmarks/mha/triton/batch_size=1-bias=True-headdim=256-num_heads=8-blocksize_q=128-blocksize_k=32-mode=fwd.csv b/benchmarks/mha/triton/batch_size=1-bias=True-headdim=256-num_heads=8-blocksize_q=128-blocksize_k=32-mode=fwd.csv new file mode 100644 index 0000000..14f9beb --- /dev/null +++ b/benchmarks/mha/triton/batch_size=1-bias=True-headdim=256-num_heads=8-blocksize_q=128-blocksize_k=32-mode=fwd.csv @@ -0,0 +1,6 @@ +seqlen,Triton-BlockPtr,Triton-PtrBlock,Jax +1024.000000,5.189954,5.164364,0.391326 +2048.000000,7.923340,7.836021,1.347659 +4096.000000,0.453120,0.459264,5.204599 +6144.000000,0.003072,0.003072,11.610592 +8192.000000,0.003072,0.003072,20.284737 diff --git a/benchmarks/mha/triton/batch_size=1-bias=True-headdim=256-num_heads=8-blocksize_q=128-blocksize_k=32-mode=fwd.png b/benchmarks/mha/triton/batch_size=1-bias=True-headdim=256-num_heads=8-blocksize_q=128-blocksize_k=32-mode=fwd.png new file mode 100644 index 0000000..0a5de2c Binary files /dev/null and b/benchmarks/mha/triton/batch_size=1-bias=True-headdim=256-num_heads=8-blocksize_q=128-blocksize_k=32-mode=fwd.png differ diff --git a/benchmarks/mha/triton/batch_size=1-bias=True-headdim=256-num_heads=8-blocksize_q=128-blocksize_k=64-mode=fwd.csv b/benchmarks/mha/triton/batch_size=1-bias=True-headdim=256-num_heads=8-blocksize_q=128-blocksize_k=64-mode=fwd.csv new file mode 100644 index 0000000..dfa6c10 --- /dev/null +++ b/benchmarks/mha/triton/batch_size=1-bias=True-headdim=256-num_heads=8-blocksize_q=128-blocksize_k=64-mode=fwd.csv @@ -0,0 +1,6 @@ +seqlen,Triton-BlockPtr,Triton-PtrBlock,Jax +1024.000000,4.397952,4.326995,0.399478 +2048.000000,8.233827,8.205313,1.432819 +4096.000000,3.979947,3.919872,5.214045 +6144.000000,0.217088,0.003072,11.678644 +8192.000000,0.004096,0.003072,20.379984 diff --git a/benchmarks/mha/triton/batch_size=1-bias=True-headdim=256-num_heads=8-blocksize_q=128-blocksize_k=64-mode=fwd.png b/benchmarks/mha/triton/batch_size=1-bias=True-headdim=256-num_heads=8-blocksize_q=128-blocksize_k=64-mode=fwd.png new file mode 100644 index 0000000..b1e12f3 Binary files /dev/null and b/benchmarks/mha/triton/batch_size=1-bias=True-headdim=256-num_heads=8-blocksize_q=128-blocksize_k=64-mode=fwd.png differ diff --git a/benchmarks/mha/triton/batch_size=1-bias=True-headdim=256-num_heads=8-blocksize_q=32-blocksize_k=128-mode=fwd.csv b/benchmarks/mha/triton/batch_size=1-bias=True-headdim=256-num_heads=8-blocksize_q=32-blocksize_k=128-mode=fwd.csv new file mode 100644 index 0000000..74e60ef --- /dev/null +++ b/benchmarks/mha/triton/batch_size=1-bias=True-headdim=256-num_heads=8-blocksize_q=32-blocksize_k=128-mode=fwd.csv @@ -0,0 +1,6 @@ +seqlen,Triton-BlockPtr,Triton-PtrBlock,Jax +1024.000000,0.762267,0.750147,0.422252 +2048.000000,1.245054,1.197897,1.364698 +4096.000000,3.181824,3.165645,5.221550 +6144.000000,5.581369,5.661127,11.625259 +8192.000000,17.980007,6.820250,20.612408 diff --git a/benchmarks/mha/triton/batch_size=1-bias=True-headdim=256-num_heads=8-blocksize_q=32-blocksize_k=128-mode=fwd.png b/benchmarks/mha/triton/batch_size=1-bias=True-headdim=256-num_heads=8-blocksize_q=32-blocksize_k=128-mode=fwd.png new file mode 100644 index 0000000..86be61f Binary files /dev/null and b/benchmarks/mha/triton/batch_size=1-bias=True-headdim=256-num_heads=8-blocksize_q=32-blocksize_k=128-mode=fwd.png differ diff --git a/benchmarks/mha/triton/batch_size=1-bias=True-headdim=256-num_heads=8-blocksize_q=32-blocksize_k=32-mode=fwd.csv b/benchmarks/mha/triton/batch_size=1-bias=True-headdim=256-num_heads=8-blocksize_q=32-blocksize_k=32-mode=fwd.csv new file mode 100644 index 0000000..1e0d55e --- /dev/null +++ b/benchmarks/mha/triton/batch_size=1-bias=True-headdim=256-num_heads=8-blocksize_q=32-blocksize_k=32-mode=fwd.csv @@ -0,0 +1,6 @@ +seqlen,Triton-BlockPtr,Triton-PtrBlock,Jax +1024.000000,0.464293,0.439179,0.398813 +2048.000000,1.169309,1.187471,1.366740 +4096.000000,3.069688,3.095618,5.305968 +6144.000000,6.436864,6.092663,11.511468 +8192.000000,13.475043,10.831758,20.265728 diff --git a/benchmarks/mha/triton/batch_size=1-bias=True-headdim=256-num_heads=8-blocksize_q=32-blocksize_k=32-mode=fwd.png b/benchmarks/mha/triton/batch_size=1-bias=True-headdim=256-num_heads=8-blocksize_q=32-blocksize_k=32-mode=fwd.png new file mode 100644 index 0000000..fba8037 Binary files /dev/null and b/benchmarks/mha/triton/batch_size=1-bias=True-headdim=256-num_heads=8-blocksize_q=32-blocksize_k=32-mode=fwd.png differ diff --git a/benchmarks/mha/triton/batch_size=1-bias=True-headdim=256-num_heads=8-blocksize_q=32-blocksize_k=64-mode=fwd.csv b/benchmarks/mha/triton/batch_size=1-bias=True-headdim=256-num_heads=8-blocksize_q=32-blocksize_k=64-mode=fwd.csv new file mode 100644 index 0000000..2c6011b --- /dev/null +++ b/benchmarks/mha/triton/batch_size=1-bias=True-headdim=256-num_heads=8-blocksize_q=32-blocksize_k=64-mode=fwd.csv @@ -0,0 +1,6 @@ +seqlen,Triton-BlockPtr,Triton-PtrBlock,Jax +1024.000000,0.478833,0.504712,0.399730 +2048.000000,1.208812,1.233014,1.376397 +4096.000000,3.228426,2.949992,5.285520 +6144.000000,5.887843,5.949991,11.609508 +8192.000000,12.319597,12.367434,20.368456 diff --git a/benchmarks/mha/triton/batch_size=1-bias=True-headdim=256-num_heads=8-blocksize_q=32-blocksize_k=64-mode=fwd.png b/benchmarks/mha/triton/batch_size=1-bias=True-headdim=256-num_heads=8-blocksize_q=32-blocksize_k=64-mode=fwd.png new file mode 100644 index 0000000..12be61e Binary files /dev/null and b/benchmarks/mha/triton/batch_size=1-bias=True-headdim=256-num_heads=8-blocksize_q=32-blocksize_k=64-mode=fwd.png differ diff --git a/benchmarks/mha/triton/batch_size=1-bias=True-headdim=256-num_heads=8-blocksize_q=64-blocksize_k=128-mode=fwd.csv b/benchmarks/mha/triton/batch_size=1-bias=True-headdim=256-num_heads=8-blocksize_q=64-blocksize_k=128-mode=fwd.csv new file mode 100644 index 0000000..f4f94d8 --- /dev/null +++ b/benchmarks/mha/triton/batch_size=1-bias=True-headdim=256-num_heads=8-blocksize_q=64-blocksize_k=128-mode=fwd.csv @@ -0,0 +1,6 @@ +seqlen,Triton-BlockPtr,Triton-PtrBlock,Jax +1024.000000,2.944916,2.902195,0.402763 +2048.000000,3.817472,3.812523,1.367569 +4096.000000,3.394219,3.344384,5.322654 +6144.000000,0.003072,0.116736,11.571480 +8192.000000,0.003072,0.003072,20.383240 diff --git a/benchmarks/mha/triton/batch_size=1-bias=True-headdim=256-num_heads=8-blocksize_q=64-blocksize_k=128-mode=fwd.png b/benchmarks/mha/triton/batch_size=1-bias=True-headdim=256-num_heads=8-blocksize_q=64-blocksize_k=128-mode=fwd.png new file mode 100644 index 0000000..ed77e4b Binary files /dev/null and b/benchmarks/mha/triton/batch_size=1-bias=True-headdim=256-num_heads=8-blocksize_q=64-blocksize_k=128-mode=fwd.png differ diff --git a/benchmarks/mha/triton/batch_size=1-bias=True-headdim=256-num_heads=8-blocksize_q=64-blocksize_k=32-mode=fwd.csv b/benchmarks/mha/triton/batch_size=1-bias=True-headdim=256-num_heads=8-blocksize_q=64-blocksize_k=32-mode=fwd.csv new file mode 100644 index 0000000..a443792 --- /dev/null +++ b/benchmarks/mha/triton/batch_size=1-bias=True-headdim=256-num_heads=8-blocksize_q=64-blocksize_k=32-mode=fwd.csv @@ -0,0 +1,6 @@ +seqlen,Triton-BlockPtr,Triton-PtrBlock,Jax +1024.000000,0.605080,0.534350,0.417140 +2048.000000,1.549294,1.525876,1.371455 +4096.000000,3.491560,3.459157,5.380148 +6144.000000,6.330181,6.282985,11.807929 +8192.000000,7.697749,11.263489,20.547647 diff --git a/benchmarks/mha/triton/batch_size=1-bias=True-headdim=256-num_heads=8-blocksize_q=64-blocksize_k=32-mode=fwd.png b/benchmarks/mha/triton/batch_size=1-bias=True-headdim=256-num_heads=8-blocksize_q=64-blocksize_k=32-mode=fwd.png new file mode 100644 index 0000000..363b5bf Binary files /dev/null and b/benchmarks/mha/triton/batch_size=1-bias=True-headdim=256-num_heads=8-blocksize_q=64-blocksize_k=32-mode=fwd.png differ diff --git a/benchmarks/mha/triton/batch_size=1-bias=True-headdim=256-num_heads=8-blocksize_q=64-blocksize_k=64-mode=fwd.csv b/benchmarks/mha/triton/batch_size=1-bias=True-headdim=256-num_heads=8-blocksize_q=64-blocksize_k=64-mode=fwd.csv new file mode 100644 index 0000000..7dbf0e8 --- /dev/null +++ b/benchmarks/mha/triton/batch_size=1-bias=True-headdim=256-num_heads=8-blocksize_q=64-blocksize_k=64-mode=fwd.csv @@ -0,0 +1,6 @@ +seqlen,Triton-BlockPtr,Triton-PtrBlock,Jax +1024.000000,0.565877,0.630195,0.416326 +2048.000000,1.450925,1.375041,1.366572 +4096.000000,3.452975,3.360628,5.296370 +6144.000000,6.097715,6.055834,11.776964 +8192.000000,7.642965,7.594326,20.600969 diff --git a/benchmarks/mha/triton/batch_size=1-bias=True-headdim=256-num_heads=8-blocksize_q=64-blocksize_k=64-mode=fwd.png b/benchmarks/mha/triton/batch_size=1-bias=True-headdim=256-num_heads=8-blocksize_q=64-blocksize_k=64-mode=fwd.png new file mode 100644 index 0000000..3d07ae2 Binary files /dev/null and b/benchmarks/mha/triton/batch_size=1-bias=True-headdim=256-num_heads=8-blocksize_q=64-blocksize_k=64-mode=fwd.png differ diff --git a/benchmarks/mha/triton/batch_size=1-bias=True-headdim=64-num_heads=16-blocksize_q=128-blocksize_k=128-mode=bwd.csv b/benchmarks/mha/triton/batch_size=1-bias=True-headdim=64-num_heads=16-blocksize_q=128-blocksize_k=128-mode=bwd.csv new file mode 100644 index 0000000..4b27dca --- /dev/null +++ b/benchmarks/mha/triton/batch_size=1-bias=True-headdim=64-num_heads=16-blocksize_q=128-blocksize_k=128-mode=bwd.csv @@ -0,0 +1,6 @@ +seqlen,Triton-BlockPtr,Triton-PtrBlock,Jax +1024.000000,62.163967,63.096832,1.729190 +2048.000000,60.447742,59.672577,4.540411 +4096.000000,50.568192,55.342079,16.777344 +6144.000000,44.912640,42.561535,37.622787 +8192.000000,51.596287,46.752769,66.763107 diff --git a/benchmarks/mha/triton/batch_size=1-bias=True-headdim=64-num_heads=16-blocksize_q=128-blocksize_k=128-mode=bwd.png b/benchmarks/mha/triton/batch_size=1-bias=True-headdim=64-num_heads=16-blocksize_q=128-blocksize_k=128-mode=bwd.png new file mode 100644 index 0000000..7f4fef3 Binary files /dev/null and b/benchmarks/mha/triton/batch_size=1-bias=True-headdim=64-num_heads=16-blocksize_q=128-blocksize_k=128-mode=bwd.png differ diff --git a/benchmarks/mha/triton/batch_size=1-bias=True-headdim=64-num_heads=16-blocksize_q=128-blocksize_k=128-mode=fwd.csv b/benchmarks/mha/triton/batch_size=1-bias=True-headdim=64-num_heads=16-blocksize_q=128-blocksize_k=128-mode=fwd.csv new file mode 100644 index 0000000..2e6c27b --- /dev/null +++ b/benchmarks/mha/triton/batch_size=1-bias=True-headdim=64-num_heads=16-blocksize_q=128-blocksize_k=128-mode=fwd.csv @@ -0,0 +1,6 @@ +seqlen,Triton-BlockPtr,Triton-PtrBlock,Jax +1024.000000,3.831901,3.735138,0.585130 +2048.000000,4.679241,4.660005,2.219186 +4096.000000,6.753024,6.374912,8.323977 +6144.000000,1.835008,1.773056,18.588097 +8192.000000,0.003072,0.004096,33.348949 diff --git a/benchmarks/mha/triton/batch_size=1-bias=True-headdim=64-num_heads=16-blocksize_q=128-blocksize_k=128-mode=fwd.png b/benchmarks/mha/triton/batch_size=1-bias=True-headdim=64-num_heads=16-blocksize_q=128-blocksize_k=128-mode=fwd.png new file mode 100644 index 0000000..612d7d9 Binary files /dev/null and b/benchmarks/mha/triton/batch_size=1-bias=True-headdim=64-num_heads=16-blocksize_q=128-blocksize_k=128-mode=fwd.png differ diff --git a/benchmarks/mha/triton/batch_size=1-bias=True-headdim=64-num_heads=16-blocksize_q=128-blocksize_k=32-mode=bwd.csv b/benchmarks/mha/triton/batch_size=1-bias=True-headdim=64-num_heads=16-blocksize_q=128-blocksize_k=32-mode=bwd.csv new file mode 100644 index 0000000..13acc23 --- /dev/null +++ b/benchmarks/mha/triton/batch_size=1-bias=True-headdim=64-num_heads=16-blocksize_q=128-blocksize_k=32-mode=bwd.csv @@ -0,0 +1,6 @@ +seqlen,Triton-BlockPtr,Triton-PtrBlock,Jax +1024.000000,43.308029,43.376640,1.689746 +2048.000000,42.080769,41.960960,4.486260 +4096.000000,35.450882,36.271614,16.780582 +6144.000000,24.922112,25.027073,39.112190 +8192.000000,8.932352,8.408064,68.649216 diff --git a/benchmarks/mha/triton/batch_size=1-bias=True-headdim=64-num_heads=16-blocksize_q=128-blocksize_k=32-mode=bwd.png b/benchmarks/mha/triton/batch_size=1-bias=True-headdim=64-num_heads=16-blocksize_q=128-blocksize_k=32-mode=bwd.png new file mode 100644 index 0000000..0d1b464 Binary files /dev/null and b/benchmarks/mha/triton/batch_size=1-bias=True-headdim=64-num_heads=16-blocksize_q=128-blocksize_k=32-mode=bwd.png differ diff --git a/benchmarks/mha/triton/batch_size=1-bias=True-headdim=64-num_heads=16-blocksize_q=128-blocksize_k=32-mode=fwd.csv b/benchmarks/mha/triton/batch_size=1-bias=True-headdim=64-num_heads=16-blocksize_q=128-blocksize_k=32-mode=fwd.csv new file mode 100644 index 0000000..642077b --- /dev/null +++ b/benchmarks/mha/triton/batch_size=1-bias=True-headdim=64-num_heads=16-blocksize_q=128-blocksize_k=32-mode=fwd.csv @@ -0,0 +1,6 @@ +seqlen,Triton-BlockPtr,Triton-PtrBlock,Jax +1024.000000,0.613093,0.599950,0.595466 +2048.000000,1.648303,1.665210,2.244574 +4096.000000,5.320917,5.406336,8.255507 +6144.000000,10.563584,10.496256,18.506752 +8192.000000,13.491200,13.719844,32.638046 diff --git a/benchmarks/mha/triton/batch_size=1-bias=True-headdim=64-num_heads=16-blocksize_q=128-blocksize_k=32-mode=fwd.png b/benchmarks/mha/triton/batch_size=1-bias=True-headdim=64-num_heads=16-blocksize_q=128-blocksize_k=32-mode=fwd.png new file mode 100644 index 0000000..d666d1a Binary files /dev/null and b/benchmarks/mha/triton/batch_size=1-bias=True-headdim=64-num_heads=16-blocksize_q=128-blocksize_k=32-mode=fwd.png differ diff --git a/benchmarks/mha/triton/batch_size=1-bias=True-headdim=64-num_heads=16-blocksize_q=128-blocksize_k=64-mode=bwd.csv b/benchmarks/mha/triton/batch_size=1-bias=True-headdim=64-num_heads=16-blocksize_q=128-blocksize_k=64-mode=bwd.csv new file mode 100644 index 0000000..9b4cc3d --- /dev/null +++ b/benchmarks/mha/triton/batch_size=1-bias=True-headdim=64-num_heads=16-blocksize_q=128-blocksize_k=64-mode=bwd.csv @@ -0,0 +1,6 @@ +seqlen,Triton-BlockPtr,Triton-PtrBlock,Jax +1024.000000,53.300224,52.809727,1.753490 +2048.000000,52.349953,52.049919,4.460530 +4096.000000,54.665215,52.781055,16.821861 +6144.000000,37.277695,38.212608,37.349678 +8192.000000,42.204159,42.647552,66.748573 diff --git a/benchmarks/mha/triton/batch_size=1-bias=True-headdim=64-num_heads=16-blocksize_q=128-blocksize_k=64-mode=bwd.png b/benchmarks/mha/triton/batch_size=1-bias=True-headdim=64-num_heads=16-blocksize_q=128-blocksize_k=64-mode=bwd.png new file mode 100644 index 0000000..50e0ffc Binary files /dev/null and b/benchmarks/mha/triton/batch_size=1-bias=True-headdim=64-num_heads=16-blocksize_q=128-blocksize_k=64-mode=bwd.png differ diff --git a/benchmarks/mha/triton/batch_size=1-bias=True-headdim=64-num_heads=16-blocksize_q=128-blocksize_k=64-mode=fwd.csv b/benchmarks/mha/triton/batch_size=1-bias=True-headdim=64-num_heads=16-blocksize_q=128-blocksize_k=64-mode=fwd.csv new file mode 100644 index 0000000..b9f37c2 --- /dev/null +++ b/benchmarks/mha/triton/batch_size=1-bias=True-headdim=64-num_heads=16-blocksize_q=128-blocksize_k=64-mode=fwd.csv @@ -0,0 +1,6 @@ +seqlen,Triton-BlockPtr,Triton-PtrBlock,Jax +1024.000000,2.330271,2.296154,0.569824 +2048.000000,3.311849,3.394607,2.218350 +4096.000000,6.228821,5.881174,8.217309 +6144.000000,6.549504,1.879552,18.626011 +8192.000000,0.003072,0.003072,32.668377 diff --git a/benchmarks/mha/triton/batch_size=1-bias=True-headdim=64-num_heads=16-blocksize_q=128-blocksize_k=64-mode=fwd.png b/benchmarks/mha/triton/batch_size=1-bias=True-headdim=64-num_heads=16-blocksize_q=128-blocksize_k=64-mode=fwd.png new file mode 100644 index 0000000..a6816f0 Binary files /dev/null and b/benchmarks/mha/triton/batch_size=1-bias=True-headdim=64-num_heads=16-blocksize_q=128-blocksize_k=64-mode=fwd.png differ diff --git a/benchmarks/mha/triton/batch_size=1-bias=True-headdim=64-num_heads=16-blocksize_q=32-blocksize_k=128-mode=bwd.csv b/benchmarks/mha/triton/batch_size=1-bias=True-headdim=64-num_heads=16-blocksize_q=32-blocksize_k=128-mode=bwd.csv new file mode 100644 index 0000000..3e08402 --- /dev/null +++ b/benchmarks/mha/triton/batch_size=1-bias=True-headdim=64-num_heads=16-blocksize_q=32-blocksize_k=128-mode=bwd.csv @@ -0,0 +1,6 @@ +seqlen,Triton-BlockPtr,Triton-PtrBlock,Jax +1024.000000,42.927105,42.636803,1.738877 +2048.000000,42.538498,42.691071,4.392402 +4096.000000,38.150658,38.902275,17.399866 +6144.000000,33.630722,33.540096,38.142563 +8192.000000,31.664639,32.194046,66.590851 diff --git a/benchmarks/mha/triton/batch_size=1-bias=True-headdim=64-num_heads=16-blocksize_q=32-blocksize_k=128-mode=bwd.png b/benchmarks/mha/triton/batch_size=1-bias=True-headdim=64-num_heads=16-blocksize_q=32-blocksize_k=128-mode=bwd.png new file mode 100644 index 0000000..c8db53c Binary files /dev/null and b/benchmarks/mha/triton/batch_size=1-bias=True-headdim=64-num_heads=16-blocksize_q=32-blocksize_k=128-mode=bwd.png differ diff --git a/benchmarks/mha/triton/batch_size=1-bias=True-headdim=64-num_heads=16-blocksize_q=32-blocksize_k=128-mode=fwd.csv b/benchmarks/mha/triton/batch_size=1-bias=True-headdim=64-num_heads=16-blocksize_q=32-blocksize_k=128-mode=fwd.csv new file mode 100644 index 0000000..4cdf998 --- /dev/null +++ b/benchmarks/mha/triton/batch_size=1-bias=True-headdim=64-num_heads=16-blocksize_q=32-blocksize_k=128-mode=fwd.csv @@ -0,0 +1,6 @@ +seqlen,Triton-BlockPtr,Triton-PtrBlock,Jax +1024.000000,0.504862,0.494933,0.586700 +2048.000000,1.612068,1.638999,2.234193 +4096.000000,4.973824,5.077852,8.239625 +6144.000000,9.955157,10.071211,18.509209 +8192.000000,12.850615,13.077504,32.731594 diff --git a/benchmarks/mha/triton/batch_size=1-bias=True-headdim=64-num_heads=16-blocksize_q=32-blocksize_k=128-mode=fwd.png b/benchmarks/mha/triton/batch_size=1-bias=True-headdim=64-num_heads=16-blocksize_q=32-blocksize_k=128-mode=fwd.png new file mode 100644 index 0000000..6791bc1 Binary files /dev/null and b/benchmarks/mha/triton/batch_size=1-bias=True-headdim=64-num_heads=16-blocksize_q=32-blocksize_k=128-mode=fwd.png differ diff --git a/benchmarks/mha/triton/batch_size=1-bias=True-headdim=64-num_heads=16-blocksize_q=32-blocksize_k=32-mode=bwd.csv b/benchmarks/mha/triton/batch_size=1-bias=True-headdim=64-num_heads=16-blocksize_q=32-blocksize_k=32-mode=bwd.csv new file mode 100644 index 0000000..2e2c9f9 --- /dev/null +++ b/benchmarks/mha/triton/batch_size=1-bias=True-headdim=64-num_heads=16-blocksize_q=32-blocksize_k=32-mode=bwd.csv @@ -0,0 +1,6 @@ +seqlen,Triton-BlockPtr,Triton-PtrBlock,Jax +1024.000000,40.612865,40.147457,1.699963 +2048.000000,37.400574,37.740028,4.427035 +4096.000000,30.267904,30.043137,16.745203 +6144.000000,29.196800,28.625919,37.960255 +8192.000000,0.002048,0.001024,67.354622 diff --git a/benchmarks/mha/triton/batch_size=1-bias=True-headdim=64-num_heads=16-blocksize_q=32-blocksize_k=32-mode=bwd.png b/benchmarks/mha/triton/batch_size=1-bias=True-headdim=64-num_heads=16-blocksize_q=32-blocksize_k=32-mode=bwd.png new file mode 100644 index 0000000..a193725 Binary files /dev/null and b/benchmarks/mha/triton/batch_size=1-bias=True-headdim=64-num_heads=16-blocksize_q=32-blocksize_k=32-mode=bwd.png differ diff --git a/benchmarks/mha/triton/batch_size=1-bias=True-headdim=64-num_heads=16-blocksize_q=32-blocksize_k=32-mode=fwd.csv b/benchmarks/mha/triton/batch_size=1-bias=True-headdim=64-num_heads=16-blocksize_q=32-blocksize_k=32-mode=fwd.csv new file mode 100644 index 0000000..918f273 --- /dev/null +++ b/benchmarks/mha/triton/batch_size=1-bias=True-headdim=64-num_heads=16-blocksize_q=32-blocksize_k=32-mode=fwd.csv @@ -0,0 +1,6 @@ +seqlen,Triton-BlockPtr,Triton-PtrBlock,Jax +1024.000000,0.536712,0.555312,0.589842 +2048.000000,1.632353,1.616723,2.174691 +4096.000000,4.805317,4.841039,8.331352 +6144.000000,10.009770,10.014464,18.563822 +8192.000000,13.030985,13.155328,32.861183 diff --git a/benchmarks/mha/triton/batch_size=1-bias=True-headdim=64-num_heads=16-blocksize_q=32-blocksize_k=32-mode=fwd.png b/benchmarks/mha/triton/batch_size=1-bias=True-headdim=64-num_heads=16-blocksize_q=32-blocksize_k=32-mode=fwd.png new file mode 100644 index 0000000..b76c114 Binary files /dev/null and b/benchmarks/mha/triton/batch_size=1-bias=True-headdim=64-num_heads=16-blocksize_q=32-blocksize_k=32-mode=fwd.png differ diff --git a/benchmarks/mha/triton/batch_size=1-bias=True-headdim=64-num_heads=16-blocksize_q=32-blocksize_k=64-mode=bwd.csv b/benchmarks/mha/triton/batch_size=1-bias=True-headdim=64-num_heads=16-blocksize_q=32-blocksize_k=64-mode=bwd.csv new file mode 100644 index 0000000..e72f0af --- /dev/null +++ b/benchmarks/mha/triton/batch_size=1-bias=True-headdim=64-num_heads=16-blocksize_q=32-blocksize_k=64-mode=bwd.csv @@ -0,0 +1,6 @@ +seqlen,Triton-BlockPtr,Triton-PtrBlock,Jax +1024.000000,42.444290,41.509888,1.718728 +2048.000000,40.276993,40.166912,4.515706 +4096.000000,35.680767,35.621887,16.926937 +6144.000000,29.127682,29.601791,38.062286 +8192.000000,29.700096,29.094398,66.920418 diff --git a/benchmarks/mha/triton/batch_size=1-bias=True-headdim=64-num_heads=16-blocksize_q=32-blocksize_k=64-mode=bwd.png b/benchmarks/mha/triton/batch_size=1-bias=True-headdim=64-num_heads=16-blocksize_q=32-blocksize_k=64-mode=bwd.png new file mode 100644 index 0000000..ce42f56 Binary files /dev/null and b/benchmarks/mha/triton/batch_size=1-bias=True-headdim=64-num_heads=16-blocksize_q=32-blocksize_k=64-mode=bwd.png differ diff --git a/benchmarks/mha/triton/batch_size=1-bias=True-headdim=64-num_heads=16-blocksize_q=32-blocksize_k=64-mode=fwd.csv b/benchmarks/mha/triton/batch_size=1-bias=True-headdim=64-num_heads=16-blocksize_q=32-blocksize_k=64-mode=fwd.csv new file mode 100644 index 0000000..cf1ea45 --- /dev/null +++ b/benchmarks/mha/triton/batch_size=1-bias=True-headdim=64-num_heads=16-blocksize_q=32-blocksize_k=64-mode=fwd.csv @@ -0,0 +1,6 @@ +seqlen,Triton-BlockPtr,Triton-PtrBlock,Jax +1024.000000,0.668058,0.517925,0.581954 +2048.000000,1.628086,1.623106,2.243758 +4096.000000,5.070480,5.082645,8.424363 +6144.000000,10.031787,9.945003,18.806171 +8192.000000,12.922880,13.030692,33.398266 diff --git a/benchmarks/mha/triton/batch_size=1-bias=True-headdim=64-num_heads=16-blocksize_q=32-blocksize_k=64-mode=fwd.png b/benchmarks/mha/triton/batch_size=1-bias=True-headdim=64-num_heads=16-blocksize_q=32-blocksize_k=64-mode=fwd.png new file mode 100644 index 0000000..02acead Binary files /dev/null and b/benchmarks/mha/triton/batch_size=1-bias=True-headdim=64-num_heads=16-blocksize_q=32-blocksize_k=64-mode=fwd.png differ diff --git a/benchmarks/mha/triton/batch_size=1-bias=True-headdim=64-num_heads=16-blocksize_q=64-blocksize_k=128-mode=bwd.csv b/benchmarks/mha/triton/batch_size=1-bias=True-headdim=64-num_heads=16-blocksize_q=64-blocksize_k=128-mode=bwd.csv new file mode 100644 index 0000000..6ab259e --- /dev/null +++ b/benchmarks/mha/triton/batch_size=1-bias=True-headdim=64-num_heads=16-blocksize_q=64-blocksize_k=128-mode=bwd.csv @@ -0,0 +1,6 @@ +seqlen,Triton-BlockPtr,Triton-PtrBlock,Jax +1024.000000,53.095425,53.192703,1.703559 +2048.000000,51.681278,51.619839,4.377183 +4096.000000,44.768257,45.523968,16.896826 +6144.000000,41.021439,45.626369,38.421501 +8192.000000,35.258369,36.006912,67.212288 diff --git a/benchmarks/mha/triton/batch_size=1-bias=True-headdim=64-num_heads=16-blocksize_q=64-blocksize_k=128-mode=bwd.png b/benchmarks/mha/triton/batch_size=1-bias=True-headdim=64-num_heads=16-blocksize_q=64-blocksize_k=128-mode=bwd.png new file mode 100644 index 0000000..55f0506 Binary files /dev/null and b/benchmarks/mha/triton/batch_size=1-bias=True-headdim=64-num_heads=16-blocksize_q=64-blocksize_k=128-mode=bwd.png differ diff --git a/benchmarks/mha/triton/batch_size=1-bias=True-headdim=64-num_heads=16-blocksize_q=64-blocksize_k=128-mode=fwd.csv b/benchmarks/mha/triton/batch_size=1-bias=True-headdim=64-num_heads=16-blocksize_q=64-blocksize_k=128-mode=fwd.csv new file mode 100644 index 0000000..b46a613 --- /dev/null +++ b/benchmarks/mha/triton/batch_size=1-bias=True-headdim=64-num_heads=16-blocksize_q=64-blocksize_k=128-mode=fwd.csv @@ -0,0 +1,6 @@ +seqlen,Triton-BlockPtr,Triton-PtrBlock,Jax +1024.000000,1.282506,1.256411,0.581872 +2048.000000,2.277905,2.304480,2.144217 +4096.000000,5.839872,5.845675,8.240283 +6144.000000,7.472896,7.116288,17.953644 +8192.000000,2.876416,2.995200,33.272491 diff --git a/benchmarks/mha/triton/batch_size=1-bias=True-headdim=64-num_heads=16-blocksize_q=64-blocksize_k=128-mode=fwd.png b/benchmarks/mha/triton/batch_size=1-bias=True-headdim=64-num_heads=16-blocksize_q=64-blocksize_k=128-mode=fwd.png new file mode 100644 index 0000000..c72a8b1 Binary files /dev/null and b/benchmarks/mha/triton/batch_size=1-bias=True-headdim=64-num_heads=16-blocksize_q=64-blocksize_k=128-mode=fwd.png differ diff --git a/benchmarks/mha/triton/batch_size=1-bias=True-headdim=64-num_heads=16-blocksize_q=64-blocksize_k=32-mode=bwd.csv b/benchmarks/mha/triton/batch_size=1-bias=True-headdim=64-num_heads=16-blocksize_q=64-blocksize_k=32-mode=bwd.csv new file mode 100644 index 0000000..bbc050b --- /dev/null +++ b/benchmarks/mha/triton/batch_size=1-bias=True-headdim=64-num_heads=16-blocksize_q=64-blocksize_k=32-mode=bwd.csv @@ -0,0 +1,6 @@ +seqlen,Triton-BlockPtr,Triton-PtrBlock,Jax +1024.000000,41.562111,41.410049,1.678567 +2048.000000,39.598591,39.300606,4.428848 +4096.000000,32.560127,33.178623,16.733448 +6144.000000,21.860865,22.489088,37.511681 +8192.000000,5.438464,6.223872,67.443710 diff --git a/benchmarks/mha/triton/batch_size=1-bias=True-headdim=64-num_heads=16-blocksize_q=64-blocksize_k=32-mode=bwd.png b/benchmarks/mha/triton/batch_size=1-bias=True-headdim=64-num_heads=16-blocksize_q=64-blocksize_k=32-mode=bwd.png new file mode 100644 index 0000000..e059177 Binary files /dev/null and b/benchmarks/mha/triton/batch_size=1-bias=True-headdim=64-num_heads=16-blocksize_q=64-blocksize_k=32-mode=bwd.png differ diff --git a/benchmarks/mha/triton/batch_size=1-bias=True-headdim=64-num_heads=16-blocksize_q=64-blocksize_k=32-mode=fwd.csv b/benchmarks/mha/triton/batch_size=1-bias=True-headdim=64-num_heads=16-blocksize_q=64-blocksize_k=32-mode=fwd.csv new file mode 100644 index 0000000..137a58e --- /dev/null +++ b/benchmarks/mha/triton/batch_size=1-bias=True-headdim=64-num_heads=16-blocksize_q=64-blocksize_k=32-mode=fwd.csv @@ -0,0 +1,6 @@ +seqlen,Triton-BlockPtr,Triton-PtrBlock,Jax +1024.000000,0.494958,0.520886,0.563141 +2048.000000,1.603912,1.595264,2.304609 +4096.000000,5.210269,5.081006,8.348887 +6144.000000,10.846806,10.662143,18.199968 +8192.000000,13.650065,13.740617,35.489792 diff --git a/benchmarks/mha/triton/batch_size=1-bias=True-headdim=64-num_heads=16-blocksize_q=64-blocksize_k=32-mode=fwd.png b/benchmarks/mha/triton/batch_size=1-bias=True-headdim=64-num_heads=16-blocksize_q=64-blocksize_k=32-mode=fwd.png new file mode 100644 index 0000000..02e9bd3 Binary files /dev/null and b/benchmarks/mha/triton/batch_size=1-bias=True-headdim=64-num_heads=16-blocksize_q=64-blocksize_k=32-mode=fwd.png differ diff --git a/benchmarks/mha/triton/batch_size=1-bias=True-headdim=64-num_heads=16-blocksize_q=64-blocksize_k=64-mode=bwd.csv b/benchmarks/mha/triton/batch_size=1-bias=True-headdim=64-num_heads=16-blocksize_q=64-blocksize_k=64-mode=bwd.csv new file mode 100644 index 0000000..9a394bf --- /dev/null +++ b/benchmarks/mha/triton/batch_size=1-bias=True-headdim=64-num_heads=16-blocksize_q=64-blocksize_k=64-mode=bwd.csv @@ -0,0 +1,6 @@ +seqlen,Triton-BlockPtr,Triton-PtrBlock,Jax +1024.000000,42.757118,42.887169,1.722554 +2048.000000,41.562111,42.292225,4.461763 +4096.000000,38.367233,38.758400,16.909210 +6144.000000,32.023552,32.134655,37.554176 +8192.000000,32.805889,32.974335,67.558403 diff --git a/benchmarks/mha/triton/batch_size=1-bias=True-headdim=64-num_heads=16-blocksize_q=64-blocksize_k=64-mode=bwd.png b/benchmarks/mha/triton/batch_size=1-bias=True-headdim=64-num_heads=16-blocksize_q=64-blocksize_k=64-mode=bwd.png new file mode 100644 index 0000000..fcc73c9 Binary files /dev/null and b/benchmarks/mha/triton/batch_size=1-bias=True-headdim=64-num_heads=16-blocksize_q=64-blocksize_k=64-mode=bwd.png differ diff --git a/benchmarks/mha/triton/batch_size=1-bias=True-headdim=64-num_heads=16-blocksize_q=64-blocksize_k=64-mode=fwd.csv b/benchmarks/mha/triton/batch_size=1-bias=True-headdim=64-num_heads=16-blocksize_q=64-blocksize_k=64-mode=fwd.csv new file mode 100644 index 0000000..2397920 --- /dev/null +++ b/benchmarks/mha/triton/batch_size=1-bias=True-headdim=64-num_heads=16-blocksize_q=64-blocksize_k=64-mode=fwd.csv @@ -0,0 +1,6 @@ +seqlen,Triton-BlockPtr,Triton-PtrBlock,Jax +1024.000000,0.538838,0.547866,0.577958 +2048.000000,1.562676,1.609285,2.230815 +4096.000000,5.121925,5.172429,8.382421 +6144.000000,10.061194,10.032522,17.637869 +8192.000000,13.038300,13.092864,33.030888 diff --git a/benchmarks/mha/triton/batch_size=1-bias=True-headdim=64-num_heads=16-blocksize_q=64-blocksize_k=64-mode=fwd.png b/benchmarks/mha/triton/batch_size=1-bias=True-headdim=64-num_heads=16-blocksize_q=64-blocksize_k=64-mode=fwd.png new file mode 100644 index 0000000..b487728 Binary files /dev/null and b/benchmarks/mha/triton/batch_size=1-bias=True-headdim=64-num_heads=16-blocksize_q=64-blocksize_k=64-mode=fwd.png differ diff --git a/benchmarks/mha/triton/batch_size=1-bias=True-headdim=64-num_heads=32-blocksize_q=128-blocksize_k=128-mode=fwd.csv b/benchmarks/mha/triton/batch_size=1-bias=True-headdim=64-num_heads=32-blocksize_q=128-blocksize_k=128-mode=fwd.csv new file mode 100644 index 0000000..bcd41cb --- /dev/null +++ b/benchmarks/mha/triton/batch_size=1-bias=True-headdim=64-num_heads=32-blocksize_q=128-blocksize_k=128-mode=fwd.csv @@ -0,0 +1,6 @@ +seqlen,Triton-BlockPtr,Triton-PtrBlock,Jax +1024.000000,2.611364,2.712261,1.057549 +2048.000000,5.573888,5.672193,4.352677 +4096.000000,1.652224,1.624576,16.344543 +6144.000000,0.003072,0.003072,36.903564 +8192.000000,0.003072,0.003072,65.189888 diff --git a/benchmarks/mha/triton/batch_size=1-bias=True-headdim=64-num_heads=32-blocksize_q=128-blocksize_k=128-mode=fwd.png b/benchmarks/mha/triton/batch_size=1-bias=True-headdim=64-num_heads=32-blocksize_q=128-blocksize_k=128-mode=fwd.png new file mode 100644 index 0000000..8a197cb Binary files /dev/null and b/benchmarks/mha/triton/batch_size=1-bias=True-headdim=64-num_heads=32-blocksize_q=128-blocksize_k=128-mode=fwd.png differ diff --git a/benchmarks/mha/triton/batch_size=1-bias=True-headdim=64-num_heads=32-blocksize_q=128-blocksize_k=32-mode=fwd.csv b/benchmarks/mha/triton/batch_size=1-bias=True-headdim=64-num_heads=32-blocksize_q=128-blocksize_k=32-mode=fwd.csv new file mode 100644 index 0000000..6ebadc1 --- /dev/null +++ b/benchmarks/mha/triton/batch_size=1-bias=True-headdim=64-num_heads=32-blocksize_q=128-blocksize_k=32-mode=fwd.csv @@ -0,0 +1,6 @@ +seqlen,Triton-BlockPtr,Triton-PtrBlock,Jax +1024.000000,0.709110,0.709915,1.039278 +2048.000000,2.858897,2.728141,4.151845 +4096.000000,9.711064,9.728630,16.044199 +6144.000000,13.798229,13.682518,36.519535 +8192.000000,14.633301,14.592341,65.483772 diff --git a/benchmarks/mha/triton/batch_size=1-bias=True-headdim=64-num_heads=32-blocksize_q=128-blocksize_k=32-mode=fwd.png b/benchmarks/mha/triton/batch_size=1-bias=True-headdim=64-num_heads=32-blocksize_q=128-blocksize_k=32-mode=fwd.png new file mode 100644 index 0000000..207af97 Binary files /dev/null and b/benchmarks/mha/triton/batch_size=1-bias=True-headdim=64-num_heads=32-blocksize_q=128-blocksize_k=32-mode=fwd.png differ diff --git a/benchmarks/mha/triton/batch_size=1-bias=True-headdim=64-num_heads=32-blocksize_q=128-blocksize_k=64-mode=fwd.csv b/benchmarks/mha/triton/batch_size=1-bias=True-headdim=64-num_heads=32-blocksize_q=128-blocksize_k=64-mode=fwd.csv new file mode 100644 index 0000000..f9a65b7 --- /dev/null +++ b/benchmarks/mha/triton/batch_size=1-bias=True-headdim=64-num_heads=32-blocksize_q=128-blocksize_k=64-mode=fwd.csv @@ -0,0 +1,6 @@ +seqlen,Triton-BlockPtr,Triton-PtrBlock,Jax +1024.000000,1.864566,1.870797,1.053446 +2048.000000,4.448085,4.585814,4.284101 +4096.000000,5.730646,5.678421,16.478708 +6144.000000,0.004096,0.004096,36.579857 +8192.000000,0.003072,0.003072,65.200127 diff --git a/benchmarks/mha/triton/batch_size=1-bias=True-headdim=64-num_heads=32-blocksize_q=128-blocksize_k=64-mode=fwd.png b/benchmarks/mha/triton/batch_size=1-bias=True-headdim=64-num_heads=32-blocksize_q=128-blocksize_k=64-mode=fwd.png new file mode 100644 index 0000000..bce1068 Binary files /dev/null and b/benchmarks/mha/triton/batch_size=1-bias=True-headdim=64-num_heads=32-blocksize_q=128-blocksize_k=64-mode=fwd.png differ diff --git a/benchmarks/mha/triton/batch_size=1-bias=True-headdim=64-num_heads=32-blocksize_q=32-blocksize_k=128-mode=fwd.csv b/benchmarks/mha/triton/batch_size=1-bias=True-headdim=64-num_heads=32-blocksize_q=32-blocksize_k=128-mode=fwd.csv new file mode 100644 index 0000000..ba078b8 --- /dev/null +++ b/benchmarks/mha/triton/batch_size=1-bias=True-headdim=64-num_heads=32-blocksize_q=32-blocksize_k=128-mode=fwd.csv @@ -0,0 +1,6 @@ +seqlen,Triton-BlockPtr,Triton-PtrBlock,Jax +1024.000000,0.752174,0.710107,1.044432 +2048.000000,2.752262,2.601396,4.194279 +4096.000000,9.281141,9.248453,15.265447 +6144.000000,13.626369,13.578239,36.903984 +8192.000000,14.436353,14.499497,65.118210 diff --git a/benchmarks/mha/triton/batch_size=1-bias=True-headdim=64-num_heads=32-blocksize_q=32-blocksize_k=128-mode=fwd.png b/benchmarks/mha/triton/batch_size=1-bias=True-headdim=64-num_heads=32-blocksize_q=32-blocksize_k=128-mode=fwd.png new file mode 100644 index 0000000..17cd870 Binary files /dev/null and b/benchmarks/mha/triton/batch_size=1-bias=True-headdim=64-num_heads=32-blocksize_q=32-blocksize_k=128-mode=fwd.png differ diff --git a/benchmarks/mha/triton/batch_size=1-bias=True-headdim=64-num_heads=32-blocksize_q=32-blocksize_k=32-mode=bwd.csv b/benchmarks/mha/triton/batch_size=1-bias=True-headdim=64-num_heads=32-blocksize_q=32-blocksize_k=32-mode=bwd.csv new file mode 100644 index 0000000..b3afdb1 --- /dev/null +++ b/benchmarks/mha/triton/batch_size=1-bias=True-headdim=64-num_heads=32-blocksize_q=32-blocksize_k=32-mode=bwd.csv @@ -0,0 +1,6 @@ +seqlen,Triton-BlockPtr,Triton-PtrBlock,Jax +1024.000000,41.367554,41.297409,2.387162 +2048.000000,35.521027,36.499458,8.334222 +4096.000000,21.349377,21.473793,32.083626 +6144.000000,18.685440,19.066881,72.701950 +8192.000000,0.002048,0.002048,100.000000 diff --git a/benchmarks/mha/triton/batch_size=1-bias=True-headdim=64-num_heads=32-blocksize_q=32-blocksize_k=32-mode=bwd.png b/benchmarks/mha/triton/batch_size=1-bias=True-headdim=64-num_heads=32-blocksize_q=32-blocksize_k=32-mode=bwd.png new file mode 100644 index 0000000..5c81f0f Binary files /dev/null and b/benchmarks/mha/triton/batch_size=1-bias=True-headdim=64-num_heads=32-blocksize_q=32-blocksize_k=32-mode=bwd.png differ diff --git a/benchmarks/mha/triton/batch_size=1-bias=True-headdim=64-num_heads=32-blocksize_q=32-blocksize_k=32-mode=fwd.csv b/benchmarks/mha/triton/batch_size=1-bias=True-headdim=64-num_heads=32-blocksize_q=32-blocksize_k=32-mode=fwd.csv new file mode 100644 index 0000000..7dfc784 --- /dev/null +++ b/benchmarks/mha/triton/batch_size=1-bias=True-headdim=64-num_heads=32-blocksize_q=32-blocksize_k=32-mode=fwd.csv @@ -0,0 +1,6 @@ +seqlen,Triton-BlockPtr,Triton-PtrBlock,Jax +1024.000000,0.761515,0.705987,1.045806 +2048.000000,2.568465,2.561394,4.283787 +4096.000000,8.659530,8.630638,16.555706 +6144.000000,13.410305,13.320363,36.982784 +8192.000000,14.432256,14.739116,65.866753 diff --git a/benchmarks/mha/triton/batch_size=1-bias=True-headdim=64-num_heads=32-blocksize_q=32-blocksize_k=32-mode=fwd.png b/benchmarks/mha/triton/batch_size=1-bias=True-headdim=64-num_heads=32-blocksize_q=32-blocksize_k=32-mode=fwd.png new file mode 100644 index 0000000..dca100d Binary files /dev/null and b/benchmarks/mha/triton/batch_size=1-bias=True-headdim=64-num_heads=32-blocksize_q=32-blocksize_k=32-mode=fwd.png differ diff --git a/benchmarks/mha/triton/batch_size=1-bias=True-headdim=64-num_heads=32-blocksize_q=32-blocksize_k=64-mode=fwd.csv b/benchmarks/mha/triton/batch_size=1-bias=True-headdim=64-num_heads=32-blocksize_q=32-blocksize_k=64-mode=fwd.csv new file mode 100644 index 0000000..af78f9c --- /dev/null +++ b/benchmarks/mha/triton/batch_size=1-bias=True-headdim=64-num_heads=32-blocksize_q=32-blocksize_k=64-mode=fwd.csv @@ -0,0 +1,6 @@ +seqlen,Triton-BlockPtr,Triton-PtrBlock,Jax +1024.000000,0.750592,0.734716,1.056916 +2048.000000,2.732344,2.604477,4.171424 +4096.000000,9.607169,9.035067,17.521664 +6144.000000,13.342037,13.356374,36.554237 +8192.000000,14.370816,14.657878,65.924286 diff --git a/benchmarks/mha/triton/batch_size=1-bias=True-headdim=64-num_heads=32-blocksize_q=32-blocksize_k=64-mode=fwd.png b/benchmarks/mha/triton/batch_size=1-bias=True-headdim=64-num_heads=32-blocksize_q=32-blocksize_k=64-mode=fwd.png new file mode 100644 index 0000000..fb164c0 Binary files /dev/null and b/benchmarks/mha/triton/batch_size=1-bias=True-headdim=64-num_heads=32-blocksize_q=32-blocksize_k=64-mode=fwd.png differ diff --git a/benchmarks/mha/triton/batch_size=1-bias=True-headdim=64-num_heads=32-blocksize_q=64-blocksize_k=128-mode=fwd.csv b/benchmarks/mha/triton/batch_size=1-bias=True-headdim=64-num_heads=32-blocksize_q=64-blocksize_k=128-mode=fwd.csv new file mode 100644 index 0000000..3e2f3b8 --- /dev/null +++ b/benchmarks/mha/triton/batch_size=1-bias=True-headdim=64-num_heads=32-blocksize_q=64-blocksize_k=128-mode=fwd.csv @@ -0,0 +1,6 @@ +seqlen,Triton-BlockPtr,Triton-PtrBlock,Jax +1024.000000,1.376570,1.346668,1.055802 +2048.000000,3.491719,3.624164,4.343775 +4096.000000,6.764032,6.389760,16.504036 +6144.000000,3.339776,3.219456,36.638721 +8192.000000,0.004096,0.003072,66.096130 diff --git a/benchmarks/mha/triton/batch_size=1-bias=True-headdim=64-num_heads=32-blocksize_q=64-blocksize_k=128-mode=fwd.png b/benchmarks/mha/triton/batch_size=1-bias=True-headdim=64-num_heads=32-blocksize_q=64-blocksize_k=128-mode=fwd.png new file mode 100644 index 0000000..9368a7d Binary files /dev/null and b/benchmarks/mha/triton/batch_size=1-bias=True-headdim=64-num_heads=32-blocksize_q=64-blocksize_k=128-mode=fwd.png differ diff --git a/benchmarks/mha/triton/batch_size=1-bias=True-headdim=64-num_heads=32-blocksize_q=64-blocksize_k=32-mode=fwd.csv b/benchmarks/mha/triton/batch_size=1-bias=True-headdim=64-num_heads=32-blocksize_q=64-blocksize_k=32-mode=fwd.csv new file mode 100644 index 0000000..57f22f2 --- /dev/null +++ b/benchmarks/mha/triton/batch_size=1-bias=True-headdim=64-num_heads=32-blocksize_q=64-blocksize_k=32-mode=fwd.csv @@ -0,0 +1,6 @@ +seqlen,Triton-BlockPtr,Triton-PtrBlock,Jax +1024.000000,0.708902,0.700407,1.044036 +2048.000000,2.740203,2.607936,4.228782 +4096.000000,9.416547,9.364795,17.520128 +6144.000000,13.611861,13.559636,36.651520 +8192.000000,19.022079,14.713515,65.922050 diff --git a/benchmarks/mha/triton/batch_size=1-bias=True-headdim=64-num_heads=32-blocksize_q=64-blocksize_k=32-mode=fwd.png b/benchmarks/mha/triton/batch_size=1-bias=True-headdim=64-num_heads=32-blocksize_q=64-blocksize_k=32-mode=fwd.png new file mode 100644 index 0000000..4e7beda Binary files /dev/null and b/benchmarks/mha/triton/batch_size=1-bias=True-headdim=64-num_heads=32-blocksize_q=64-blocksize_k=32-mode=fwd.png differ diff --git a/benchmarks/mha/triton/batch_size=1-bias=True-headdim=64-num_heads=32-blocksize_q=64-blocksize_k=64-mode=fwd.csv b/benchmarks/mha/triton/batch_size=1-bias=True-headdim=64-num_heads=32-blocksize_q=64-blocksize_k=64-mode=fwd.csv new file mode 100644 index 0000000..0175b1c --- /dev/null +++ b/benchmarks/mha/triton/batch_size=1-bias=True-headdim=64-num_heads=32-blocksize_q=64-blocksize_k=64-mode=fwd.csv @@ -0,0 +1,6 @@ +seqlen,Triton-BlockPtr,Triton-PtrBlock,Jax +1024.000000,0.697885,0.691014,1.054027 +2048.000000,2.747819,2.552127,4.263107 +4096.000000,9.142271,9.166776,16.167019 +6144.000000,13.440000,14.201709,36.599297 +8192.000000,19.050240,14.466048,65.029411 diff --git a/benchmarks/mha/triton/batch_size=1-bias=True-headdim=64-num_heads=32-blocksize_q=64-blocksize_k=64-mode=fwd.png b/benchmarks/mha/triton/batch_size=1-bias=True-headdim=64-num_heads=32-blocksize_q=64-blocksize_k=64-mode=fwd.png new file mode 100644 index 0000000..7dc4759 Binary files /dev/null and b/benchmarks/mha/triton/batch_size=1-bias=True-headdim=64-num_heads=32-blocksize_q=64-blocksize_k=64-mode=fwd.png differ diff --git a/benchmarks/mha/triton/batch_size=1-bias=True-headdim=64-num_heads=8-blocksize_q=128-blocksize_k=128-mode=bwd.csv b/benchmarks/mha/triton/batch_size=1-bias=True-headdim=64-num_heads=8-blocksize_q=128-blocksize_k=128-mode=bwd.csv new file mode 100644 index 0000000..60befb3 --- /dev/null +++ b/benchmarks/mha/triton/batch_size=1-bias=True-headdim=64-num_heads=8-blocksize_q=128-blocksize_k=128-mode=bwd.csv @@ -0,0 +1,6 @@ +seqlen,Triton-BlockPtr,Triton-PtrBlock,Jax +1024.000000,61.353985,62.225407,1.337359 +2048.000000,60.879871,61.496319,2.333625 +4096.000000,56.600574,56.689663,8.596188 +6144.000000,57.592831,60.437504,19.008633 +8192.000000,44.192768,41.035774,33.623520 diff --git a/benchmarks/mha/triton/batch_size=1-bias=True-headdim=64-num_heads=8-blocksize_q=128-blocksize_k=128-mode=bwd.png b/benchmarks/mha/triton/batch_size=1-bias=True-headdim=64-num_heads=8-blocksize_q=128-blocksize_k=128-mode=bwd.png new file mode 100644 index 0000000..34f3022 Binary files /dev/null and b/benchmarks/mha/triton/batch_size=1-bias=True-headdim=64-num_heads=8-blocksize_q=128-blocksize_k=128-mode=bwd.png differ diff --git a/benchmarks/mha/triton/batch_size=1-bias=True-headdim=64-num_heads=8-blocksize_q=128-blocksize_k=128-mode=fwd.csv b/benchmarks/mha/triton/batch_size=1-bias=True-headdim=64-num_heads=8-blocksize_q=128-blocksize_k=128-mode=fwd.csv new file mode 100644 index 0000000..1967d60 --- /dev/null +++ b/benchmarks/mha/triton/batch_size=1-bias=True-headdim=64-num_heads=8-blocksize_q=128-blocksize_k=128-mode=fwd.csv @@ -0,0 +1,6 @@ +seqlen,Triton-BlockPtr,Triton-PtrBlock,Jax +1024.000000,2.602459,2.490644,0.344058 +2048.000000,7.456237,7.448722,1.120765 +4096.000000,6.392685,6.435547,4.236637 +6144.000000,9.776128,9.794902,9.237079 +8192.000000,1.684480,1.734144,16.682266 diff --git a/benchmarks/mha/triton/batch_size=1-bias=True-headdim=64-num_heads=8-blocksize_q=128-blocksize_k=128-mode=fwd.png b/benchmarks/mha/triton/batch_size=1-bias=True-headdim=64-num_heads=8-blocksize_q=128-blocksize_k=128-mode=fwd.png new file mode 100644 index 0000000..739c47c Binary files /dev/null and b/benchmarks/mha/triton/batch_size=1-bias=True-headdim=64-num_heads=8-blocksize_q=128-blocksize_k=128-mode=fwd.png differ diff --git a/benchmarks/mha/triton/batch_size=1-bias=True-headdim=64-num_heads=8-blocksize_q=128-blocksize_k=32-mode=bwd.csv b/benchmarks/mha/triton/batch_size=1-bias=True-headdim=64-num_heads=8-blocksize_q=128-blocksize_k=32-mode=bwd.csv new file mode 100644 index 0000000..2b30b05 --- /dev/null +++ b/benchmarks/mha/triton/batch_size=1-bias=True-headdim=64-num_heads=8-blocksize_q=128-blocksize_k=32-mode=bwd.csv @@ -0,0 +1,6 @@ +seqlen,Triton-BlockPtr,Triton-PtrBlock,Jax +1024.000000,41.853951,41.547775,1.411192 +2048.000000,40.596992,40.745472,2.399296 +4096.000000,37.636608,38.649857,8.613574 +6144.000000,32.667648,32.721409,18.973766 +8192.000000,25.736193,26.100735,33.609726 diff --git a/benchmarks/mha/triton/batch_size=1-bias=True-headdim=64-num_heads=8-blocksize_q=128-blocksize_k=32-mode=bwd.png b/benchmarks/mha/triton/batch_size=1-bias=True-headdim=64-num_heads=8-blocksize_q=128-blocksize_k=32-mode=bwd.png new file mode 100644 index 0000000..54e73f7 Binary files /dev/null and b/benchmarks/mha/triton/batch_size=1-bias=True-headdim=64-num_heads=8-blocksize_q=128-blocksize_k=32-mode=bwd.png differ diff --git a/benchmarks/mha/triton/batch_size=1-bias=True-headdim=64-num_heads=8-blocksize_q=128-blocksize_k=32-mode=fwd.csv b/benchmarks/mha/triton/batch_size=1-bias=True-headdim=64-num_heads=8-blocksize_q=128-blocksize_k=32-mode=fwd.csv new file mode 100644 index 0000000..9ee3385 --- /dev/null +++ b/benchmarks/mha/triton/batch_size=1-bias=True-headdim=64-num_heads=8-blocksize_q=128-blocksize_k=32-mode=fwd.csv @@ -0,0 +1,6 @@ +seqlen,Triton-BlockPtr,Triton-PtrBlock,Jax +1024.000000,0.523776,0.503175,0.330198 +2048.000000,0.783487,0.805291,1.121285 +4096.000000,3.229547,3.258554,4.321293 +6144.000000,5.891724,6.055610,9.273612 +8192.000000,9.949578,9.955723,16.893600 diff --git a/benchmarks/mha/triton/batch_size=1-bias=True-headdim=64-num_heads=8-blocksize_q=128-blocksize_k=32-mode=fwd.png b/benchmarks/mha/triton/batch_size=1-bias=True-headdim=64-num_heads=8-blocksize_q=128-blocksize_k=32-mode=fwd.png new file mode 100644 index 0000000..1fa177d Binary files /dev/null and b/benchmarks/mha/triton/batch_size=1-bias=True-headdim=64-num_heads=8-blocksize_q=128-blocksize_k=32-mode=fwd.png differ diff --git a/benchmarks/mha/triton/batch_size=1-bias=True-headdim=64-num_heads=8-blocksize_q=128-blocksize_k=64-mode=bwd.csv b/benchmarks/mha/triton/batch_size=1-bias=True-headdim=64-num_heads=8-blocksize_q=128-blocksize_k=64-mode=bwd.csv new file mode 100644 index 0000000..b521e1f --- /dev/null +++ b/benchmarks/mha/triton/batch_size=1-bias=True-headdim=64-num_heads=8-blocksize_q=128-blocksize_k=64-mode=bwd.csv @@ -0,0 +1,6 @@ +seqlen,Triton-BlockPtr,Triton-PtrBlock,Jax +1024.000000,50.859009,50.534401,1.410580 +2048.000000,51.761150,52.281345,2.357782 +4096.000000,48.677887,48.444416,8.476134 +6144.000000,53.513214,54.030334,18.859066 +8192.000000,36.483070,37.648384,33.277550 diff --git a/benchmarks/mha/triton/batch_size=1-bias=True-headdim=64-num_heads=8-blocksize_q=128-blocksize_k=64-mode=bwd.png b/benchmarks/mha/triton/batch_size=1-bias=True-headdim=64-num_heads=8-blocksize_q=128-blocksize_k=64-mode=bwd.png new file mode 100644 index 0000000..50f6db4 Binary files /dev/null and b/benchmarks/mha/triton/batch_size=1-bias=True-headdim=64-num_heads=8-blocksize_q=128-blocksize_k=64-mode=bwd.png differ diff --git a/benchmarks/mha/triton/batch_size=1-bias=True-headdim=64-num_heads=8-blocksize_q=128-blocksize_k=64-mode=fwd.csv b/benchmarks/mha/triton/batch_size=1-bias=True-headdim=64-num_heads=8-blocksize_q=128-blocksize_k=64-mode=fwd.csv new file mode 100644 index 0000000..3f19aa9 --- /dev/null +++ b/benchmarks/mha/triton/batch_size=1-bias=True-headdim=64-num_heads=8-blocksize_q=128-blocksize_k=64-mode=fwd.csv @@ -0,0 +1,6 @@ +seqlen,Triton-BlockPtr,Triton-PtrBlock,Jax +1024.000000,1.723079,1.697655,0.342277 +2048.000000,4.616704,4.653780,1.125893 +4096.000000,5.810269,5.742592,4.256475 +6144.000000,12.641278,12.547072,9.258146 +8192.000000,6.558720,6.664533,16.576769 diff --git a/benchmarks/mha/triton/batch_size=1-bias=True-headdim=64-num_heads=8-blocksize_q=128-blocksize_k=64-mode=fwd.png b/benchmarks/mha/triton/batch_size=1-bias=True-headdim=64-num_heads=8-blocksize_q=128-blocksize_k=64-mode=fwd.png new file mode 100644 index 0000000..55f87c8 Binary files /dev/null and b/benchmarks/mha/triton/batch_size=1-bias=True-headdim=64-num_heads=8-blocksize_q=128-blocksize_k=64-mode=fwd.png differ diff --git a/benchmarks/mha/triton/batch_size=1-bias=True-headdim=64-num_heads=8-blocksize_q=32-blocksize_k=128-mode=bwd.csv b/benchmarks/mha/triton/batch_size=1-bias=True-headdim=64-num_heads=8-blocksize_q=32-blocksize_k=128-mode=bwd.csv new file mode 100644 index 0000000..5579340 --- /dev/null +++ b/benchmarks/mha/triton/batch_size=1-bias=True-headdim=64-num_heads=8-blocksize_q=32-blocksize_k=128-mode=bwd.csv @@ -0,0 +1,6 @@ +seqlen,Triton-BlockPtr,Triton-PtrBlock,Jax +1024.000000,39.845871,40.007553,1.397342 +2048.000000,39.244286,93.883904,2.386327 +4096.000000,38.621185,39.122433,8.596893 +6144.000000,35.342339,35.373055,18.857094 +8192.000000,30.567425,30.608896,33.631248 diff --git a/benchmarks/mha/triton/batch_size=1-bias=True-headdim=64-num_heads=8-blocksize_q=32-blocksize_k=128-mode=bwd.png b/benchmarks/mha/triton/batch_size=1-bias=True-headdim=64-num_heads=8-blocksize_q=32-blocksize_k=128-mode=bwd.png new file mode 100644 index 0000000..a0cfc53 Binary files /dev/null and b/benchmarks/mha/triton/batch_size=1-bias=True-headdim=64-num_heads=8-blocksize_q=32-blocksize_k=128-mode=bwd.png differ diff --git a/benchmarks/mha/triton/batch_size=1-bias=True-headdim=64-num_heads=8-blocksize_q=32-blocksize_k=128-mode=fwd.csv b/benchmarks/mha/triton/batch_size=1-bias=True-headdim=64-num_heads=8-blocksize_q=32-blocksize_k=128-mode=fwd.csv new file mode 100644 index 0000000..531e2be --- /dev/null +++ b/benchmarks/mha/triton/batch_size=1-bias=True-headdim=64-num_heads=8-blocksize_q=32-blocksize_k=128-mode=fwd.csv @@ -0,0 +1,6 @@ +seqlen,Triton-BlockPtr,Triton-PtrBlock,Jax +1024.000000,0.520714,0.523802,0.344201 +2048.000000,0.815578,0.854716,1.119187 +4096.000000,2.947140,2.932112,4.329190 +6144.000000,5.419055,5.407166,9.223536 +8192.000000,9.406306,9.235140,16.996210 diff --git a/benchmarks/mha/triton/batch_size=1-bias=True-headdim=64-num_heads=8-blocksize_q=32-blocksize_k=128-mode=fwd.png b/benchmarks/mha/triton/batch_size=1-bias=True-headdim=64-num_heads=8-blocksize_q=32-blocksize_k=128-mode=fwd.png new file mode 100644 index 0000000..1789742 Binary files /dev/null and b/benchmarks/mha/triton/batch_size=1-bias=True-headdim=64-num_heads=8-blocksize_q=32-blocksize_k=128-mode=fwd.png differ diff --git a/benchmarks/mha/triton/batch_size=1-bias=True-headdim=64-num_heads=8-blocksize_q=32-blocksize_k=32-mode=bwd.csv b/benchmarks/mha/triton/batch_size=1-bias=True-headdim=64-num_heads=8-blocksize_q=32-blocksize_k=32-mode=bwd.csv new file mode 100644 index 0000000..96ed937 --- /dev/null +++ b/benchmarks/mha/triton/batch_size=1-bias=True-headdim=64-num_heads=8-blocksize_q=32-blocksize_k=32-mode=bwd.csv @@ -0,0 +1,6 @@ +seqlen,Triton-BlockPtr,Triton-PtrBlock,Jax +1024.000000,36.375038,36.915024,1.356797 +2048.000000,35.839996,36.000771,2.349570 +4096.000000,32.206337,32.279037,8.530293 +6144.000000,25.390591,25.362946,19.068314 +8192.000000,16.850945,16.951296,33.924812 diff --git a/benchmarks/mha/triton/batch_size=1-bias=True-headdim=64-num_heads=8-blocksize_q=32-blocksize_k=32-mode=bwd.png b/benchmarks/mha/triton/batch_size=1-bias=True-headdim=64-num_heads=8-blocksize_q=32-blocksize_k=32-mode=bwd.png new file mode 100644 index 0000000..58772e2 Binary files /dev/null and b/benchmarks/mha/triton/batch_size=1-bias=True-headdim=64-num_heads=8-blocksize_q=32-blocksize_k=32-mode=bwd.png differ diff --git a/benchmarks/mha/triton/batch_size=1-bias=True-headdim=64-num_heads=8-blocksize_q=32-blocksize_k=32-mode=fwd.csv b/benchmarks/mha/triton/batch_size=1-bias=True-headdim=64-num_heads=8-blocksize_q=32-blocksize_k=32-mode=fwd.csv new file mode 100644 index 0000000..d125447 --- /dev/null +++ b/benchmarks/mha/triton/batch_size=1-bias=True-headdim=64-num_heads=8-blocksize_q=32-blocksize_k=32-mode=fwd.csv @@ -0,0 +1,6 @@ +seqlen,Triton-BlockPtr,Triton-PtrBlock,Jax +1024.000000,0.647896,0.517666,0.341431 +2048.000000,0.804695,0.814930,1.122615 +4096.000000,2.859736,2.862675,4.369193 +6144.000000,5.633426,5.611075,9.307389 +8192.000000,9.312256,9.455694,15.958285 diff --git a/benchmarks/mha/triton/batch_size=1-bias=True-headdim=64-num_heads=8-blocksize_q=32-blocksize_k=32-mode=fwd.png b/benchmarks/mha/triton/batch_size=1-bias=True-headdim=64-num_heads=8-blocksize_q=32-blocksize_k=32-mode=fwd.png new file mode 100644 index 0000000..d97488f Binary files /dev/null and b/benchmarks/mha/triton/batch_size=1-bias=True-headdim=64-num_heads=8-blocksize_q=32-blocksize_k=32-mode=fwd.png differ diff --git a/benchmarks/mha/triton/batch_size=1-bias=True-headdim=64-num_heads=8-blocksize_q=32-blocksize_k=64-mode=bwd.csv b/benchmarks/mha/triton/batch_size=1-bias=True-headdim=64-num_heads=8-blocksize_q=32-blocksize_k=64-mode=bwd.csv new file mode 100644 index 0000000..2f3b35f --- /dev/null +++ b/benchmarks/mha/triton/batch_size=1-bias=True-headdim=64-num_heads=8-blocksize_q=32-blocksize_k=64-mode=bwd.csv @@ -0,0 +1,6 @@ +seqlen,Triton-BlockPtr,Triton-PtrBlock,Jax +1024.000000,38.586975,38.283329,1.378514 +2048.000000,37.617149,37.868031,2.390532 +4096.000000,34.799103,35.452415,8.530737 +6144.000000,32.466431,32.334846,19.153547 +8192.000000,26.447872,26.618368,34.004990 diff --git a/benchmarks/mha/triton/batch_size=1-bias=True-headdim=64-num_heads=8-blocksize_q=32-blocksize_k=64-mode=bwd.png b/benchmarks/mha/triton/batch_size=1-bias=True-headdim=64-num_heads=8-blocksize_q=32-blocksize_k=64-mode=bwd.png new file mode 100644 index 0000000..933820d Binary files /dev/null and b/benchmarks/mha/triton/batch_size=1-bias=True-headdim=64-num_heads=8-blocksize_q=32-blocksize_k=64-mode=bwd.png differ diff --git a/benchmarks/mha/triton/batch_size=1-bias=True-headdim=64-num_heads=8-blocksize_q=32-blocksize_k=64-mode=fwd.csv b/benchmarks/mha/triton/batch_size=1-bias=True-headdim=64-num_heads=8-blocksize_q=32-blocksize_k=64-mode=fwd.csv new file mode 100644 index 0000000..ed4122d --- /dev/null +++ b/benchmarks/mha/triton/batch_size=1-bias=True-headdim=64-num_heads=8-blocksize_q=32-blocksize_k=64-mode=fwd.csv @@ -0,0 +1,6 @@ +seqlen,Triton-BlockPtr,Triton-PtrBlock,Jax +1024.000000,0.576113,0.543250,0.358758 +2048.000000,0.822253,0.827039,1.114569 +4096.000000,3.005690,2.990080,4.290877 +6144.000000,5.640425,5.588457,9.253213 +8192.000000,9.407960,9.420507,16.673454 diff --git a/benchmarks/mha/triton/batch_size=1-bias=True-headdim=64-num_heads=8-blocksize_q=32-blocksize_k=64-mode=fwd.png b/benchmarks/mha/triton/batch_size=1-bias=True-headdim=64-num_heads=8-blocksize_q=32-blocksize_k=64-mode=fwd.png new file mode 100644 index 0000000..8dc20b7 Binary files /dev/null and b/benchmarks/mha/triton/batch_size=1-bias=True-headdim=64-num_heads=8-blocksize_q=32-blocksize_k=64-mode=fwd.png differ diff --git a/benchmarks/mha/triton/batch_size=1-bias=True-headdim=64-num_heads=8-blocksize_q=64-blocksize_k=128-mode=bwd.csv b/benchmarks/mha/triton/batch_size=1-bias=True-headdim=64-num_heads=8-blocksize_q=64-blocksize_k=128-mode=bwd.csv new file mode 100644 index 0000000..2476b0a --- /dev/null +++ b/benchmarks/mha/triton/batch_size=1-bias=True-headdim=64-num_heads=8-blocksize_q=64-blocksize_k=128-mode=bwd.csv @@ -0,0 +1,6 @@ +seqlen,Triton-BlockPtr,Triton-PtrBlock,Jax +1024.000000,49.309696,50.021378,1.384689 +2048.000000,48.715775,50.462719,2.415959 +4096.000000,46.816257,47.163391,8.567736 +6144.000000,42.467327,43.418625,18.979416 +8192.000000,42.753025,46.178303,32.854687 diff --git a/benchmarks/mha/triton/batch_size=1-bias=True-headdim=64-num_heads=8-blocksize_q=64-blocksize_k=128-mode=bwd.png b/benchmarks/mha/triton/batch_size=1-bias=True-headdim=64-num_heads=8-blocksize_q=64-blocksize_k=128-mode=bwd.png new file mode 100644 index 0000000..fb2f811 Binary files /dev/null and b/benchmarks/mha/triton/batch_size=1-bias=True-headdim=64-num_heads=8-blocksize_q=64-blocksize_k=128-mode=bwd.png differ diff --git a/benchmarks/mha/triton/batch_size=1-bias=True-headdim=64-num_heads=8-blocksize_q=64-blocksize_k=128-mode=fwd.csv b/benchmarks/mha/triton/batch_size=1-bias=True-headdim=64-num_heads=8-blocksize_q=64-blocksize_k=128-mode=fwd.csv new file mode 100644 index 0000000..883ccd2 --- /dev/null +++ b/benchmarks/mha/triton/batch_size=1-bias=True-headdim=64-num_heads=8-blocksize_q=64-blocksize_k=128-mode=fwd.csv @@ -0,0 +1,6 @@ +seqlen,Triton-BlockPtr,Triton-PtrBlock,Jax +1024.000000,0.860796,0.832957,0.350747 +2048.000000,1.870602,1.900544,1.138429 +4096.000000,4.201408,4.197858,4.254879 +6144.000000,7.111424,7.113984,9.273764 +8192.000000,6.915840,6.910720,16.820019 diff --git a/benchmarks/mha/triton/batch_size=1-bias=True-headdim=64-num_heads=8-blocksize_q=64-blocksize_k=128-mode=fwd.png b/benchmarks/mha/triton/batch_size=1-bias=True-headdim=64-num_heads=8-blocksize_q=64-blocksize_k=128-mode=fwd.png new file mode 100644 index 0000000..08256c1 Binary files /dev/null and b/benchmarks/mha/triton/batch_size=1-bias=True-headdim=64-num_heads=8-blocksize_q=64-blocksize_k=128-mode=fwd.png differ diff --git a/benchmarks/mha/triton/batch_size=1-bias=True-headdim=64-num_heads=8-blocksize_q=64-blocksize_k=32-mode=bwd.csv b/benchmarks/mha/triton/batch_size=1-bias=True-headdim=64-num_heads=8-blocksize_q=64-blocksize_k=32-mode=bwd.csv new file mode 100644 index 0000000..17a8fe2 --- /dev/null +++ b/benchmarks/mha/triton/batch_size=1-bias=True-headdim=64-num_heads=8-blocksize_q=64-blocksize_k=32-mode=bwd.csv @@ -0,0 +1,6 @@ +seqlen,Triton-BlockPtr,Triton-PtrBlock,Jax +1024.000000,38.424496,38.535683,1.361614 +2048.000000,37.474815,37.586945,2.413392 +4096.000000,34.633728,34.557442,8.531098 +6144.000000,29.181953,29.447166,18.814770 +8192.000000,21.328896,21.188606,33.464352 diff --git a/benchmarks/mha/triton/batch_size=1-bias=True-headdim=64-num_heads=8-blocksize_q=64-blocksize_k=32-mode=bwd.png b/benchmarks/mha/triton/batch_size=1-bias=True-headdim=64-num_heads=8-blocksize_q=64-blocksize_k=32-mode=bwd.png new file mode 100644 index 0000000..45afab8 Binary files /dev/null and b/benchmarks/mha/triton/batch_size=1-bias=True-headdim=64-num_heads=8-blocksize_q=64-blocksize_k=32-mode=bwd.png differ diff --git a/benchmarks/mha/triton/batch_size=1-bias=True-headdim=64-num_heads=8-blocksize_q=64-blocksize_k=32-mode=fwd.csv b/benchmarks/mha/triton/batch_size=1-bias=True-headdim=64-num_heads=8-blocksize_q=64-blocksize_k=32-mode=fwd.csv new file mode 100644 index 0000000..93f9af1 --- /dev/null +++ b/benchmarks/mha/triton/batch_size=1-bias=True-headdim=64-num_heads=8-blocksize_q=64-blocksize_k=32-mode=fwd.csv @@ -0,0 +1,6 @@ +seqlen,Triton-BlockPtr,Triton-PtrBlock,Jax +1024.000000,0.513495,0.517661,0.350645 +2048.000000,0.826982,0.820568,1.118938 +4096.000000,3.070091,3.075726,4.363532 +6144.000000,5.725324,5.628037,9.183004 +8192.000000,9.397905,9.569061,16.451668 diff --git a/benchmarks/mha/triton/batch_size=1-bias=True-headdim=64-num_heads=8-blocksize_q=64-blocksize_k=32-mode=fwd.png b/benchmarks/mha/triton/batch_size=1-bias=True-headdim=64-num_heads=8-blocksize_q=64-blocksize_k=32-mode=fwd.png new file mode 100644 index 0000000..2508ac8 Binary files /dev/null and b/benchmarks/mha/triton/batch_size=1-bias=True-headdim=64-num_heads=8-blocksize_q=64-blocksize_k=32-mode=fwd.png differ diff --git a/benchmarks/mha/triton/batch_size=1-bias=True-headdim=64-num_heads=8-blocksize_q=64-blocksize_k=64-mode=bwd.csv b/benchmarks/mha/triton/batch_size=1-bias=True-headdim=64-num_heads=8-blocksize_q=64-blocksize_k=64-mode=bwd.csv new file mode 100644 index 0000000..ca61ee6 --- /dev/null +++ b/benchmarks/mha/triton/batch_size=1-bias=True-headdim=64-num_heads=8-blocksize_q=64-blocksize_k=64-mode=bwd.csv @@ -0,0 +1,6 @@ +seqlen,Triton-BlockPtr,Triton-PtrBlock,Jax +1024.000000,39.700096,40.035278,1.382476 +2048.000000,40.352768,40.001534,2.420520 +4096.000000,38.114815,37.777920,8.479310 +6144.000000,35.771393,35.728897,18.966349 +8192.000000,30.152704,31.441919,33.163776 diff --git a/benchmarks/mha/triton/batch_size=1-bias=True-headdim=64-num_heads=8-blocksize_q=64-blocksize_k=64-mode=bwd.png b/benchmarks/mha/triton/batch_size=1-bias=True-headdim=64-num_heads=8-blocksize_q=64-blocksize_k=64-mode=bwd.png new file mode 100644 index 0000000..e73ca30 Binary files /dev/null and b/benchmarks/mha/triton/batch_size=1-bias=True-headdim=64-num_heads=8-blocksize_q=64-blocksize_k=64-mode=bwd.png differ diff --git a/benchmarks/mha/triton/batch_size=1-bias=True-headdim=64-num_heads=8-blocksize_q=64-blocksize_k=64-mode=fwd.csv b/benchmarks/mha/triton/batch_size=1-bias=True-headdim=64-num_heads=8-blocksize_q=64-blocksize_k=64-mode=fwd.csv new file mode 100644 index 0000000..a721dc0 --- /dev/null +++ b/benchmarks/mha/triton/batch_size=1-bias=True-headdim=64-num_heads=8-blocksize_q=64-blocksize_k=64-mode=fwd.csv @@ -0,0 +1,6 @@ +seqlen,Triton-BlockPtr,Triton-PtrBlock,Jax +1024.000000,0.519706,0.541197,0.349492 +2048.000000,0.785792,0.792366,1.128300 +4096.000000,2.982607,2.973717,4.243594 +6144.000000,5.667371,5.605760,9.169615 +8192.000000,9.369600,9.438426,16.558758 diff --git a/benchmarks/mha/triton/batch_size=1-bias=True-headdim=64-num_heads=8-blocksize_q=64-blocksize_k=64-mode=fwd.png b/benchmarks/mha/triton/batch_size=1-bias=True-headdim=64-num_heads=8-blocksize_q=64-blocksize_k=64-mode=fwd.png new file mode 100644 index 0000000..b66400e Binary files /dev/null and b/benchmarks/mha/triton/batch_size=1-bias=True-headdim=64-num_heads=8-blocksize_q=64-blocksize_k=64-mode=fwd.png differ diff --git a/benchmarks/mha/triton/batch_size=2-bias=True-headdim=64-num_heads=16-blocksize_q=128-blocksize_k=128-mode=fwd.csv b/benchmarks/mha/triton/batch_size=2-bias=True-headdim=64-num_heads=16-blocksize_q=128-blocksize_k=128-mode=fwd.csv new file mode 100644 index 0000000..05336c7 --- /dev/null +++ b/benchmarks/mha/triton/batch_size=2-bias=True-headdim=64-num_heads=16-blocksize_q=128-blocksize_k=128-mode=fwd.csv @@ -0,0 +1,6 @@ +seqlen,Triton-BlockPtr,Triton-PtrBlock,Jax +1024.000000,3.016448,2.718542,1.108554 +2048.000000,4.776374,4.713472,4.471820 +4096.000000,1.587200,1.765888,16.994169 +6144.000000,0.004096,0.003072,39.430237 +8192.000000,0.609280,0.546816,70.395363 diff --git a/benchmarks/mha/triton/batch_size=2-bias=True-headdim=64-num_heads=16-blocksize_q=128-blocksize_k=128-mode=fwd.png b/benchmarks/mha/triton/batch_size=2-bias=True-headdim=64-num_heads=16-blocksize_q=128-blocksize_k=128-mode=fwd.png new file mode 100644 index 0000000..dfd4256 Binary files /dev/null and b/benchmarks/mha/triton/batch_size=2-bias=True-headdim=64-num_heads=16-blocksize_q=128-blocksize_k=128-mode=fwd.png differ diff --git a/benchmarks/mha/triton/batch_size=2-bias=True-headdim=64-num_heads=16-blocksize_q=128-blocksize_k=32-mode=fwd.csv b/benchmarks/mha/triton/batch_size=2-bias=True-headdim=64-num_heads=16-blocksize_q=128-blocksize_k=32-mode=fwd.csv new file mode 100644 index 0000000..e1eb741 --- /dev/null +++ b/benchmarks/mha/triton/batch_size=2-bias=True-headdim=64-num_heads=16-blocksize_q=128-blocksize_k=32-mode=fwd.csv @@ -0,0 +1,6 @@ +seqlen,Triton-BlockPtr,Triton-PtrBlock,Jax +1024.000000,0.815368,0.806767,1.081670 +2048.000000,2.774706,2.834364,4.351111 +4096.000000,9.429071,9.525248,16.589849 +6144.000000,14.187520,15.814313,36.553886 +8192.000000,18.838186,16.232790,69.380447 diff --git a/benchmarks/mha/triton/batch_size=2-bias=True-headdim=64-num_heads=16-blocksize_q=128-blocksize_k=32-mode=fwd.png b/benchmarks/mha/triton/batch_size=2-bias=True-headdim=64-num_heads=16-blocksize_q=128-blocksize_k=32-mode=fwd.png new file mode 100644 index 0000000..6e8ee09 Binary files /dev/null and b/benchmarks/mha/triton/batch_size=2-bias=True-headdim=64-num_heads=16-blocksize_q=128-blocksize_k=32-mode=fwd.png differ diff --git a/benchmarks/mha/triton/batch_size=2-bias=True-headdim=64-num_heads=16-blocksize_q=128-blocksize_k=64-mode=fwd.csv b/benchmarks/mha/triton/batch_size=2-bias=True-headdim=64-num_heads=16-blocksize_q=128-blocksize_k=64-mode=fwd.csv new file mode 100644 index 0000000..e2f5c52 --- /dev/null +++ b/benchmarks/mha/triton/batch_size=2-bias=True-headdim=64-num_heads=16-blocksize_q=128-blocksize_k=64-mode=fwd.csv @@ -0,0 +1,6 @@ +seqlen,Triton-BlockPtr,Triton-PtrBlock,Jax +1024.000000,1.993728,1.939996,1.132842 +2048.000000,4.663125,4.559034,4.576328 +4096.000000,5.938517,5.812907,16.555014 +6144.000000,0.004096,0.003072,38.962688 +8192.000000,0.565248,0.644096,69.729279 diff --git a/benchmarks/mha/triton/batch_size=2-bias=True-headdim=64-num_heads=16-blocksize_q=128-blocksize_k=64-mode=fwd.png b/benchmarks/mha/triton/batch_size=2-bias=True-headdim=64-num_heads=16-blocksize_q=128-blocksize_k=64-mode=fwd.png new file mode 100644 index 0000000..ab85fe0 Binary files /dev/null and b/benchmarks/mha/triton/batch_size=2-bias=True-headdim=64-num_heads=16-blocksize_q=128-blocksize_k=64-mode=fwd.png differ diff --git a/benchmarks/mha/triton/batch_size=2-bias=True-headdim=64-num_heads=16-blocksize_q=32-blocksize_k=128-mode=fwd.csv b/benchmarks/mha/triton/batch_size=2-bias=True-headdim=64-num_heads=16-blocksize_q=32-blocksize_k=128-mode=fwd.csv new file mode 100644 index 0000000..b5fa6f8 --- /dev/null +++ b/benchmarks/mha/triton/batch_size=2-bias=True-headdim=64-num_heads=16-blocksize_q=32-blocksize_k=128-mode=fwd.csv @@ -0,0 +1,6 @@ +seqlen,Triton-BlockPtr,Triton-PtrBlock,Jax +1024.000000,0.791671,0.786922,1.106653 +2048.000000,2.779336,2.588974,4.389463 +4096.000000,8.780406,9.031168,16.431162 +6144.000000,14.875134,14.510079,38.387711 +8192.000000,19.375788,15.981910,70.879906 diff --git a/benchmarks/mha/triton/batch_size=2-bias=True-headdim=64-num_heads=16-blocksize_q=32-blocksize_k=128-mode=fwd.png b/benchmarks/mha/triton/batch_size=2-bias=True-headdim=64-num_heads=16-blocksize_q=32-blocksize_k=128-mode=fwd.png new file mode 100644 index 0000000..76e6b38 Binary files /dev/null and b/benchmarks/mha/triton/batch_size=2-bias=True-headdim=64-num_heads=16-blocksize_q=32-blocksize_k=128-mode=fwd.png differ diff --git a/benchmarks/mha/triton/batch_size=2-bias=True-headdim=64-num_heads=16-blocksize_q=32-blocksize_k=32-mode=fwd.csv b/benchmarks/mha/triton/batch_size=2-bias=True-headdim=64-num_heads=16-blocksize_q=32-blocksize_k=32-mode=fwd.csv new file mode 100644 index 0000000..97ed856 --- /dev/null +++ b/benchmarks/mha/triton/batch_size=2-bias=True-headdim=64-num_heads=16-blocksize_q=32-blocksize_k=32-mode=fwd.csv @@ -0,0 +1,6 @@ +seqlen,Triton-BlockPtr,Triton-PtrBlock,Jax +1024.000000,0.767275,0.813129,1.086443 +2048.000000,2.711290,2.778112,4.553719 +4096.000000,9.293532,9.213781,17.360691 +6144.000000,14.882474,13.861888,36.494415 +8192.000000,18.230272,18.602325,70.261757 diff --git a/benchmarks/mha/triton/batch_size=2-bias=True-headdim=64-num_heads=16-blocksize_q=32-blocksize_k=32-mode=fwd.png b/benchmarks/mha/triton/batch_size=2-bias=True-headdim=64-num_heads=16-blocksize_q=32-blocksize_k=32-mode=fwd.png new file mode 100644 index 0000000..3f03d2c Binary files /dev/null and b/benchmarks/mha/triton/batch_size=2-bias=True-headdim=64-num_heads=16-blocksize_q=32-blocksize_k=32-mode=fwd.png differ diff --git a/benchmarks/mha/triton/batch_size=2-bias=True-headdim=64-num_heads=16-blocksize_q=32-blocksize_k=64-mode=fwd.csv b/benchmarks/mha/triton/batch_size=2-bias=True-headdim=64-num_heads=16-blocksize_q=32-blocksize_k=64-mode=fwd.csv new file mode 100644 index 0000000..5f71510 --- /dev/null +++ b/benchmarks/mha/triton/batch_size=2-bias=True-headdim=64-num_heads=16-blocksize_q=32-blocksize_k=64-mode=fwd.csv @@ -0,0 +1,6 @@ +seqlen,Triton-BlockPtr,Triton-PtrBlock,Jax +1024.000000,0.773839,0.778123,1.134479 +2048.000000,2.807117,2.708676,4.576355 +4096.000000,9.447424,9.421824,17.078535 +6144.000000,13.849427,14.191958,39.379471 +8192.000000,19.350868,14.856534,70.892700 diff --git a/benchmarks/mha/triton/batch_size=2-bias=True-headdim=64-num_heads=16-blocksize_q=32-blocksize_k=64-mode=fwd.png b/benchmarks/mha/triton/batch_size=2-bias=True-headdim=64-num_heads=16-blocksize_q=32-blocksize_k=64-mode=fwd.png new file mode 100644 index 0000000..9c0a328 Binary files /dev/null and b/benchmarks/mha/triton/batch_size=2-bias=True-headdim=64-num_heads=16-blocksize_q=32-blocksize_k=64-mode=fwd.png differ diff --git a/benchmarks/mha/triton/batch_size=2-bias=True-headdim=64-num_heads=16-blocksize_q=64-blocksize_k=128-mode=fwd.csv b/benchmarks/mha/triton/batch_size=2-bias=True-headdim=64-num_heads=16-blocksize_q=64-blocksize_k=128-mode=fwd.csv new file mode 100644 index 0000000..c090e0a --- /dev/null +++ b/benchmarks/mha/triton/batch_size=2-bias=True-headdim=64-num_heads=16-blocksize_q=64-blocksize_k=128-mode=fwd.csv @@ -0,0 +1,6 @@ +seqlen,Triton-BlockPtr,Triton-PtrBlock,Jax +1024.000000,1.446932,1.361579,1.164321 +2048.000000,3.849088,4.349473,4.559613 +4096.000000,6.429440,6.547456,16.471294 +6144.000000,3.458048,3.225600,37.546127 +8192.000000,0.003072,0.722944,69.769539 diff --git a/benchmarks/mha/triton/batch_size=2-bias=True-headdim=64-num_heads=16-blocksize_q=64-blocksize_k=128-mode=fwd.png b/benchmarks/mha/triton/batch_size=2-bias=True-headdim=64-num_heads=16-blocksize_q=64-blocksize_k=128-mode=fwd.png new file mode 100644 index 0000000..ad269ce Binary files /dev/null and b/benchmarks/mha/triton/batch_size=2-bias=True-headdim=64-num_heads=16-blocksize_q=64-blocksize_k=128-mode=fwd.png differ diff --git a/benchmarks/mha/triton/batch_size=2-bias=True-headdim=64-num_heads=16-blocksize_q=64-blocksize_k=32-mode=fwd.csv b/benchmarks/mha/triton/batch_size=2-bias=True-headdim=64-num_heads=16-blocksize_q=64-blocksize_k=32-mode=fwd.csv new file mode 100644 index 0000000..d5f3f45 --- /dev/null +++ b/benchmarks/mha/triton/batch_size=2-bias=True-headdim=64-num_heads=16-blocksize_q=64-blocksize_k=32-mode=fwd.csv @@ -0,0 +1,6 @@ +seqlen,Triton-BlockPtr,Triton-PtrBlock,Jax +1024.000000,0.711567,0.702139,1.128044 +2048.000000,2.749167,2.708395,4.606463 +4096.000000,9.773449,9.490118,16.462849 +6144.000000,13.871104,14.724095,39.420418 +8192.000000,19.491840,18.470572,65.034241 diff --git a/benchmarks/mha/triton/batch_size=2-bias=True-headdim=64-num_heads=16-blocksize_q=64-blocksize_k=32-mode=fwd.png b/benchmarks/mha/triton/batch_size=2-bias=True-headdim=64-num_heads=16-blocksize_q=64-blocksize_k=32-mode=fwd.png new file mode 100644 index 0000000..093e5bd Binary files /dev/null and b/benchmarks/mha/triton/batch_size=2-bias=True-headdim=64-num_heads=16-blocksize_q=64-blocksize_k=32-mode=fwd.png differ diff --git a/benchmarks/mha/triton/batch_size=2-bias=True-headdim=64-num_heads=16-blocksize_q=64-blocksize_k=64-mode=fwd.csv b/benchmarks/mha/triton/batch_size=2-bias=True-headdim=64-num_heads=16-blocksize_q=64-blocksize_k=64-mode=fwd.csv new file mode 100644 index 0000000..524b635 --- /dev/null +++ b/benchmarks/mha/triton/batch_size=2-bias=True-headdim=64-num_heads=16-blocksize_q=64-blocksize_k=64-mode=fwd.csv @@ -0,0 +1,6 @@ +seqlen,Triton-BlockPtr,Triton-PtrBlock,Jax +1024.000000,0.703027,0.708415,1.102299 +2048.000000,2.692684,2.626432,4.585531 +4096.000000,8.840980,9.186743,16.590517 +6144.000000,14.992554,14.491649,36.601345 +8192.000000,18.592085,18.293760,70.544418 diff --git a/benchmarks/mha/triton/batch_size=2-bias=True-headdim=64-num_heads=16-blocksize_q=64-blocksize_k=64-mode=fwd.png b/benchmarks/mha/triton/batch_size=2-bias=True-headdim=64-num_heads=16-blocksize_q=64-blocksize_k=64-mode=fwd.png new file mode 100644 index 0000000..8c6517a Binary files /dev/null and b/benchmarks/mha/triton/batch_size=2-bias=True-headdim=64-num_heads=16-blocksize_q=64-blocksize_k=64-mode=fwd.png differ diff --git a/benchmarks/mha/triton/batch_size=2-bias=True-headdim=64-num_heads=8-blocksize_q=128-blocksize_k=128-mode=fwd.csv b/benchmarks/mha/triton/batch_size=2-bias=True-headdim=64-num_heads=8-blocksize_q=128-blocksize_k=128-mode=fwd.csv new file mode 100644 index 0000000..83f6b00 --- /dev/null +++ b/benchmarks/mha/triton/batch_size=2-bias=True-headdim=64-num_heads=8-blocksize_q=128-blocksize_k=128-mode=fwd.csv @@ -0,0 +1,6 @@ +seqlen,Triton-BlockPtr,Triton-PtrBlock,Jax +1024.000000,4.057088,4.245180,0.624358 +2048.000000,4.851318,4.914761,2.331070 +4096.000000,6.615040,6.592768,8.467351 +6144.000000,0.002048,0.004096,19.303846 +8192.000000,0.003072,0.004096,33.620193 diff --git a/benchmarks/mha/triton/batch_size=2-bias=True-headdim=64-num_heads=8-blocksize_q=128-blocksize_k=128-mode=fwd.png b/benchmarks/mha/triton/batch_size=2-bias=True-headdim=64-num_heads=8-blocksize_q=128-blocksize_k=128-mode=fwd.png new file mode 100644 index 0000000..f468052 Binary files /dev/null and b/benchmarks/mha/triton/batch_size=2-bias=True-headdim=64-num_heads=8-blocksize_q=128-blocksize_k=128-mode=fwd.png differ diff --git a/benchmarks/mha/triton/batch_size=2-bias=True-headdim=64-num_heads=8-blocksize_q=128-blocksize_k=32-mode=fwd.csv b/benchmarks/mha/triton/batch_size=2-bias=True-headdim=64-num_heads=8-blocksize_q=128-blocksize_k=32-mode=fwd.csv new file mode 100644 index 0000000..8b063a1 --- /dev/null +++ b/benchmarks/mha/triton/batch_size=2-bias=True-headdim=64-num_heads=8-blocksize_q=128-blocksize_k=32-mode=fwd.csv @@ -0,0 +1,6 @@ +seqlen,Triton-BlockPtr,Triton-PtrBlock,Jax +1024.000000,0.526656,0.467367,0.606131 +2048.000000,1.756235,1.753088,2.272848 +4096.000000,5.566166,5.677440,9.024755 +6144.000000,10.894710,10.917516,19.174747 +8192.000000,14.630740,14.779050,34.224640 diff --git a/benchmarks/mha/triton/batch_size=2-bias=True-headdim=64-num_heads=8-blocksize_q=128-blocksize_k=32-mode=fwd.png b/benchmarks/mha/triton/batch_size=2-bias=True-headdim=64-num_heads=8-blocksize_q=128-blocksize_k=32-mode=fwd.png new file mode 100644 index 0000000..71b4139 Binary files /dev/null and b/benchmarks/mha/triton/batch_size=2-bias=True-headdim=64-num_heads=8-blocksize_q=128-blocksize_k=32-mode=fwd.png differ diff --git a/benchmarks/mha/triton/batch_size=2-bias=True-headdim=64-num_heads=8-blocksize_q=128-blocksize_k=64-mode=fwd.csv b/benchmarks/mha/triton/batch_size=2-bias=True-headdim=64-num_heads=8-blocksize_q=128-blocksize_k=64-mode=fwd.csv new file mode 100644 index 0000000..7d50cf1 --- /dev/null +++ b/benchmarks/mha/triton/batch_size=2-bias=True-headdim=64-num_heads=8-blocksize_q=128-blocksize_k=64-mode=fwd.csv @@ -0,0 +1,6 @@ +seqlen,Triton-BlockPtr,Triton-PtrBlock,Jax +1024.000000,2.478391,2.474656,0.613245 +2048.000000,3.561742,3.650398,2.413515 +4096.000000,6.082150,6.158165,8.622825 +6144.000000,1.751552,1.756160,19.270195 +8192.000000,0.920576,0.003072,36.294174 diff --git a/benchmarks/mha/triton/batch_size=2-bias=True-headdim=64-num_heads=8-blocksize_q=128-blocksize_k=64-mode=fwd.png b/benchmarks/mha/triton/batch_size=2-bias=True-headdim=64-num_heads=8-blocksize_q=128-blocksize_k=64-mode=fwd.png new file mode 100644 index 0000000..a004714 Binary files /dev/null and b/benchmarks/mha/triton/batch_size=2-bias=True-headdim=64-num_heads=8-blocksize_q=128-blocksize_k=64-mode=fwd.png differ diff --git a/benchmarks/mha/triton/batch_size=2-bias=True-headdim=64-num_heads=8-blocksize_q=32-blocksize_k=128-mode=fwd.csv b/benchmarks/mha/triton/batch_size=2-bias=True-headdim=64-num_heads=8-blocksize_q=32-blocksize_k=128-mode=fwd.csv new file mode 100644 index 0000000..17025da --- /dev/null +++ b/benchmarks/mha/triton/batch_size=2-bias=True-headdim=64-num_heads=8-blocksize_q=32-blocksize_k=128-mode=fwd.csv @@ -0,0 +1,6 @@ +seqlen,Triton-BlockPtr,Triton-PtrBlock,Jax +1024.000000,0.481963,0.472162,0.635261 +2048.000000,1.737520,1.773109,2.338425 +4096.000000,5.370496,5.492647,9.277440 +6144.000000,11.488768,10.655185,19.684666 +8192.000000,14.547456,14.532608,34.567886 diff --git a/benchmarks/mha/triton/batch_size=2-bias=True-headdim=64-num_heads=8-blocksize_q=32-blocksize_k=128-mode=fwd.png b/benchmarks/mha/triton/batch_size=2-bias=True-headdim=64-num_heads=8-blocksize_q=32-blocksize_k=128-mode=fwd.png new file mode 100644 index 0000000..8cf9992 Binary files /dev/null and b/benchmarks/mha/triton/batch_size=2-bias=True-headdim=64-num_heads=8-blocksize_q=32-blocksize_k=128-mode=fwd.png differ diff --git a/benchmarks/mha/triton/batch_size=2-bias=True-headdim=64-num_heads=8-blocksize_q=32-blocksize_k=32-mode=fwd.csv b/benchmarks/mha/triton/batch_size=2-bias=True-headdim=64-num_heads=8-blocksize_q=32-blocksize_k=32-mode=fwd.csv new file mode 100644 index 0000000..9cd2c98 --- /dev/null +++ b/benchmarks/mha/triton/batch_size=2-bias=True-headdim=64-num_heads=8-blocksize_q=32-blocksize_k=32-mode=fwd.csv @@ -0,0 +1,6 @@ +seqlen,Triton-BlockPtr,Triton-PtrBlock,Jax +1024.000000,0.458480,0.464732,0.609373 +2048.000000,1.689842,1.773986,2.237333 +4096.000000,5.108080,5.088000,8.947492 +6144.000000,10.509498,10.850305,20.380268 +8192.000000,14.760106,14.632447,34.841118 diff --git a/benchmarks/mha/triton/batch_size=2-bias=True-headdim=64-num_heads=8-blocksize_q=32-blocksize_k=32-mode=fwd.png b/benchmarks/mha/triton/batch_size=2-bias=True-headdim=64-num_heads=8-blocksize_q=32-blocksize_k=32-mode=fwd.png new file mode 100644 index 0000000..d1c5676 Binary files /dev/null and b/benchmarks/mha/triton/batch_size=2-bias=True-headdim=64-num_heads=8-blocksize_q=32-blocksize_k=32-mode=fwd.png differ diff --git a/benchmarks/mha/triton/batch_size=2-bias=True-headdim=64-num_heads=8-blocksize_q=32-blocksize_k=64-mode=fwd.csv b/benchmarks/mha/triton/batch_size=2-bias=True-headdim=64-num_heads=8-blocksize_q=32-blocksize_k=64-mode=fwd.csv new file mode 100644 index 0000000..1af234b --- /dev/null +++ b/benchmarks/mha/triton/batch_size=2-bias=True-headdim=64-num_heads=8-blocksize_q=32-blocksize_k=64-mode=fwd.csv @@ -0,0 +1,6 @@ +seqlen,Triton-BlockPtr,Triton-PtrBlock,Jax +1024.000000,0.551196,0.501013,0.631803 +2048.000000,1.677475,1.770744,2.353120 +4096.000000,5.462912,5.429115,9.106876 +6144.000000,10.608454,10.543011,19.638601 +8192.000000,14.735019,14.596607,35.468433 diff --git a/benchmarks/mha/triton/batch_size=2-bias=True-headdim=64-num_heads=8-blocksize_q=32-blocksize_k=64-mode=fwd.png b/benchmarks/mha/triton/batch_size=2-bias=True-headdim=64-num_heads=8-blocksize_q=32-blocksize_k=64-mode=fwd.png new file mode 100644 index 0000000..92b42bc Binary files /dev/null and b/benchmarks/mha/triton/batch_size=2-bias=True-headdim=64-num_heads=8-blocksize_q=32-blocksize_k=64-mode=fwd.png differ diff --git a/benchmarks/mha/triton/batch_size=2-bias=True-headdim=64-num_heads=8-blocksize_q=64-blocksize_k=128-mode=fwd.csv b/benchmarks/mha/triton/batch_size=2-bias=True-headdim=64-num_heads=8-blocksize_q=64-blocksize_k=128-mode=fwd.csv new file mode 100644 index 0000000..e79127d --- /dev/null +++ b/benchmarks/mha/triton/batch_size=2-bias=True-headdim=64-num_heads=8-blocksize_q=64-blocksize_k=128-mode=fwd.csv @@ -0,0 +1,6 @@ +seqlen,Triton-BlockPtr,Triton-PtrBlock,Jax +1024.000000,1.196295,1.214626,0.626305 +2048.000000,2.411008,2.412156,2.391395 +4096.000000,6.387328,6.035968,8.681062 +6144.000000,5.965141,5.595478,19.403021 +8192.000000,6.159360,5.892096,33.856514 diff --git a/benchmarks/mha/triton/batch_size=2-bias=True-headdim=64-num_heads=8-blocksize_q=64-blocksize_k=128-mode=fwd.png b/benchmarks/mha/triton/batch_size=2-bias=True-headdim=64-num_heads=8-blocksize_q=64-blocksize_k=128-mode=fwd.png new file mode 100644 index 0000000..b4956a6 Binary files /dev/null and b/benchmarks/mha/triton/batch_size=2-bias=True-headdim=64-num_heads=8-blocksize_q=64-blocksize_k=128-mode=fwd.png differ diff --git a/benchmarks/mha/triton/batch_size=2-bias=True-headdim=64-num_heads=8-blocksize_q=64-blocksize_k=32-mode=fwd.csv b/benchmarks/mha/triton/batch_size=2-bias=True-headdim=64-num_heads=8-blocksize_q=64-blocksize_k=32-mode=fwd.csv new file mode 100644 index 0000000..3556553 --- /dev/null +++ b/benchmarks/mha/triton/batch_size=2-bias=True-headdim=64-num_heads=8-blocksize_q=64-blocksize_k=32-mode=fwd.csv @@ -0,0 +1,6 @@ +seqlen,Triton-BlockPtr,Triton-PtrBlock,Jax +1024.000000,0.480738,0.467340,0.632091 +2048.000000,1.731760,1.773752,2.346859 +4096.000000,5.476651,5.613953,9.266672 +6144.000000,10.993106,10.725003,19.697664 +8192.000000,15.575624,14.935893,35.763264 diff --git a/benchmarks/mha/triton/batch_size=2-bias=True-headdim=64-num_heads=8-blocksize_q=64-blocksize_k=32-mode=fwd.png b/benchmarks/mha/triton/batch_size=2-bias=True-headdim=64-num_heads=8-blocksize_q=64-blocksize_k=32-mode=fwd.png new file mode 100644 index 0000000..d1d8353 Binary files /dev/null and b/benchmarks/mha/triton/batch_size=2-bias=True-headdim=64-num_heads=8-blocksize_q=64-blocksize_k=32-mode=fwd.png differ diff --git a/benchmarks/mha/triton/batch_size=2-bias=True-headdim=64-num_heads=8-blocksize_q=64-blocksize_k=64-mode=fwd.csv b/benchmarks/mha/triton/batch_size=2-bias=True-headdim=64-num_heads=8-blocksize_q=64-blocksize_k=64-mode=fwd.csv new file mode 100644 index 0000000..b5cf294 --- /dev/null +++ b/benchmarks/mha/triton/batch_size=2-bias=True-headdim=64-num_heads=8-blocksize_q=64-blocksize_k=64-mode=fwd.csv @@ -0,0 +1,6 @@ +seqlen,Triton-BlockPtr,Triton-PtrBlock,Jax +1024.000000,0.562165,0.536631,0.626351 +2048.000000,1.695971,1.699328,2.394719 +4096.000000,5.352694,5.439488,9.222279 +6144.000000,11.140933,10.992724,19.588991 +8192.000000,15.083666,15.595812,34.506432 diff --git a/benchmarks/mha/triton/batch_size=2-bias=True-headdim=64-num_heads=8-blocksize_q=64-blocksize_k=64-mode=fwd.png b/benchmarks/mha/triton/batch_size=2-bias=True-headdim=64-num_heads=8-blocksize_q=64-blocksize_k=64-mode=fwd.png new file mode 100644 index 0000000..7358d3b Binary files /dev/null and b/benchmarks/mha/triton/batch_size=2-bias=True-headdim=64-num_heads=8-blocksize_q=64-blocksize_k=64-mode=fwd.png differ diff --git a/benchmarks/mha/triton/results.html b/benchmarks/mha/triton/results.html new file mode 100644 index 0000000..1125ab5 --- /dev/null +++ b/benchmarks/mha/triton/results.html @@ -0,0 +1,154 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/jax_flash_attn2/flash_attention.py b/jax_flash_attn2/flash_attention.py index 07c39c6..ccd0db7 100644 --- a/jax_flash_attn2/flash_attention.py +++ b/jax_flash_attn2/flash_attention.py @@ -360,11 +360,11 @@ def _attn_refrence(query_states, key_states, value_states, bias): @lru_cache def get_cached_flash_attention( - backend: AVAILABLE_BACKENDS, - platform: AVAILABLE_FLASH_ATTENTION2_PLATFORMS, - blocksize_q: int, - blocksize_k: int, - softmax_scale: Optional[float], + backend: AVAILABLE_BACKENDS = None, + platform: AVAILABLE_FLASH_ATTENTION2_PLATFORMS = None, + blocksize_q: int = 128, + blocksize_k: int = 128, + softmax_scale: Optional[float] = None, ): return create_flash_attention( backend=backend, diff --git a/pyproject.toml b/pyproject.toml index 2f4daba..404c2c4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -38,7 +38,7 @@ unfixable = ["B"] [tool.ruff.lint.per-file-ignores] "__init__.py" = ["E402", "F401"] "**/{tests,docs,tools}/*" = ["E402"] -"python_test/*" = ["E402"] +"tests/*" = ["E402", "E731"] "triton_*" = ["E741", "ISC001", "E501", "E731"] "pallas_*" = ["E741", "ISC001", "E501", "E731"] diff --git a/tests/benchmark_mha_triton.py b/tests/benchmark_mha_triton.py new file mode 100644 index 0000000..930619d --- /dev/null +++ b/tests/benchmark_mha_triton.py @@ -0,0 +1,126 @@ +import os + +import jax +import jaxlib +import triton +from jax import nn +from jax import numpy as jnp +from jax import random as jrnd + +from jax_flash_attn2 import get_cached_flash_attention + +benchmark_configs = [] +for mode in ["bwd", "fwd"]: + for batch_size in [1, 2, 4]: + for bias in [True, False]: + for headdim in [64, 128, 256]: + for num_heads in [8, 16, 32]: + for blocksize_q in [32, 64, 128]: + for blocksize_k in [32, 64, 128]: + benchmark_configs.append( + triton.testing.Benchmark( + x_names=["seqlen"], + x_vals=[1024, 2048, 4096, 6144, 8192], + line_arg="provider", + line_vals=["triton-block-ptr", "triton-ptr-block", "jax"], + line_names=["Triton-BlockPtr", "Triton-PtrBlock", "Jax"], + styles=[("green", "-"), ("blue", "-."), ("blue", ":")], + ylabel="MS", + plot_name=f"batch_size={batch_size}-bias={bias}-headdim={headdim}-num_heads={num_heads}-blocksize_q={blocksize_q}-blocksize_k={blocksize_k}-mode={mode}", + args={ + "BATCH": batch_size, + "H": num_heads, + "HEAD_DIM": headdim, + "mode": mode, + "BIAS": bias, + "blocksize_k": blocksize_k, + "blocksize_q": blocksize_q, + }, + ) + ) + + +@triton.testing.perf_report(benchmark_configs) +def mha_attention_benchmark( + seqlen, + H, + BATCH, + HEAD_DIM, + mode, + BIAS, + blocksize_k, + blocksize_q, + provider, +): + try: + q_key, k_key, v_key = jrnd.split(jrnd.PRNGKey(8), 3) + query = jax.nn.initializers.normal(2)( + q_key, (BATCH, seqlen, H, HEAD_DIM), dtype=jnp.float16 + ) + key = jax.nn.initializers.normal(2)( + k_key, (BATCH, seqlen, H, HEAD_DIM), dtype=jnp.float16 + ) + value = jax.nn.initializers.normal(2)( + v_key, (BATCH, seqlen, H, HEAD_DIM), dtype=jnp.float16 + ) + bias = ( + jnp.where( + jrnd.randint(v_key, (BATCH, 1, seqlen, seqlen), 0, 4) > 2, + jnp.finfo(jnp.float16).min, + 0, + ) + if BIAS + else None + ) + if mode == "fwd": + if provider == "triton-block-ptr": + os.environ["FLASH_ATTN_BLOCK_PTR"] = "1" + flash_attn = get_cached_flash_attention( + blocksize_k=blocksize_k, + blocksize_q=blocksize_q, + ) + fn = lambda: flash_attn(query, key, value, bias) + elif provider == "triton-ptr-block": + os.environ["FLASH_ATTN_BLOCK_PTR"] = "0" + flash_attn = get_cached_flash_attention( + blocksize_k=blocksize_k, + blocksize_q=blocksize_q, + ) + fn = lambda: flash_attn(query, key, value, bias) + elif provider == "jax": + _fn = jax.jit(nn.dot_product_attention) + fn = lambda: _fn(query, key, value, bias).block_until_ready() + elif mode == "bwd": + if provider == "triton-block-ptr": + os.environ["FLASH_ATTN_BLOCK_PTR"] = "1" + flash_attn = get_cached_flash_attention( + blocksize_k=blocksize_k, + blocksize_q=blocksize_q, + ) + fn = lambda: jax.grad(lambda *x: flash_attn(*x).sum())(query, key, value, bias) + elif provider == "triton-ptr-block": + os.environ["FLASH_ATTN_BLOCK_PTR"] = "0" + flash_attn = get_cached_flash_attention( + blocksize_k=blocksize_k, + blocksize_q=blocksize_q, + ) + fn = lambda: jax.grad(lambda *x: flash_attn(*x).sum())(query, key, value, bias) + elif provider == "jax": + _fn = jax.jit(nn.dot_product_attention) + fn = lambda: jax.grad(lambda *x: _fn(*x).sum())( + query, key, value, bias + ).block_until_ready() + try: + ms = triton.testing.do_bench(fn) + except jaxlib.xla_extension.XlaRuntimeError: + ms = 100.0000 + return ms + except: # noqa + return 300.0000 + + +if __name__ == "__main__": + mha_attention_benchmark.run( + print_data=True, + save_path="jax-flash-attn2/benchmarks/mha/triton", + )