Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

flexattn with qwen2 #81

Open
NonvolatileMemory opened this issue Nov 18, 2024 · 4 comments
Open

flexattn with qwen2 #81

NonvolatileMemory opened this issue Nov 18, 2024 · 4 comments

Comments

@NonvolatileMemory
Copy link

seems flexattn cannot support numheads=28?

@drisspg
Copy link
Contributor

drisspg commented Nov 18, 2024

Do you have a repro? I just tried this and it appears to be working for me. Notably, I'm on Nightly version of pytorch

import torch

from torch.nn.attention.flex_attention import flex_attention, create_block_mask


def causal_mask(b, h, q_idx, kv_idx):
   return q_idx >= kv_idx


b, h, s, d = 1, 28, 256, 64
tens = torch.rand(b, h, s, d, device="cuda")

flex = torch.compile(flex_attention)

bm = create_block_mask(causal_mask, None, None, s, s)

print(flex(tens, tens, tens, block_mask=bm))

@NonvolatileMemory
Copy link
Author

Hi!

Here is my code

def diff(bsz=4, seq_len=1024, d_head=128, num_heads=28, block_size=4):
    # torch_attn

    Q = torch.randn(bsz, num_heads, seq_len, d_head)#.cuda()
    K = torch.randn(bsz, 4, seq_len, d_head)#.cuda()
    V = torch.randn(bsz, 4, seq_len, d_head)#.cuda()

    scores = torch.matmul(Q, K.permute(0, 1, 3, 2)) / (Q.size(-1) ** 0.5)

    q_idx = torch.arange(seq_len).view(-1, 1)
    kv_idx = torch.arange(seq_len).view(1, -1)
    mask = torch_mask(q_idx, kv_idx, block_size)[None, None, :, :].cuda()

    # scores = scores.masked_fill(~mask, float('-inf'))
    # attn_weights = F.softmax(scores, dim=-1)
    # torch_out = torch.matmul(attn_weights, V)
    sub_block_mask = create_block_mask(block_mask, B=None, H=None, Q_LEN=seq_len, KV_LEN=seq_len,  _compile=True)
    flex_out = flex_attn(Q, K, V, block_mask=sub_block_mask, enable_gqa=True)
    return flex_out
    # return (flex_out[:, :, 16:] - torch_out[:, :, 16:]).max()
    
def block_mask(b, h, q_idx, kv_idx):
    q_block = q_idx // 4
    kv_block = kv_idx // 4
    return q_block > kv_block
    ```

@NonvolatileMemory
Copy link
Author

Do you have a repro? I just tried this and it appears to be working for me. Notably, I'm on Nightly version of pytorch

import torch

from torch.nn.attention.flex_attention import flex_attention, create_block_mask


def causal_mask(b, h, q_idx, kv_idx):
   return q_idx >= kv_idx


b, h, s, d = 1, 28, 256, 64
tens = torch.rand(b, h, s, d, device="cuda")

flex = torch.compile(flex_attention)

bm = create_block_mask(causal_mask, None, None, s, s)

print(flex(tens, tens, tens, block_mask=bm))

Maybe because I am using the 2.5.0 ver of torch instead of nightly?

@drisspg
Copy link
Contributor

drisspg commented Nov 20, 2024

Yeah, potentially. Would you mind trying nightly?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants