Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Sageattention1.0 runs slower than FA2 on A100. #70

Open
foreverpiano opened this issue Dec 13, 2024 · 12 comments
Open

Sageattention1.0 runs slower than FA2 on A100. #70

foreverpiano opened this issue Dec 13, 2024 · 12 comments

Comments

@foreverpiano
Copy link

foreverpiano commented Dec 13, 2024

benchmark code:

import torch
import torch.nn as nn
import torch.nn.functional as F
import time
from typing import Optional
from torch.nn.attention import sdpa_kernel, SDPBackend

backends = []
backends.append(SDPBackend.CUDNN_ATTENTION)
backends.append(SDPBackend.EFFICIENT_ATTENTION)
backends.append(SDPBackend.MATH)

# import attention implementation
try:
    from flash_attn import flash_attn_func
    FLASH_ATTN_AVAILABLE = True
except ImportError:
    FLASH_ATTN_AVAILABLE = False
    print("flash_attn not available")

try:
    from sageattention import sageattn
    SAGE_ATTN_AVAILABLE = True
except ImportError:
    SAGE_ATTN_AVAILABLE = False
    print("sage_attn not available")

def run_attention(
    query: torch.Tensor,  # (B, H, N, D)
    key: torch.Tensor,    # (B, H, N, D) 
    value: torch.Tensor,  # (B, H, N, D)
    attention_mode: str = "sdpa",
    device: Optional[torch.device] = None,
    softmax_scale: Optional[float] = None
) -> torch.Tensor:
    """
    Run attention using specified backend.
    
    Args:
        query: Query tensor (batch_size, num_heads, seq_len, head_dim)
        key: Key tensor (batch_size, num_heads, seq_len, head_dim)
        value: Value tensor (batch_size, num_heads, seq_len, head_dim)
        attention_mode: One of ["flash_attn", "sdpa", "sage_attn"]
        device: torch device
        softmax_scale: Scale factor for attention scores
        
    Returns:
        out: Output tensor (batch_size, seq_len, hidden_dim)
    """
    b, num_heads, seq_len, dim_head = query.shape
    
    
    with torch.autocast(device.type if device else "cuda", enabled=False)  :
        if attention_mode == "flash_attn" and FLASH_ATTN_AVAILABLE:
            # Flash Attention expects (B, N, H, D) format
            query = query.permute(0, 2, 1, 3)
            key = key.permute(0, 2, 1, 3)
            value = value.permute(0, 2, 1, 3)
            
            out = flash_attn_func(
                query, key, value,
                dropout_p=0.0,
                softmax_scale=softmax_scale,
            )
            out = out.permute(0, 2, 1, 3)
            
        elif attention_mode == "sdpa":
            with sdpa_kernel(backends):
                out = F.scaled_dot_product_attention(
                    query, key, value,
                    attn_mask=None,
                    dropout_p=0.0,
                    is_causal=False
                )
                
        elif attention_mode == "sage_attn" and SAGE_ATTN_AVAILABLE:
            out = sageattn(
                query, key, value,
                attn_mask=None,
                dropout_p=0.0,
                is_causal=False
            )
        else:
            raise ValueError(f"Attention mode {attention_mode} not available")
            
        # Reshape output to (B, N, H*D) format
        out = out.transpose(1, 2).reshape(b, seq_len, num_heads * dim_head)
            
        return out

def profile_attention(batch_size=2, seq_len=8192, num_heads=24, head_dim=128, num_rounds=20):
    """Profile different attention implementations"""
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
    shape = (batch_size, num_heads, seq_len, head_dim)
    query = torch.randn(shape, device=device, dtype=torch.float16)
    key = torch.randn(shape, device=device, dtype=torch.float16)
    value = torch.randn(shape, device=device, dtype=torch.float16)
    
    # Calculate theoretical memory usage
    element_size = query.element_size()  
    total_elements = 3 * batch_size * num_heads * seq_len * head_dim 
    attention_matrix_size = batch_size * num_heads * seq_len * seq_len 
    memory_qkv = total_elements * element_size / (1024 ** 3) 
    memory_attention = attention_matrix_size * element_size / (1024 ** 3)
    
    print(f"\nTheoretical Memory Usage:")
    print(f"QKV Tensors: {memory_qkv:.2f} GB")
    print(f"Attention Matrix (for non-memory-efficient implementations): {memory_attention:.2f} GB")
    
    # Warmup
    print("\nWarming up...")
    for _ in range(3):
        for mode in ["sdpa", "flash_attn", "sage_attn"]:
            try:
                _ = run_attention(query, key, value, 
                                attention_mode=mode,
                                device=device)
                torch.cuda.synchronize()
            except Exception as e:
                print(f"Warmup {mode} failed: {str(e)}")
                continue
    
    # Profile each implementation
    print(f"\nProfiling with shape {shape}...")
    results = {}
    
    for mode in ["sdpa", "flash_attn", "sage_attn"]:
        try:
            torch.cuda.reset_peak_memory_stats()
            torch.cuda.empty_cache()
            
            times = []
            torch.cuda.synchronize()
            
            for _ in range(num_rounds):
                start = time.perf_counter()
                _ = run_attention(query, key, value,
                                attention_mode=mode,
                                device=device)
                torch.cuda.synchronize()
                times.append(time.perf_counter() - start)
            
            avg_time = sum(times) / len(times)
            std_time = torch.tensor(times).std().item()
            peak_memory = torch.cuda.max_memory_allocated() / (1024 ** 3)  # Convert to GB
            
            results[mode] = {
                "avg_time": avg_time * 1000,  # Convert to ms
                "std_time": std_time * 1000,
                "peak_memory": peak_memory
            }
            
        except Exception as e:
            print(f"{mode} failed: {str(e)}")
            continue
    
    # Print results
    print("\nResults:")
    print("-" * 70)
    print(f"{'Implementation':<15} {'Avg Time (ms)':>15} {'Std Dev (ms)':>15} {'Peak Memory (GB)':>20}")
    print("-" * 70)
    for mode, stats in results.items():
        print(f"{mode:<15} {stats['avg_time']:>15.3f} {stats['std_time']:>15.3f} {stats['peak_memory']:>20.3f}")

if __name__ == "__main__":
    batch_size, seq_len, num_heads, head_dim = 2, 8192, 24, 128
    print(f"\nTesting batch_size={batch_size}, seq_len={seq_len}, "
          f"num_heads={num_heads}, head_dim={head_dim}")
    profile_attention(batch_size, seq_len, num_heads, head_dim)

result:

Results:
----------------------------------------------------------------------
Implementation    Avg Time (ms)    Std Dev (ms)     Peak Memory (GB)
----------------------------------------------------------------------
sdpa                      8.741           0.348                0.469
flash_attn                7.634           0.063                0.469
sage_attn                 8.869           0.082                0.469

version

torch==2.5.0+cu121
sageattention==1.0.6
flash-attn==2.7.0.post2
cuda-12.1
python 3.10.15
A100-SXM-80GB.

cc @jt-zhang @jason-huang03
Any wrong or improvement of my code?

@jt-zhang
Copy link
Member

Thank you for reaching out. You need to use the SageAttention2.0 in A100 GPUs.

@jt-zhang
Copy link
Member

Also, SageAttention2.0 requires the CUDA version to be >= 12.4.

@foreverpiano
Copy link
Author

Got it.

@nighting0le01
Copy link

i still don't see significant speedup on A100 @jt-zhang with CUDA 12.4 on A100

@jt-zhang
Copy link
Member

jt-zhang commented Dec 15, 2024

@nighting0le01 Could you please elaborate on the details?
Such as the code, the version of sageattention and the shape of q,k,v.

@jt-zhang jt-zhang changed the title Sageattention runs much slower than FA2 on A100. Sageattention runs slower than FA2 on A100. Dec 15, 2024
@jt-zhang jt-zhang changed the title Sageattention runs slower than FA2 on A100. Sageattention1.0 runs slower than FA2 on A100. Dec 16, 2024
@SJTU-yys
Copy link

SJTU-yys commented Dec 21, 2024

@nighting0le01 Could you please elaborate on the details? Such as the code, the version of sageattention and the shape of q,k,v.

Hi, I am trying to reproduce the speed improvements of sageattention2 on A100
I follow the readme and installed sageattention2 on my test workspace, ran the profiling script given by @foreverpiano,the speed of sage attention is still not as fast as flash attention
My Environment:

  • GPU A100-SXM-80GB
  • cuda 12.4
  • python 3.11
  • torch 2.3.1
  • triton 3.0.0
  • flash-attn 2.7.0.post2

Details of the input:
batch_size, seq_len, num_heads, head_dim = 8, 1024, 28, 128

And the result is
20241221-154943
@jt-zhang keen to hear your feedbacks

@jason-huang03
Copy link
Member

@SJTU-yys the sequence length in your test case is 1024, which is quite small. Can you try longer sequence lengths like 8k or 16k?

@SJTU-yys
Copy link

@SJTU-yys the sequence length in your test case is 1024, which is quite small. Can you try longer sequence lengths like 8k or 16k?

I tried both 8192 and 16384 seq_len and get some improvements now
Below is the result for 8192
image

Below is the result for 16384
image

Is these results meets your expectation?

@jason-huang03
Copy link
Member

@SJTU-yys can you measure the result in TFlops?

@foreverpiano
Copy link
Author

@jason-huang03 @SJTU-yys I'm curious about Sage performance under large batch sizes and variable length (varlen) conditions. Do you have any relevant benchmark on this?

@lauthu
Copy link

lauthu commented Jan 13, 2025

@SJTU-yys the sequence length in your test case is 1024, which is quite small. Can you try longer sequence lengths like 8k or 16k?

Hi @jason-huang03, I was wondering if Sage Attention 2 is consistently faster than Flash Attention 2. We have a scenario involving a 0.5B model running on an A100 MIG instance, with a relatively small sequence length and batch size (approximately batch size 8 and sequence length 384). Would it be possible to use Sage Attention to achieve better performance in this case?

Thank you for your assistance!

@jason-huang03
Copy link
Member

jason-huang03 commented Jan 13, 2025

@SJTU-yys the sequence length in your test case is 1024, which is quite small. Can you try longer sequence lengths like 8k or 16k?

Hi @jason-huang03, I was wondering if Sage Attention 2 is consistently faster than Flash Attention 2. We have a scenario involving a 0.5B model running on an A100 MIG instance, with a relatively small sequence length and batch size (approximately batch size 8 and sequence length 384). Would it be possible to use Sage Attention to achieve better performance in this case?

Thank you for your assistance!

@lauthu
I don't think it will give mesurable benefit. In your use case, the sequence length is soo small so that the latency is dominated by linear layer and other part of the model.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

6 participants