Skip to content

Conversation

@AntonOresten
Copy link
Contributor

We encountered NaNs when kpad_mask had a lot of zeros. The solution in this PR was somewhat vibe-coded, so may or may not be the simplest solution, but it seems to work.

@AntonOresten
Copy link
Contributor Author

AntonOresten commented Nov 8, 2025

Not sure about this one tbh. Preferably, there'd be a sequence lengths vector of length B passed instead, where we'd skip trailing tiles entirely. Fairly straight-forward for keys, not sure about queries.

@AntonOresten AntonOresten marked this pull request as draft November 8, 2025 11:04
@AntonOresten
Copy link
Contributor Author

We saw NaNs again specifically on a branch without these changes, just for the record.

@pxl-th
Copy link
Member

pxl-th commented Dec 4, 2025

Can you share MWE?

@AntonOresten
Copy link
Contributor Author

AntonOresten commented Dec 4, 2025

Show groupsize
julia> @eval NNop begin
       function flash_attention_groupsize(::Type{T}; emb_dim::Int, target_shmem::UInt64) where T
           # TODO
           # - return `qk_fp16` to configure kernel
           # - optional qk_fp16
           # qk_fp16s = (false, true)
           # TODO prefer bigger groupsize?
           qk_fp16s = (true,)
           for qk_fp16 in qk_fp16s, groupsize in (256, 128, 64, 32, 16)
               shmem = flash_attention_shmem_bwd(T; emb_dim, groupsize, qk_fp16)
               shmem  target_shmem && begin @show groupsize; return groupsize end
           end
           error("Failed to find groupsize for Flash Attention that satisfies Shared Memory constraint.")
       end
       end
flash_attention_groupsize (generic function with 1 method)
julia> begin
           H, L = 64, 64
           pad = 32
           x = CUDA.rand(H, L, 1, 1);
           kpad_mask = CuArray([trues(L-pad); falses(pad);;]);
           any(isnan, NNop.flash_attention(x, x, x; causal=false, kpad_mask))
       end
groupsize = 32
true

julia> begin
           H, L = 64, 64
           pad = 31
           x = CUDA.rand(H, L, 1, 1);
           kpad_mask = CuArray([trues(L-pad); falses(pad);;]);
           any(isnan, NNop.flash_attention(x, x, x; causal=false, kpad_mask))
       end
groupsize = 32
false

When L is not a multiple of groupsize: L = 65 and trying pad = {0, 1} also gives NaNs

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants