Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

How to implement Bidirectional Alibi with padding using flex attention? #74

Open
sphmel opened this issue Nov 7, 2024 · 4 comments
Open

Comments

@sphmel
Copy link

sphmel commented Nov 7, 2024

Hi, I want to use FlexAttention for alibi with padding(no bias)

If seq_len is 5 I want to make alibi tensor like below, which is alibi tensor with seq_len, and last item is not penalized

0 -1 -2 -3 0
-1 0 -1 -2 0
-2 -1 0 -1 0
-3 -2 -1 0 0
0 0 0 0 0

How can I implement score mod like this? seq_len can be different every forward. Such alibi is used in Voicebox paper. I'm new to BatchedTensor or maybe vmap API? I do not know how to implement it at all. Can you help me?

@sphmel
Copy link
Author

sphmel commented Nov 7, 2024

q_idx - kv_idx wiil make tensor below, but i want to last row and column is not biased

0 -1 -2 -3 -4
-1 0 -1 -2 -3
-2 -1 0 -1 -2
-3 -2 -1 0 -1
-4 -3 -2 -1 0

@Chillee
Copy link
Contributor

Chillee commented Nov 11, 2024

@sphmel I think this should work.

def score_mod(score, b, h, q_idx, kv_idx):
    bias = (q_idx - kv_idx)
    bias = torch.where((q_idx == seq_len - 1) || (kv_idx == seq_len - 1), bias, 0)
    return score + bias

@chinsengi
Copy link

q_idx - kv_idx wiil make tensor below, but i want to last row and column is not biased

0 -1 -2 -3 -4
-1 0 -1 -2 -3
-2 -1 0 -1 -2
-3 -2 -1 0 -1
-4 -3 -2 -1 0

I don't understand, shouldn't q_idx - kv_idx give you an skew-symmetric matrix rather than a symmetric matrix?

@Chillee
Copy link
Contributor

Chillee commented Nov 25, 2024

@chinsengi I assumed it was just a typo and they meant for the lower-triangular part to be positive.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants