Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

hotfix(message/validation): optimize signer state memory usage #1874

Open
wants to merge 20 commits into
base: stage
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions message/validation/common_checks.go
Original file line number Diff line number Diff line change
Expand Up @@ -140,3 +140,16 @@ func (mv *messageValidator) validateBeaconDuty(

return nil
}

func (mv *messageValidator) signerIndexInCommittee(signer spectypes.OperatorID, committee []spectypes.OperatorID) int {
// Probably converting committee to a map would be faster, but since committee size is <13, this should be okay.
// Although, we need to change it if it appears in pprof output.
for i, oid := range committee {
if oid == signer {
return i
}
}

mv.logger.Panic(fmt.Sprintf("signer %v must be in committee %v", signer, committee))
panic("unreachable") // fix compilation issue
nkryuchkov marked this conversation as resolved.
Show resolved Hide resolved
}
17 changes: 5 additions & 12 deletions message/validation/consensus_state.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,31 +4,24 @@ import (
"sync"

"github.com/attestantio/go-eth2-client/spec/phase0"
spectypes "github.com/ssvlabs/ssv-spec/types"
)

// consensusID uniquely identifies a public key and role pair to keep track of state.
type consensusID struct {
DutyExecutorID string
Role spectypes.RunnerRole
}

// consensusState keeps track of the signers for a given public key and role.
type consensusState struct {
state map[spectypes.OperatorID]*OperatorState
state []*OperatorState
storedSlotCount phase0.Slot
mu sync.Mutex
}

func (cs *consensusState) GetOrCreate(signer spectypes.OperatorID) *OperatorState {
func (cs *consensusState) GetOrCreate(idx int) *OperatorState {
cs.mu.Lock()
defer cs.mu.Unlock()

if _, ok := cs.state[signer]; !ok {
cs.state[signer] = newOperatorState(cs.storedSlotCount)
if cs.state[idx] == nil {
cs.state[idx] = newOperatorState(cs.storedSlotCount)
}

return cs.state[signer]
return cs.state[idx]
}

type OperatorState struct {
Expand Down
91 changes: 47 additions & 44 deletions message/validation/consensus_validation.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ package validation
import (
"bytes"
"crypto/sha256"
"encoding/binary"
"encoding/hex"
"fmt"
"time"
Expand Down Expand Up @@ -44,17 +43,17 @@ func (mv *messageValidator) validateConsensusMessage(

mv.metrics.ConsensusMsgType(consensusMessage.MsgType, len(signedSSVMessage.OperatorIDs))

if err := mv.validateConsensusMessageSemantics(signedSSVMessage, consensusMessage, committeeInfo.operatorIDs); err != nil {
if err := mv.validateConsensusMessageSemantics(signedSSVMessage, consensusMessage, committeeInfo.committee); err != nil {
return consensusMessage, err
}

state := mv.consensusState(signedSSVMessage.SSVMessage.GetID())
state := mv.consensusState(signedSSVMessage.SSVMessage.GetID(), committeeInfo.committee)

if err := mv.validateQBFTLogic(signedSSVMessage, consensusMessage, committeeInfo.operatorIDs, receivedAt, state); err != nil {
if err := mv.validateQBFTLogic(signedSSVMessage, consensusMessage, committeeInfo.committee, receivedAt, state); err != nil {
return consensusMessage, err
}

if err := mv.validateQBFTMessageByDutyLogic(signedSSVMessage, consensusMessage, committeeInfo.indices, receivedAt, state); err != nil {
if err := mv.validateQBFTMessageByDutyLogic(signedSSVMessage, consensusMessage, committeeInfo, receivedAt, state); err != nil {
return consensusMessage, err
}

Expand All @@ -69,7 +68,7 @@ func (mv *messageValidator) validateConsensusMessage(
}
}

if err := mv.updateConsensusState(signedSSVMessage, consensusMessage, state); err != nil {
if err := mv.updateConsensusState(signedSSVMessage, consensusMessage, committeeInfo.committee, state); err != nil {
return consensusMessage, err
}

Expand Down Expand Up @@ -192,7 +191,7 @@ func (mv *messageValidator) validateQBFTLogic(

msgSlot := phase0.Slot(consensusMessage.Height)
for _, signer := range signedSSVMessage.OperatorIDs {
signerStateBySlot := state.GetOrCreate(signer)
signerStateBySlot := state.GetOrCreate(mv.signerIndexInCommittee(signer, committee))
signerState := signerStateBySlot.Get(msgSlot)
if signerState == nil {
continue
Expand All @@ -212,26 +211,28 @@ func (mv *messageValidator) validateQBFTLogic(

if consensusMessage.Round == signerState.Round {
// Rule: Peer must not send two proposals with different data
if len(signedSSVMessage.FullData) != 0 && signerState.ProposalData != nil && !bytes.Equal(signerState.ProposalData, signedSSVMessage.FullData) {
return ErrDifferentProposalData
if len(signedSSVMessage.FullData) != 0 && signerState.HashedProposalData != nil {
if *signerState.HashedProposalData != sha256.Sum256(signedSSVMessage.FullData) {
return ErrDifferentProposalData
}
}

// Rule: Peer must send only 1 proposal, 1 prepare, 1 commit, and 1 round-change per round
limits := maxMessageCounts()
if err := signerState.MessageCounts.ValidateConsensusMessage(signedSSVMessage, consensusMessage, limits); err != nil {
if err := signerState.SeenMsgTypes.ValidateConsensusMessage(signedSSVMessage, consensusMessage); err != nil {
return err
}
}
} else if len(signedSSVMessage.OperatorIDs) > 1 {
// Rule: Decided msg can't have the same signers as previously sent before for the same duty
encodedOperators, err := encodeOperators(signedSSVMessage.OperatorIDs)
if err != nil {
return err
quorum := Quorum{
Signers: signedSSVMessage.OperatorIDs,
Committee: committee,
}

// Rule: Decided msg can't have the same signers as previously sent before for the same duty
if _, ok := signerState.SeenSigners[encodedOperators]; ok {
return ErrDecidedWithSameSigners
if signerState.SeenSigners != nil {
if _, ok := signerState.SeenSigners[quorum.ToBitMask()]; ok {
return ErrDecidedWithSameSigners
}
}
}
}
Expand All @@ -249,7 +250,7 @@ func (mv *messageValidator) validateQBFTLogic(
func (mv *messageValidator) validateQBFTMessageByDutyLogic(
signedSSVMessage *spectypes.SignedSSVMessage,
consensusMessage *specqbft.Message,
validatorIndices []phase0.ValidatorIndex,
committeeInfo CommitteeInfo,
receivedAt time.Time,
state *consensusState,
) error {
Expand All @@ -258,7 +259,7 @@ func (mv *messageValidator) validateQBFTMessageByDutyLogic(
// Rule: Height must not be "old". I.e., signer must not have already advanced to a later slot.
if role != spectypes.RoleCommittee { // Rule only for validator runners
for _, signer := range signedSSVMessage.OperatorIDs {
signerStateBySlot := state.GetOrCreate(signer)
signerStateBySlot := state.GetOrCreate(mv.signerIndexInCommittee(signer, committeeInfo.committee))
if maxSlot := signerStateBySlot.MaxSlot(); maxSlot > phase0.Slot(consensusMessage.Height) {
e := ErrSlotAlreadyAdvanced
e.got = consensusMessage.Height
Expand All @@ -269,7 +270,7 @@ func (mv *messageValidator) validateQBFTMessageByDutyLogic(
}

msgSlot := phase0.Slot(consensusMessage.Height)
if err := mv.validateBeaconDuty(role, msgSlot, validatorIndices); err != nil {
if err := mv.validateBeaconDuty(role, msgSlot, committeeInfo.indices); err != nil {
return err
}

Expand All @@ -285,21 +286,26 @@ func (mv *messageValidator) validateQBFTMessageByDutyLogic(
// - 2*V for Committee duty (where V is the number of validators in the cluster) (if no validator is doing sync committee in this epoch)
// - else, accept
for _, signer := range signedSSVMessage.OperatorIDs {
signerStateBySlot := state.GetOrCreate(signer)
if err := mv.validateDutyCount(signedSSVMessage.SSVMessage.GetID(), msgSlot, validatorIndices, signerStateBySlot); err != nil {
signerStateBySlot := state.GetOrCreate(mv.signerIndexInCommittee(signer, committeeInfo.committee))
if err := mv.validateDutyCount(signedSSVMessage.SSVMessage.GetID(), msgSlot, committeeInfo.indices, signerStateBySlot); err != nil {
return err
}
}

return nil
}

func (mv *messageValidator) updateConsensusState(signedSSVMessage *spectypes.SignedSSVMessage, consensusMessage *specqbft.Message, consensusState *consensusState) error {
func (mv *messageValidator) updateConsensusState(
signedSSVMessage *spectypes.SignedSSVMessage,
consensusMessage *specqbft.Message,
committee []spectypes.OperatorID,
consensusState *consensusState,
) error {
msgSlot := phase0.Slot(consensusMessage.Height)
msgEpoch := mv.netCfg.Beacon.EstimatedEpochAtSlot(msgSlot)

for _, signer := range signedSSVMessage.OperatorIDs {
stateBySlot := consensusState.GetOrCreate(signer)
stateBySlot := consensusState.GetOrCreate(mv.signerIndexInCommittee(signer, committee))
signerState := stateBySlot.Get(msgSlot)
if signerState == nil {
signerState = NewSignerState(phase0.Slot(consensusMessage.Height), consensusMessage.Round)
Expand All @@ -310,31 +316,39 @@ func (mv *messageValidator) updateConsensusState(signedSSVMessage *spectypes.Sig
}
}

if err := mv.processSignerState(signedSSVMessage, consensusMessage, signerState); err != nil {
if err := mv.processSignerState(signedSSVMessage, consensusMessage, committee, signerState); err != nil {
return err
}
}

return nil
}

func (mv *messageValidator) processSignerState(signedSSVMessage *spectypes.SignedSSVMessage, consensusMessage *specqbft.Message, signerState *SignerState) error {
func (mv *messageValidator) processSignerState(
signedSSVMessage *spectypes.SignedSSVMessage,
consensusMessage *specqbft.Message,
committee []spectypes.OperatorID,
signerState *SignerState,
) error {
if len(signedSSVMessage.FullData) != 0 && consensusMessage.MsgType == specqbft.ProposalMsgType {
signerState.ProposalData = signedSSVMessage.FullData
fullDataHash := sha256.Sum256(signedSSVMessage.FullData)
signerState.HashedProposalData = &fullDataHash
}

signerCount := len(signedSSVMessage.OperatorIDs)
if signerCount > 1 {
encodedOperators, err := encodeOperators(signedSSVMessage.OperatorIDs)
if err != nil {
// encodeOperators must never re
return ErrEncodeOperators
quorum := Quorum{
Signers: signedSSVMessage.OperatorIDs,
Committee: committee,
}

signerState.SeenSigners[encodedOperators] = struct{}{}
if signerState.SeenSigners == nil {
signerState.SeenSigners = make(map[SignersBitMask]struct{}) // lazy init on demand to reduce mem consumption
}
signerState.SeenSigners[quorum.ToBitMask()] = struct{}{}
}

return signerState.MessageCounts.RecordConsensusMessage(signedSSVMessage, consensusMessage)
return signerState.SeenMsgTypes.RecordConsensusMessage(signedSSVMessage, consensusMessage)
}

func (mv *messageValidator) validateJustifications(message *specqbft.Message) error {
Expand Down Expand Up @@ -451,14 +465,3 @@ func (mv *messageValidator) roundRobinProposer(height specqbft.Height, round spe
index := (firstRoundIndex + uint64(round) - uint64(specqbft.FirstRound)) % uint64(len(committee))
return committee[index]
}

func encodeOperators(operators []spectypes.OperatorID) ([sha256.Size]byte, error) {
buf := new(bytes.Buffer)
for _, operator := range operators {
if err := binary.Write(buf, binary.LittleEndian, operator); err != nil {
return [sha256.Size]byte{}, err
}
}
hash := sha256.Sum256(buf.Bytes())
return hash, nil
}
1 change: 0 additions & 1 deletion message/validation/errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,6 @@ var (
ErrDuplicatedMessage = Error{text: "message is duplicated", reject: true}
ErrInvalidPartialSignatureTypeCount = Error{text: "sent more partial signature messages of a certain type than allowed", reject: true}
ErrTooManyPartialSignatureMessages = Error{text: "too many partial signature messages", reject: true}
ErrEncodeOperators = Error{text: "encode operators", reject: true}
)

func (mv *messageValidator) handleValidationError(peerID peer.ID, decodedMessage *queue.SSVMessage, err error) pubsub.ValidationResult {
Expand Down
14 changes: 7 additions & 7 deletions message/validation/partial_validation.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ func (mv *messageValidator) validatePartialSignatureMessage(
}

msgID := ssvMessage.GetID()
state := mv.consensusState(msgID)
state := mv.consensusState(msgID, committeeInfo.committee)
if err := mv.validatePartialSigMessagesByDutyLogic(signedSSVMessage, partialSignatureMessages, committeeInfo, receivedAt, state); err != nil {
return nil, err
}
Expand All @@ -55,7 +55,7 @@ func (mv *messageValidator) validatePartialSignatureMessage(
return partialSignatureMessages, e
}

if err := mv.updatePartialSignatureState(partialSignatureMessages, state, signer); err != nil {
if err := mv.updatePartialSignatureState(partialSignatureMessages, state, signer, committeeInfo.committee); err != nil {
return nil, err
}

Expand Down Expand Up @@ -142,7 +142,7 @@ func (mv *messageValidator) validatePartialSigMessagesByDutyLogic(
role := signedSSVMessage.SSVMessage.GetID().GetRoleType()
messageSlot := partialSignatureMessages.Slot
signer := signedSSVMessage.OperatorIDs[0]
signerStateBySlot := state.GetOrCreate(signer)
signerStateBySlot := state.GetOrCreate(mv.signerIndexInCommittee(signer, committeeInfo.committee))

// Rule: Height must not be "old". I.e., signer must not have already advanced to a later slot.
if signedSSVMessage.SSVMessage.MsgID.GetRoleType() != types.RoleCommittee { // Rule only for validator runners
Expand All @@ -167,8 +167,7 @@ func (mv *messageValidator) validatePartialSigMessagesByDutyLogic(
// - 1 SelectionProofPartialSig and 1 PostConsensusPartialSig for Sync committee contribution
// - 1 ValidatorRegistrationPartialSig for Validator Registration
// - 1 VoluntaryExitPartialSig for Voluntary Exit
limits := maxMessageCounts()
if err := signerState.MessageCounts.ValidatePartialSignatureMessage(partialSignatureMessages, limits); err != nil {
if err := signerState.SeenMsgTypes.ValidatePartialSignatureMessage(partialSignatureMessages); err != nil {
return err
}
}
Expand Down Expand Up @@ -228,8 +227,9 @@ func (mv *messageValidator) updatePartialSignatureState(
partialSignatureMessages *spectypes.PartialSignatureMessages,
state *consensusState,
signer spectypes.OperatorID,
committee []spectypes.OperatorID,
) error {
stateBySlot := state.GetOrCreate(signer)
stateBySlot := state.GetOrCreate(mv.signerIndexInCommittee(signer, committee))
messageSlot := partialSignatureMessages.Slot
messageEpoch := mv.netCfg.Beacon.EstimatedEpochAtSlot(messageSlot)

Expand All @@ -239,7 +239,7 @@ func (mv *messageValidator) updatePartialSignatureState(
stateBySlot.Set(messageSlot, messageEpoch, signerState)
}

return signerState.MessageCounts.RecordPartialSignatureMessage(partialSignatureMessages)
return signerState.SeenMsgTypes.RecordPartialSignatureMessage(partialSignatureMessages)
}

func (mv *messageValidator) validPartialSigMsgType(msgType spectypes.PartialSigMsgType) bool {
Expand Down
58 changes: 58 additions & 0 deletions message/validation/quorum.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
package validation

import (
"fmt"
"math/bits"

spectypes "github.com/ssvlabs/ssv-spec/types"
)

// TODO: Take all of these from https://github.com/ssvlabs/ssv/pull/1867 once it's merged.
// This file is temporary to avoid the need to be based on another PR, hence there are no tests.

const maxCommitteeSize = 13

type Quorum struct {
Signers []spectypes.OperatorID
Committee []spectypes.OperatorID
}

func (q *Quorum) ToBitMask() SignersBitMask {
if len(q.Signers) > maxCommitteeSize || len(q.Committee) > maxCommitteeSize || len(q.Signers) > len(q.Committee) {
panic(fmt.Sprintf("invalid signers/committee size: %d/%d", len(q.Signers), len(q.Committee)))
}

bitmask := SignersBitMask(0)
i, j := 0, 0
for i < len(q.Signers) && j < len(q.Committee) {
if q.Signers[i] == q.Committee[j] {
bitmask |= 1 << uint(j) // #nosec G115 -- j cannot exceed maxCommitteeSize
i++
j++
} else if q.Signers[i] < q.Committee[j] {
i++
} else { // A[i] > B[j]
j++
}
}

return bitmask
}

type SignersBitMask uint16

func (obm SignersBitMask) SignersList(committee []spectypes.OperatorID) []spectypes.OperatorID {
if len(committee) > maxCommitteeSize {
panic(fmt.Sprintf("invalid committee size: %d", len(committee)))
}

signers := make([]spectypes.OperatorID, 0, bits.OnesCount16(uint16(obm)))
for j := 0; j < len(committee); j++ {
// #nosec G115 -- j cannot exceed maxCommitteeSize
if obm&(1<<uint(j)) != 0 {
signers = append(signers, committee[j])
}
}

return signers
}
Loading
Loading