diff --git a/flash_attn/flash_attn_interface.py b/flash_attn/flash_attn_interface.py index 4f55f0c57..6241d7bf0 100644 --- a/flash_attn/flash_attn_interface.py +++ b/flash_attn/flash_attn_interface.py @@ -60,9 +60,15 @@ def _flash_attn_varlen_forward(q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q def _flash_attn_backward(dout, q, k, v, out, softmax_lse, dq, dk, dv, dropout_p, softmax_scale, causal, rng_state=None): - maybe_contiguous = lambda x: x.contiguous() if x.stride(-1) != 1 else x + maybe_contiguous = lambda x: x.contiguous() if not x.is_contiguous() else x # dq, dk, dv are allocated by us so they should already be contiguous dout, q, k, v, out = [maybe_contiguous(x) for x in (dout, q, k, v, out)] + + if out.stride() != dout.stride(): + out = out.as_strided(dout.size(),dout.stride()) + if dq.stride() != q.stride(): + dq = dq.as_strided(q.size(),q.stride()) + dq, dk, dv, softmax_d, = flash_attn_cuda.bwd( dout, q, k, v, out, softmax_lse, dq, dk, dv, dropout_p, softmax_scale, causal, None, rng_state @@ -73,7 +79,7 @@ def _flash_attn_backward(dout, q, k, v, out, softmax_lse, dq, dk, dv, def _flash_attn_varlen_backward(dout, q, k, v, out, softmax_lse, dq, dk, dv, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, dropout_p, softmax_scale, causal, rng_state=None): - maybe_contiguous = lambda x: x.contiguous() if x.stride(-1) != 1 else x + maybe_contiguous = lambda x: x.contiguous() if not x.is_contiguous() else x # dq, dk, dv are allocated by us so they should already be contiguous dout, q, k, v, out = [maybe_contiguous(x) for x in (dout, q, k, v, out)] dq, dk, dv, softmax_d, = flash_attn_cuda.varlen_bwd(