diff --git a/Project.toml b/Project.toml index 9733424..548eeb4 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "NNop" uuid = "eeb6ee5c-f953-4f60-8482-00c4fb7bc198" authors = ["Anton Smirnov "] -version = "0.2.0" +version = "0.2.1" [deps] ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" diff --git a/src/attention.jl b/src/attention.jl index 270cd7d..dbe5815 100644 --- a/src/attention.jl +++ b/src/attention.jl @@ -82,21 +82,40 @@ end l_ij = zero(T) - @unroll for i in 1:gsz - (in_seq_bounds || k_offset + i ≤ size(k, 2)) || break - tmp = exp(s_shm[tidx, i] - m_ij) - l_ij += tmp - s_shm[tidx, i] = tmp + if m_ij == typemin(T) + # If all considered logits are effectively -Inf for this row, + # avoid exp(-Inf - -Inf) = NaN and produce a zero-probability tile. + @unroll for i in 1:gsz + (in_seq_bounds || k_offset + i ≤ size(k, 2)) || break + s_shm[tidx, i] = zero(T) + end + else + @unroll for i in 1:gsz + (in_seq_bounds || k_offset + i ≤ size(k, 2)) || break + tmp = exp(s_shm[tidx, i] - m_ij) + l_ij += tmp + s_shm[tidx, i] = tmp + end end @synchronize() m_i_new = max(m_i, m_ij) - α = exp(m_i - m_i_new) - β = exp(m_ij - m_i_new) - l_i_new = α * l_i + β * l_ij + if l_i == zero(T) && l_ij == zero(T) + # No valid keys encountered so far and none in this tile; + # keep running sums unchanged without introducing NaNs. + α = one(T) + β = zero(T) + l_i_new = zero(T) + else + α = exp(m_i - m_i_new) + β = exp(m_ij - m_i_new) + l_i_new = α * l_i + β * l_ij + end - p_scale = β / l_i_new - o_scale = l_i / l_i_new * α + # Guard against 0/0 when there are no valid keys seen so far. + den = l_i_new + p_scale = den == zero(T) ? zero(T) : β / den + o_scale = den == zero(T) ? one(T) : (l_i / den * α) @unroll for i in 1:gsz s_shm[tidx, i] *= p_scale diff --git a/test/Project.toml b/test/Project.toml index 3602059..6dbc057 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -1,5 +1,4 @@ [deps] -AMDGPU = "21141c5a-9bdb-4563-92ae-f87d6854732e" Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd" Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"