diff --git a/lib/nnc/mfa/v2/AttentionDescriptor.cpp b/lib/nnc/mfa/v2/AttentionDescriptor.cpp index 7f67abdfd..5f98864fa 100644 --- a/lib/nnc/mfa/v2/AttentionDescriptor.cpp +++ b/lib/nnc/mfa/v2/AttentionDescriptor.cpp @@ -9,6 +9,8 @@ bool AttentionDescriptor::operator==(const AttentionDescriptor& rhs) const { batchDimension == rhs.batchDimension && Hq == rhs.Hq && Hk == rhs.Hk && + scale == rhs.scale && + type == rhs.type && (lowPrecisionInputs == rhs.lowPrecisionInputs) && (lowPrecisionIntermediates == rhs.lowPrecisionIntermediates) && simd_all(leadingDimensions.value_or(simd::uint4(UINT32_MAX)) == rhs.leadingDimensions.value_or(simd::uint4(UINT32_MAX))) && diff --git a/lib/nnc/mfa/v2/AttentionKernelDescriptor.cpp b/lib/nnc/mfa/v2/AttentionKernelDescriptor.cpp index d15385739..adaaae98b 100644 --- a/lib/nnc/mfa/v2/AttentionKernelDescriptor.cpp +++ b/lib/nnc/mfa/v2/AttentionKernelDescriptor.cpp @@ -16,7 +16,8 @@ bool AttentionKernelDescriptor::operator==(const AttentionKernelDescriptor& rhs) registerPrecisions == rhs.registerPrecisions && transposeState == rhs.transposeState && leadingDimensions == rhs.leadingDimensions && - type == rhs.type; + type == rhs.type &&; + scale == rhs.scale; } std::size_t std::hash::operator()(const AttentionKernelDescriptor& hash) const noexcept { diff --git a/lib/nnc/mfa/v2/AttentionOperand.hpp b/lib/nnc/mfa/v2/AttentionOperand.hpp index b62e586e3..7f49896c8 100644 --- a/lib/nnc/mfa/v2/AttentionOperand.hpp +++ b/lib/nnc/mfa/v2/AttentionOperand.hpp @@ -138,10 +138,52 @@ struct AttentionOperands { constexpr AttentionOperands() : bitmap(0) {} constexpr bool operator==(const AttentionOperands& rhs) const { - return Q == rhs.Q && K == rhs.K && S == rhs.S && P == rhs.P && V == rhs.V && O == rhs.O && - L == rhs.L && D == rhs.D && - dO == rhs.dO && dV == rhs.dV && dP == rhs.dP && dS == rhs.dS && dK == rhs.dK && dQ == rhs.dQ && - bitmap == bitmap; + if (bitmap != rhs.bitmap) { + return false; + } + if (bitmap & (1 << (AttentionOperand::Q)) && Q != rhs.Q) { + return false; + } + if (bitmap & (1 << (AttentionOperand::K)) && K != rhs.K) { + return false; + } + if (bitmap & (1 << (AttentionOperand::S)) && S != rhs.S) { + return false; + } + if (bitmap & (1 << (AttentionOperand::P)) && P != rhs.P) { + return false; + } + if (bitmap & (1 << (AttentionOperand::V)) && V != rhs.V) { + return false; + } + if (bitmap & (1 << (AttentionOperand::O)) && O != rhs.O) { + return false; + } + if (bitmap & (1 << (AttentionOperand::L)) && L != rhs.L) { + return false; + } + if (bitmap & (1 << (AttentionOperand::D)) && D != rhs.D) { + return false; + } + if (bitmap & (1 << (AttentionOperand::dO)) && dO != rhs.dO) { + return false; + } + if (bitmap & (1 << (AttentionOperand::dV)) && dV != rhs.dV) { + return false; + } + if (bitmap & (1 << (AttentionOperand::dP)) && dP != rhs.dP) { + return false; + } + if (bitmap & (1 << (AttentionOperand::dS)) && dS != rhs.dS) { + return false; + } + if (bitmap & (1 << (AttentionOperand::dK)) && dK != rhs.dK) { + return false; + } + if (bitmap & (1 << (AttentionOperand::dQ)) && dQ != rhs.dQ) { + return false; + } + return true; } class Reference {