Skip to content

Commit

Permalink
updating README.md
Browse files Browse the repository at this point in the history
  • Loading branch information
erfanzar committed Nov 15, 2024
1 parent 62b7ca2 commit 4dd027c
Showing 1 changed file with 37 additions and 18 deletions.
55 changes: 37 additions & 18 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -61,30 +61,31 @@ with mesh:
```
## 📊 Benchmarks

- [Triton GPU MHA](https://github.com/erfanzar/jax-flash-attn2/tree/main/benchmarks/mha/triton)
- Triton GPU MQA (comming soon...)
- [Triton GPU (MHA/GQA) vs JAX SDPA CUDNN](https://github.com/erfanzar/jax-flash-attn2/tree/main/benchmarks/triton-vs-jax-sdpa-cudnn)
- [Triton GPU (MHA/GQA) vs JAX SDPA](https://github.com/erfanzar/jax-flash-attn2/tree/main/benchmarks/triton-vs-jax-sdpa)
- Pallas GPU MHA (comming soon...)
- Pallas TPU MHA (comming soon...)
- XLA CPU MHA (comming soon...)
## Supported Configurations

### Backends
- `gpu`: CUDA-capable GPUs
- `gpu`: CUDA/AMD-capable GPUs
- `tpu`: Google Cloud TPUs
- `cpu`: CPU fallback

### Platforms
- `triton`: Optimized for NVIDIA GPUs
- `triton`: Optimized for NVIDIA/AMD GPUs
- `pallas`: Optimized for TPUs and supported on GPUs
- `jax`: Universal fallback, supports all backends

### Valid Backend-Platform Combinations

| Backend | Supported Platforms |
| ------- | ------------------- |
| GPU | Triton, Pallas, JAX |
| TPU | Pallas, JAX |
| CPU | JAX |
| Backend | Supported Platforms |
| ---------------- | ------------------- |
| GPU - AMD/NVIDIA | Triton, JAX |
| GPU - NVIDIA | Triton, Pallas, JAX |
| TPU | Pallas, JAX |
| CPU | JAX |

## Advanced Configuration

Expand All @@ -94,27 +95,28 @@ with mesh:
attention = get_cached_flash_attention(
backend="gpu",
platform="triton",
blocksize_q=128, # Customize query block size
blocksize_k=128, # Customize key block size
blocksize_q=128, # Customize query block size # Ignored for Triton
blocksize_k=128, # Customize key block size Ignored for Triton
softmax_scale=1.0, # Custom softmax scaling
)
```

### Environment Variables

- `FORCE_MHA`: Set to "true", "1", or "on" to force using MHA implementation even for GQA cases
- `FLASH_ATTN_BLOCK_PTR`: set to "1" to use `tl.make_block_ptr` for accessing pointer in fwd mode (better for H100/H200 GPUs)
- `"GPU_IDX_FLASH_ATTN"` to define GPU INDEX force for computing triton attention
- `"CACHE_TRITON_KERNELS"` whenever to cache triton kernels (`defualt true`)
- `"_JAX_TRITON_DUMP_DIR"` path to save triton kernels
- `"BLOCKSIZE_M_FLASH_ATTN"` block size q seq length for backward
- `"BLOCKSIZE_N_FLASH_ATTN"` block size kv seq length for backward

## Performance Tips

1. **Block Sizes**: Default block sizes (128) work well for most cases, but you might want to tune them for your specific hardware and model architecture.

2. **Platform Selection**:

1. **Platform Selection**:
- For NVIDIA GPUs: prefer `triton`
- For TPUs: prefer `pallas`
- For CPU or fallback: use `jax`

3. **Caching**: The `get_cached_flash_attention` function automatically caches instances based on parameters. No need to manage caching manually.
2. **Caching**: The `get_cached_flash_attention` function automatically caches instances based on parameters. No need to manage caching manually.

## Requirements

Expand Down Expand Up @@ -144,6 +146,23 @@ If you use this implementation in your research, please cite:
url = {https://github.com/erfanzar/jax-flash-attn2}
}
```
### refrence citations

```bibtex
@inproceedings{dao2022flashattention,
title={Flash{A}ttention: Fast and Memory-Efficient Exact Attention with {IO}-Awareness},
author={Dao, Tri and Fu, Daniel Y. and Ermon, Stefano and Rudra, Atri and R{\'e}, Christopher},
booktitle={Advances in Neural Information Processing Systems (NeurIPS)},
year={2022}
}
@inproceedings{dao2023flashattention2,
title={Flash{A}ttention-2: Faster Attention with Better Parallelism and Work Partitioning},
author={Dao, Tri},
booktitle={International Conference on Learning Representations (ICLR)},
year={2024}
}
```

## Acknowledgments And Refrences

1. This implementation (MHA) is based on:
Expand Down

0 comments on commit 4dd027c

Please sign in to comment.