-
-
Notifications
You must be signed in to change notification settings - Fork 4
Open
Description
Hi! Thanks for making this.
I understand this package is still experimental, but wanted to document this constraint, and also ask if anyone knows whether e.g. head dim of 128 would be difficult to implement?
julia> x = CUDA.rand(64, 4096, 4, 4); Jjama3.NNop.flash_attention(x, x, x; causal=true);
julia> x = CUDA.rand(128, 4096, 4, 4); NNop.flash_attention(x, x, x; causal=true);
ERROR: Failed to find groupsize for Flash Attention that satisfies Shared Memory constraint.
Stacktrace:
[1] error(s::String)
@ Base ./error.jl:35
[2] #flash_attention_groupsize#8
@ ~/.julia/packages/NNop/cRsoL/src/attention.jl:186 [inlined]
[3] flash_attention_groupsize
@ ~/.julia/packages/NNop/cRsoL/src/attention.jl:175 [inlined]
[4] _flash_attention(q::CuArray{Float32, 4, CUDA.DeviceMemory}, k::CuArray{Float32, 4, CUDA.DeviceMemory}, v::CuArray{Float32, 4, CUDA.DeviceMemory}; causal::Bool)
@ NNop ~/.julia/packages/NNop/cRsoL/src/attention.jl:127
[5] _flash_attention
@ ~/.julia/packages/NNop/cRsoL/src/attention.jl:113 [inlined]
[6] #flash_attention#18
@ ~/.julia/packages/NNop/cRsoL/src/attention_crc.jl:5 [inlined]
[7] top-level scope
@ REPL[121]:1Metadata
Metadata
Assignees
Labels
No labels