Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix stride issues in flash_attn_interface #58

Open
wants to merge 2 commits into
base: flash_attention_for_rocm
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 8 additions & 2 deletions flash_attn/flash_attn_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,9 +60,15 @@ def _flash_attn_varlen_forward(q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q

def _flash_attn_backward(dout, q, k, v, out, softmax_lse, dq, dk, dv,
dropout_p, softmax_scale, causal, rng_state=None):
maybe_contiguous = lambda x: x.contiguous() if x.stride(-1) != 1 else x
maybe_contiguous = lambda x: x.contiguous() if not x.is_contiguous() else x
# dq, dk, dv are allocated by us so they should already be contiguous
dout, q, k, v, out = [maybe_contiguous(x) for x in (dout, q, k, v, out)]

if out.stride() != dout.stride():
out = out.as_strided(dout.size(),dout.stride())
if dq.stride() != q.stride():
dq = dq.as_strided(q.size(),q.stride())

dq, dk, dv, softmax_d, = flash_attn_cuda.bwd(
dout, q, k, v, out, softmax_lse, dq, dk, dv, dropout_p,
softmax_scale, causal, None, rng_state
Expand All @@ -73,7 +79,7 @@ def _flash_attn_backward(dout, q, k, v, out, softmax_lse, dq, dk, dv,
def _flash_attn_varlen_backward(dout, q, k, v, out, softmax_lse, dq, dk, dv,
cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k,
dropout_p, softmax_scale, causal, rng_state=None):
maybe_contiguous = lambda x: x.contiguous() if x.stride(-1) != 1 else x
maybe_contiguous = lambda x: x.contiguous() if not x.is_contiguous() else x
# dq, dk, dv are allocated by us so they should already be contiguous
dout, q, k, v, out = [maybe_contiguous(x) for x in (dout, q, k, v, out)]
dq, dk, dv, softmax_d, = flash_attn_cuda.varlen_bwd(
Expand Down