diff --git a/message/validation/committee_info.go b/message/validation/committee_info.go new file mode 100644 index 0000000000..3a859dc8eb --- /dev/null +++ b/message/validation/committee_info.go @@ -0,0 +1,36 @@ +package validation + +import ( + "github.com/attestantio/go-eth2-client/spec/phase0" + spectypes "github.com/ssvlabs/ssv-spec/types" +) + +type CommitteeInfo struct { + committeeID spectypes.CommitteeID + committee []spectypes.OperatorID + signerIndices map[spectypes.OperatorID]int + validatorIndices []phase0.ValidatorIndex +} + +func newCommitteeInfo( + committeeID spectypes.CommitteeID, + operators []spectypes.OperatorID, + validatorIndices []phase0.ValidatorIndex, +) CommitteeInfo { + signerIndices := make(map[spectypes.OperatorID]int) + for i, operator := range operators { + signerIndices[operator] = i + } + + return CommitteeInfo{ + committeeID: committeeID, + committee: operators, + signerIndices: signerIndices, + validatorIndices: validatorIndices, + } +} + +// keeping the method for readability and the comment +func (ci *CommitteeInfo) signerIndex(signer spectypes.OperatorID) int { + return ci.signerIndices[signer] // existence must be checked by ErrSignerNotInCommittee +} diff --git a/message/validation/common_checks.go b/message/validation/common_checks.go index 35c0ae59fe..ec88c9bc60 100644 --- a/message/validation/common_checks.go +++ b/message/validation/common_checks.go @@ -40,7 +40,7 @@ func (mv *messageValidator) messageLateness(slot phase0.Slot, role spectypes.Run case spectypes.RoleProposer, spectypes.RoleSyncCommitteeContribution: ttl = 1 + lateSlotAllowance case spectypes.RoleCommittee, spectypes.RoleAggregator: - ttl = phase0.Slot(mv.netCfg.Beacon.SlotsPerEpoch()) + lateSlotAllowance + ttl = MaxStoredSlots(mv.netCfg) case spectypes.RoleValidatorRegistration, spectypes.RoleVoluntaryExit: return 0 } diff --git a/message/validation/consensus_state.go b/message/validation/consensus_state.go index f5d4b2c121..9cc7c7d44b 100644 --- a/message/validation/consensus_state.go +++ b/message/validation/consensus_state.go @@ -4,36 +4,29 @@ import ( "sync" "github.com/attestantio/go-eth2-client/spec/phase0" - spectypes "github.com/ssvlabs/ssv-spec/types" ) -// consensusID uniquely identifies a public key and role pair to keep track of state. -type consensusID struct { - DutyExecutorID string - Role spectypes.RunnerRole -} - -// consensusState keeps track of the signers for a given public key and role. -type consensusState struct { - state map[spectypes.OperatorID]*OperatorState +// ValidatorState keeps track of the signers for a given public key and role. +type ValidatorState struct { + operators []*OperatorState storedSlotCount phase0.Slot mu sync.Mutex } -func (cs *consensusState) GetOrCreate(signer spectypes.OperatorID) *OperatorState { +func (cs *ValidatorState) Signer(idx int) *OperatorState { cs.mu.Lock() defer cs.mu.Unlock() - if _, ok := cs.state[signer]; !ok { - cs.state[signer] = newOperatorState(cs.storedSlotCount) + if cs.operators[idx] == nil { + cs.operators[idx] = newOperatorState(cs.storedSlotCount) } - return cs.state[signer] + return cs.operators[idx] } type OperatorState struct { - mu sync.RWMutex - state []*SignerState // the slice index is slot % storedSlotCount + mu sync.Mutex + signers []*SignerState // the slice index is slot % storedSlotCount maxSlot phase0.Slot maxEpoch phase0.Epoch lastEpochDuties uint64 @@ -42,15 +35,15 @@ type OperatorState struct { func newOperatorState(size phase0.Slot) *OperatorState { return &OperatorState{ - state: make([]*SignerState, size), + signers: make([]*SignerState, size), } } func (os *OperatorState) Get(slot phase0.Slot) *SignerState { - os.mu.RLock() - defer os.mu.RUnlock() + os.mu.Lock() + defer os.mu.Unlock() - s := os.state[(uint64(slot) % uint64(len(os.state)))] + s := os.signers[(uint64(slot) % uint64(len(os.signers)))] if s == nil || s.Slot != slot { return nil } @@ -62,7 +55,7 @@ func (os *OperatorState) Set(slot phase0.Slot, epoch phase0.Epoch, state *Signer os.mu.Lock() defer os.mu.Unlock() - os.state[uint64(slot)%uint64(len(os.state))] = state + os.signers[uint64(slot)%uint64(len(os.signers))] = state if slot > os.maxSlot { os.maxSlot = slot } @@ -76,15 +69,15 @@ func (os *OperatorState) Set(slot phase0.Slot, epoch phase0.Epoch, state *Signer } func (os *OperatorState) MaxSlot() phase0.Slot { - os.mu.RLock() - defer os.mu.RUnlock() + os.mu.Lock() + defer os.mu.Unlock() return os.maxSlot } func (os *OperatorState) DutyCount(epoch phase0.Epoch) uint64 { - os.mu.RLock() - defer os.mu.RUnlock() + os.mu.Lock() + defer os.mu.Unlock() if epoch == os.maxEpoch { return os.lastEpochDuties diff --git a/message/validation/consensus_state_test.go b/message/validation/consensus_state_test.go index e8008fd804..eb8faca0eb 100644 --- a/message/validation/consensus_state_test.go +++ b/message/validation/consensus_state_test.go @@ -12,7 +12,7 @@ func TestOperatorState(t *testing.T) { size := phase0.Slot(10) os := newOperatorState(size) require.NotNil(t, os) - require.Equal(t, len(os.state), int(size)) + require.Equal(t, len(os.signers), int(size)) }) t.Run("TestGetAndSet", func(t *testing.T) { @@ -58,9 +58,9 @@ func TestOperatorState(t *testing.T) { slot := phase0.Slot(5) epoch := phase0.Epoch(1) - signerState := &SignerState{Slot: slot} + signerState1 := &SignerState{Slot: slot} - os.Set(slot, epoch, signerState) + os.Set(slot, epoch, signerState1) require.Equal(t, os.DutyCount(epoch), uint64(1)) require.Equal(t, os.DutyCount(epoch-1), uint64(0)) @@ -82,9 +82,9 @@ func TestOperatorState(t *testing.T) { slot := phase0.Slot(5) epoch := phase0.Epoch(1) - signerState := &SignerState{Slot: slot} + signerState1 := &SignerState{Slot: slot} - os.Set(slot, epoch, signerState) + os.Set(slot, epoch, signerState1) require.Equal(t, os.DutyCount(epoch), uint64(1)) slot2 := phase0.Slot(6) diff --git a/message/validation/consensus_validation.go b/message/validation/consensus_validation.go index 258c6d77a4..9df6bdfc9d 100644 --- a/message/validation/consensus_validation.go +++ b/message/validation/consensus_validation.go @@ -5,7 +5,6 @@ package validation import ( "bytes" "crypto/sha256" - "encoding/binary" "encoding/hex" "fmt" "time" @@ -43,17 +42,17 @@ func (mv *messageValidator) validateConsensusMessage( mv.metrics.ConsensusMsgType(consensusMessage.MsgType, len(signedSSVMessage.OperatorIDs)) - if err := mv.validateConsensusMessageSemantics(signedSSVMessage, consensusMessage, committeeInfo.operatorIDs); err != nil { + if err := mv.validateConsensusMessageSemantics(signedSSVMessage, consensusMessage, committeeInfo.committee); err != nil { return consensusMessage, err } - state := mv.consensusState(signedSSVMessage.SSVMessage.GetID()) + state := mv.validatorState(signedSSVMessage.SSVMessage.GetID(), committeeInfo.committee) - if err := mv.validateQBFTLogic(signedSSVMessage, consensusMessage, committeeInfo.operatorIDs, receivedAt, state); err != nil { + if err := mv.validateQBFTLogic(signedSSVMessage, consensusMessage, committeeInfo, receivedAt, state); err != nil { return consensusMessage, err } - if err := mv.validateQBFTMessageByDutyLogic(signedSSVMessage, consensusMessage, committeeInfo.indices, receivedAt, state); err != nil { + if err := mv.validateQBFTMessageByDutyLogic(signedSSVMessage, consensusMessage, committeeInfo, receivedAt, state); err != nil { return consensusMessage, err } @@ -68,7 +67,7 @@ func (mv *messageValidator) validateConsensusMessage( } } - if err := mv.updateConsensusState(signedSSVMessage, consensusMessage, state); err != nil { + if err := mv.updateConsensusState(signedSSVMessage, consensusMessage, committeeInfo, state); err != nil { return consensusMessage, err } @@ -174,13 +173,13 @@ func (mv *messageValidator) validateConsensusMessageSemantics( func (mv *messageValidator) validateQBFTLogic( signedSSVMessage *spectypes.SignedSSVMessage, consensusMessage *specqbft.Message, - committee []spectypes.OperatorID, + committeeInfo CommitteeInfo, receivedAt time.Time, - state *consensusState, + state *ValidatorState, ) error { if consensusMessage.MsgType == specqbft.ProposalMsgType { // Rule: Signer must be the leader - leader := mv.roundRobinProposer(consensusMessage.Height, consensusMessage.Round, committee) + leader := mv.roundRobinProposer(consensusMessage.Height, consensusMessage.Round, committeeInfo.committee) if signedSSVMessage.OperatorIDs[0] != leader { err := ErrSignerNotLeader err.got = signedSSVMessage.OperatorIDs[0] @@ -191,7 +190,7 @@ func (mv *messageValidator) validateQBFTLogic( msgSlot := phase0.Slot(consensusMessage.Height) for _, signer := range signedSSVMessage.OperatorIDs { - signerStateBySlot := state.GetOrCreate(signer) + signerStateBySlot := state.Signer(committeeInfo.signerIndex(signer)) signerState := signerStateBySlot.Get(msgSlot) if signerState == nil { continue @@ -211,26 +210,28 @@ func (mv *messageValidator) validateQBFTLogic( if consensusMessage.Round == signerState.Round { // Rule: Peer must not send two proposals with different data - if len(signedSSVMessage.FullData) != 0 && signerState.ProposalData != nil && !bytes.Equal(signerState.ProposalData, signedSSVMessage.FullData) { - return ErrDifferentProposalData + if len(signedSSVMessage.FullData) != 0 && signerState.HashedProposalData != nil { + if *signerState.HashedProposalData != sha256.Sum256(signedSSVMessage.FullData) { + return ErrDifferentProposalData + } } // Rule: Peer must send only 1 proposal, 1 prepare, 1 commit, and 1 round-change per round - limits := maxMessageCounts() - if err := signerState.MessageCounts.ValidateConsensusMessage(signedSSVMessage, consensusMessage, limits); err != nil { + if err := signerState.SeenMsgTypes.ValidateConsensusMessage(signedSSVMessage, consensusMessage); err != nil { return err } } } else if len(signedSSVMessage.OperatorIDs) > 1 { - // Rule: Decided msg can't have the same signers as previously sent before for the same duty - encodedOperators, err := encodeOperators(signedSSVMessage.OperatorIDs) - if err != nil { - return err + quorum := Quorum{ + Signers: signedSSVMessage.OperatorIDs, + Committee: committeeInfo.committee, } // Rule: Decided msg can't have the same signers as previously sent before for the same duty - if _, ok := signerState.SeenSigners[encodedOperators]; ok { - return ErrDecidedWithSameSigners + if signerState.SeenSigners != nil { + if _, ok := signerState.SeenSigners[quorum.ToBitMask()]; ok { + return ErrDecidedWithSameSigners + } } } } @@ -248,16 +249,16 @@ func (mv *messageValidator) validateQBFTLogic( func (mv *messageValidator) validateQBFTMessageByDutyLogic( signedSSVMessage *spectypes.SignedSSVMessage, consensusMessage *specqbft.Message, - validatorIndices []phase0.ValidatorIndex, + committeeInfo CommitteeInfo, receivedAt time.Time, - state *consensusState, + state *ValidatorState, ) error { role := signedSSVMessage.SSVMessage.GetID().GetRoleType() // Rule: Height must not be "old". I.e., signer must not have already advanced to a later slot. if role != spectypes.RoleCommittee { // Rule only for validator runners for _, signer := range signedSSVMessage.OperatorIDs { - signerStateBySlot := state.GetOrCreate(signer) + signerStateBySlot := state.Signer(committeeInfo.signerIndex(signer)) if maxSlot := signerStateBySlot.MaxSlot(); maxSlot > phase0.Slot(consensusMessage.Height) { e := ErrSlotAlreadyAdvanced e.got = consensusMessage.Height @@ -269,7 +270,7 @@ func (mv *messageValidator) validateQBFTMessageByDutyLogic( msgSlot := phase0.Slot(consensusMessage.Height) randaoMsg := false - if err := mv.validateBeaconDuty(signedSSVMessage.SSVMessage.GetID().GetRoleType(), msgSlot, validatorIndices, randaoMsg); err != nil { + if err := mv.validateBeaconDuty(signedSSVMessage.SSVMessage.GetID().GetRoleType(), msgSlot, committeeInfo.validatorIndices, randaoMsg); err != nil { return err } @@ -285,8 +286,8 @@ func (mv *messageValidator) validateQBFTMessageByDutyLogic( // - 2*V for Committee duty (where V is the number of validators in the cluster) (if no validator is doing sync committee in this epoch) // - else, accept for _, signer := range signedSSVMessage.OperatorIDs { - signerStateBySlot := state.GetOrCreate(signer) - if err := mv.validateDutyCount(signedSSVMessage.SSVMessage.GetID(), msgSlot, validatorIndices, signerStateBySlot); err != nil { + signerStateBySlot := state.Signer(committeeInfo.signerIndex(signer)) + if err := mv.validateDutyCount(signedSSVMessage.SSVMessage.GetID(), msgSlot, committeeInfo.validatorIndices, signerStateBySlot); err != nil { return err } } @@ -294,15 +295,20 @@ func (mv *messageValidator) validateQBFTMessageByDutyLogic( return nil } -func (mv *messageValidator) updateConsensusState(signedSSVMessage *spectypes.SignedSSVMessage, consensusMessage *specqbft.Message, consensusState *consensusState) error { +func (mv *messageValidator) updateConsensusState( + signedSSVMessage *spectypes.SignedSSVMessage, + consensusMessage *specqbft.Message, + committeeInfo CommitteeInfo, + consensusState *ValidatorState, +) error { msgSlot := phase0.Slot(consensusMessage.Height) msgEpoch := mv.netCfg.Beacon.EstimatedEpochAtSlot(msgSlot) for _, signer := range signedSSVMessage.OperatorIDs { - stateBySlot := consensusState.GetOrCreate(signer) + stateBySlot := consensusState.Signer(committeeInfo.signerIndex(signer)) signerState := stateBySlot.Get(msgSlot) if signerState == nil { - signerState = NewSignerState(phase0.Slot(consensusMessage.Height), consensusMessage.Round) + signerState = newSignerState(phase0.Slot(consensusMessage.Height), consensusMessage.Round) stateBySlot.Set(msgSlot, msgEpoch, signerState) } else { if consensusMessage.Round > signerState.Round { @@ -310,7 +316,7 @@ func (mv *messageValidator) updateConsensusState(signedSSVMessage *spectypes.Sig } } - if err := mv.processSignerState(signedSSVMessage, consensusMessage, signerState); err != nil { + if err := mv.processSignerState(signedSSVMessage, consensusMessage, committeeInfo.committee, signerState); err != nil { return err } } @@ -318,23 +324,31 @@ func (mv *messageValidator) updateConsensusState(signedSSVMessage *spectypes.Sig return nil } -func (mv *messageValidator) processSignerState(signedSSVMessage *spectypes.SignedSSVMessage, consensusMessage *specqbft.Message, signerState *SignerState) error { +func (mv *messageValidator) processSignerState( + signedSSVMessage *spectypes.SignedSSVMessage, + consensusMessage *specqbft.Message, + committee []spectypes.OperatorID, + signerState *SignerState, +) error { if len(signedSSVMessage.FullData) != 0 && consensusMessage.MsgType == specqbft.ProposalMsgType { - signerState.ProposalData = signedSSVMessage.FullData + fullDataHash := sha256.Sum256(signedSSVMessage.FullData) + signerState.HashedProposalData = &fullDataHash } signerCount := len(signedSSVMessage.OperatorIDs) if signerCount > 1 { - encodedOperators, err := encodeOperators(signedSSVMessage.OperatorIDs) - if err != nil { - // encodeOperators must never re - return ErrEncodeOperators + quorum := Quorum{ + Signers: signedSSVMessage.OperatorIDs, + Committee: committee, } - signerState.SeenSigners[encodedOperators] = struct{}{} + if signerState.SeenSigners == nil { + signerState.SeenSigners = make(map[SignersBitMask]struct{}) // lazy init on demand to reduce mem consumption + } + signerState.SeenSigners[quorum.ToBitMask()] = struct{}{} } - return signerState.MessageCounts.RecordConsensusMessage(signedSSVMessage, consensusMessage) + return signerState.SeenMsgTypes.RecordConsensusMessage(signedSSVMessage, consensusMessage) } func (mv *messageValidator) validateJustifications(message *specqbft.Message) error { @@ -451,14 +465,3 @@ func (mv *messageValidator) roundRobinProposer(height specqbft.Height, round spe index := (firstRoundIndex + uint64(round) - uint64(specqbft.FirstRound)) % uint64(len(committee)) return committee[index] } - -func encodeOperators(operators []spectypes.OperatorID) ([sha256.Size]byte, error) { - buf := new(bytes.Buffer) - for _, operator := range operators { - if err := binary.Write(buf, binary.LittleEndian, operator); err != nil { - return [sha256.Size]byte{}, err - } - } - hash := sha256.Sum256(buf.Bytes()) - return hash, nil -} diff --git a/message/validation/const.go b/message/validation/const.go index 339556fc4e..2c2fec58de 100644 --- a/message/validation/const.go +++ b/message/validation/const.go @@ -2,6 +2,8 @@ package validation import ( "time" + + "github.com/attestantio/go-eth2-client/spec/phase0" ) // To add some encoding overhead for ssz, we use (N + N/encodingOverheadDivisor + 4) for a structure with expected size N @@ -13,7 +15,7 @@ const ( clockErrorTolerance = time.Millisecond * 50 allowedRoundsInFuture = 1 allowedRoundsInPast = 2 - lateSlotAllowance = 2 + lateSlotAllowance = phase0.Slot(2) syncCommitteeSize = 512 rsaSignatureSize = 256 operatorIDSize = 8 // uint64 diff --git a/message/validation/errors.go b/message/validation/errors.go index c1a90997f3..5e23d8a7db 100644 --- a/message/validation/errors.go +++ b/message/validation/errors.go @@ -115,7 +115,6 @@ var ( ErrDuplicatedMessage = Error{text: "message is duplicated", reject: true} ErrInvalidPartialSignatureTypeCount = Error{text: "sent more partial signature messages of a certain type than allowed", reject: true} ErrTooManyPartialSignatureMessages = Error{text: "too many partial signature messages", reject: true} - ErrEncodeOperators = Error{text: "encode operators", reject: true} ) func (mv *messageValidator) handleValidationError(peerID peer.ID, decodedMessage *queue.SSVMessage, err error) pubsub.ValidationResult { diff --git a/message/validation/message_counts.go b/message/validation/message_counts.go deleted file mode 100644 index bb54852510..0000000000 --- a/message/validation/message_counts.go +++ /dev/null @@ -1,137 +0,0 @@ -package validation - -// message_counts.go contains code for counting and validating messages per validator-slot-round. - -import ( - "fmt" - - specqbft "github.com/ssvlabs/ssv-spec/qbft" - spectypes "github.com/ssvlabs/ssv-spec/types" -) - -// MessageCounts tracks the number of various message types received for validation. -type MessageCounts struct { - PreConsensus int - Proposal int - Prepare int - Commit int - RoundChange int - PostConsensus int -} - -// String provides a formatted representation of the MessageCounts. -func (c *MessageCounts) String() string { - return fmt.Sprintf("pre-consensus: %v, proposal: %v, prepare: %v, commit: %v, round change: %v, post-consensus: %v", - c.PreConsensus, - c.Proposal, - c.Prepare, - c.Commit, - c.RoundChange, - c.PostConsensus, - ) -} - -// ValidateConsensusMessage checks if the provided consensus message exceeds the set limits. -// Returns an error if the message type exceeds its respective count limit. -func (c *MessageCounts) ValidateConsensusMessage(signedSSVMessage *spectypes.SignedSSVMessage, msg *specqbft.Message, limits MessageCounts) error { - switch msg.MsgType { - case specqbft.ProposalMsgType: - if c.Proposal >= limits.Proposal { - err := ErrDuplicatedMessage - err.got = fmt.Sprintf("proposal, having %v", c.String()) - return err - } - case specqbft.PrepareMsgType: - if c.Prepare >= limits.Prepare { - err := ErrDuplicatedMessage - err.got = fmt.Sprintf("prepare, having %v", c.String()) - return err - } - case specqbft.CommitMsgType: - if len(signedSSVMessage.OperatorIDs) == 1 { - if c.Commit >= limits.Commit { - err := ErrDuplicatedMessage - err.got = fmt.Sprintf("commit, having %v", c.String()) - return err - } - } - case specqbft.RoundChangeMsgType: - if c.RoundChange >= limits.RoundChange { - err := ErrDuplicatedMessage - - err.got = fmt.Sprintf("round change, having %v", c.String()) - return err - } - default: - return fmt.Errorf("unexpected signed message type") // should be checked before - } - - return nil -} - -// ValidatePartialSignatureMessage checks if the provided partial signature message exceeds the set limits. -// Returns an error if the message type exceeds its respective count limit. -func (c *MessageCounts) ValidatePartialSignatureMessage(m *spectypes.PartialSignatureMessages, limits MessageCounts) error { - switch m.Type { - case spectypes.RandaoPartialSig, spectypes.SelectionProofPartialSig, spectypes.ContributionProofs, spectypes.ValidatorRegistrationPartialSig, spectypes.VoluntaryExitPartialSig: - if c.PreConsensus >= limits.PreConsensus { - err := ErrInvalidPartialSignatureTypeCount - err.got = fmt.Sprintf("pre-consensus, having %v", c.String()) - return err - } - case spectypes.PostConsensusPartialSig: - if c.PostConsensus >= limits.PostConsensus { - err := ErrInvalidPartialSignatureTypeCount - err.got = fmt.Sprintf("post-consensus, having %v", c.String()) - return err - } - default: - return fmt.Errorf("unexpected partial signature message type") // should be checked before - } - - return nil -} - -// RecordConsensusMessage updates the counts based on the provided consensus message type. -func (c *MessageCounts) RecordConsensusMessage(signedSSVMessage *spectypes.SignedSSVMessage, msg *specqbft.Message) error { - switch msg.MsgType { - case specqbft.ProposalMsgType: - c.Proposal++ - case specqbft.PrepareMsgType: - c.Prepare++ - case specqbft.CommitMsgType: - if len(signedSSVMessage.OperatorIDs) == 1 { - c.Commit++ - } - case specqbft.RoundChangeMsgType: - c.RoundChange++ - default: - return fmt.Errorf("unexpected signed message type") // should be checked before - } - return nil -} - -// RecordPartialSignatureMessage updates the counts based on the provided partial signature message type. -func (c *MessageCounts) RecordPartialSignatureMessage(messages *spectypes.PartialSignatureMessages) error { - switch messages.Type { - case spectypes.RandaoPartialSig, spectypes.SelectionProofPartialSig, spectypes.ContributionProofs, spectypes.ValidatorRegistrationPartialSig, spectypes.VoluntaryExitPartialSig: - c.PreConsensus++ - case spectypes.PostConsensusPartialSig: - c.PostConsensus++ - default: - return fmt.Errorf("unexpected partial signature message type") // should be checked before - } - return nil -} - -// maxMessageCounts is the maximum number of acceptable messages from a signer within a slot & round. -func maxMessageCounts() MessageCounts { - return MessageCounts{ - PreConsensus: 1, - Proposal: 1, - Prepare: 1, - Commit: 1, - RoundChange: 1, - PostConsensus: 1, - } -} diff --git a/message/validation/partial_validation.go b/message/validation/partial_validation.go index 3df628a933..c1a5fee3e6 100644 --- a/message/validation/partial_validation.go +++ b/message/validation/partial_validation.go @@ -9,7 +9,6 @@ import ( "github.com/attestantio/go-eth2-client/spec/phase0" specqbft "github.com/ssvlabs/ssv-spec/qbft" - "github.com/ssvlabs/ssv-spec/types" spectypes "github.com/ssvlabs/ssv-spec/types" ) @@ -37,12 +36,12 @@ func (mv *messageValidator) validatePartialSignatureMessage( return nil, e } - if err := mv.validatePartialSignatureMessageSemantics(signedSSVMessage, partialSignatureMessages, committeeInfo.indices); err != nil { + if err := mv.validatePartialSignatureMessageSemantics(signedSSVMessage, partialSignatureMessages, committeeInfo.validatorIndices); err != nil { return nil, err } msgID := ssvMessage.GetID() - state := mv.consensusState(msgID) + state := mv.validatorState(msgID, committeeInfo.committee) if err := mv.validatePartialSigMessagesByDutyLogic(signedSSVMessage, partialSignatureMessages, committeeInfo, receivedAt, state); err != nil { return nil, err } @@ -55,7 +54,7 @@ func (mv *messageValidator) validatePartialSignatureMessage( return partialSignatureMessages, e } - if err := mv.updatePartialSignatureState(partialSignatureMessages, state, signer); err != nil { + if err := mv.updatePartialSignatureState(partialSignatureMessages, state, signer, committeeInfo); err != nil { return nil, err } @@ -137,15 +136,15 @@ func (mv *messageValidator) validatePartialSigMessagesByDutyLogic( partialSignatureMessages *spectypes.PartialSignatureMessages, committeeInfo CommitteeInfo, receivedAt time.Time, - state *consensusState, + state *ValidatorState, ) error { role := signedSSVMessage.SSVMessage.GetID().GetRoleType() messageSlot := partialSignatureMessages.Slot signer := signedSSVMessage.OperatorIDs[0] - signerStateBySlot := state.GetOrCreate(signer) + signerStateBySlot := state.Signer(committeeInfo.signerIndex(signer)) // Rule: Height must not be "old". I.e., signer must not have already advanced to a later slot. - if role != types.RoleCommittee { // Rule only for validator runners + if role != spectypes.RoleCommittee { // Rule only for validator runners maxSlot := signerStateBySlot.MaxSlot() if maxSlot != 0 && maxSlot > partialSignatureMessages.Slot { e := ErrSlotAlreadyAdvanced @@ -156,7 +155,7 @@ func (mv *messageValidator) validatePartialSigMessagesByDutyLogic( } randaoMsg := partialSignatureMessages.Type == spectypes.RandaoPartialSig - if err := mv.validateBeaconDuty(signedSSVMessage.SSVMessage.GetID().GetRoleType(), messageSlot, committeeInfo.indices, randaoMsg); err != nil { + if err := mv.validateBeaconDuty(signedSSVMessage.SSVMessage.GetID().GetRoleType(), messageSlot, committeeInfo.validatorIndices, randaoMsg); err != nil { return err } @@ -168,8 +167,7 @@ func (mv *messageValidator) validatePartialSigMessagesByDutyLogic( // - 1 SelectionProofPartialSig and 1 PostConsensusPartialSig for Sync committee contribution // - 1 ValidatorRegistrationPartialSig for Validator Registration // - 1 VoluntaryExitPartialSig for Voluntary Exit - limits := maxMessageCounts() - if err := signerState.MessageCounts.ValidatePartialSignatureMessage(partialSignatureMessages, limits); err != nil { + if err := signerState.SeenMsgTypes.ValidatePartialSignatureMessage(partialSignatureMessages); err != nil { return err } } @@ -185,11 +183,11 @@ func (mv *messageValidator) validatePartialSigMessagesByDutyLogic( // - 2 for aggregation, voluntary exit and validator registration // - 2*V for Committee duty (where V is the number of validators in the cluster) (if no validator is doing sync committee in this epoch) // - else, accept - if err := mv.validateDutyCount(signedSSVMessage.SSVMessage.GetID(), messageSlot, committeeInfo.indices, signerStateBySlot); err != nil { + if err := mv.validateDutyCount(signedSSVMessage.SSVMessage.GetID(), messageSlot, committeeInfo.validatorIndices, signerStateBySlot); err != nil { return err } - clusterValidatorCount := len(committeeInfo.indices) + clusterValidatorCount := len(committeeInfo.validatorIndices) partialSignatureMessageCount := len(partialSignatureMessages.Messages) if signedSSVMessage.SSVMessage.MsgID.GetRoleType() == spectypes.RoleCommittee { @@ -206,7 +204,7 @@ func (mv *messageValidator) validatePartialSigMessagesByDutyLogic( return ErrTripleValidatorIndexInPartialSignatures } } - } else if signedSSVMessage.SSVMessage.MsgID.GetRoleType() == types.RoleSyncCommitteeContribution { + } else if signedSSVMessage.SSVMessage.MsgID.GetRoleType() == spectypes.RoleSyncCommitteeContribution { // Rule: The number of signatures must be <= MaxSignaturesInSyncCommitteeContribution for the sync comittee contribution duty if partialSignatureMessageCount > maxSignatures { e := ErrTooManyPartialSignatureMessages @@ -227,20 +225,21 @@ func (mv *messageValidator) validatePartialSigMessagesByDutyLogic( func (mv *messageValidator) updatePartialSignatureState( partialSignatureMessages *spectypes.PartialSignatureMessages, - state *consensusState, + state *ValidatorState, signer spectypes.OperatorID, + committeeInfo CommitteeInfo, ) error { - stateBySlot := state.GetOrCreate(signer) + stateBySlot := state.Signer(committeeInfo.signerIndex(signer)) messageSlot := partialSignatureMessages.Slot messageEpoch := mv.netCfg.Beacon.EstimatedEpochAtSlot(messageSlot) signerState := stateBySlot.Get(messageSlot) if signerState == nil || signerState.Slot != messageSlot { - signerState = NewSignerState(messageSlot, specqbft.FirstRound) + signerState = newSignerState(messageSlot, specqbft.FirstRound) stateBySlot.Set(messageSlot, messageEpoch, signerState) } - return signerState.MessageCounts.RecordPartialSignatureMessage(partialSignatureMessages) + return signerState.SeenMsgTypes.RecordPartialSignatureMessage(partialSignatureMessages) } func (mv *messageValidator) validPartialSigMsgType(msgType spectypes.PartialSigMsgType) bool { diff --git a/message/validation/quorum.go b/message/validation/quorum.go new file mode 100644 index 0000000000..2079ac551e --- /dev/null +++ b/message/validation/quorum.go @@ -0,0 +1,58 @@ +package validation + +import ( + "fmt" + "math/bits" + + spectypes "github.com/ssvlabs/ssv-spec/types" +) + +// TODO: Take all of these from https://github.com/ssvlabs/ssv/pull/1867 once it's merged. +// This file is temporary to avoid the need to be based on another PR, hence there are no tests. + +const maxCommitteeSize = 13 + +type Quorum struct { + Signers []spectypes.OperatorID + Committee []spectypes.OperatorID +} + +func (q *Quorum) ToBitMask() SignersBitMask { + if len(q.Signers) > maxCommitteeSize || len(q.Committee) > maxCommitteeSize || len(q.Signers) > len(q.Committee) { + panic(fmt.Sprintf("invalid signers/committee size: %d/%d", len(q.Signers), len(q.Committee))) + } + + bitmask := SignersBitMask(0) + i, j := 0, 0 + for i < len(q.Signers) && j < len(q.Committee) { + if q.Signers[i] == q.Committee[j] { + bitmask |= 1 << uint(j) // #nosec G115 -- j cannot exceed maxCommitteeSize + i++ + j++ + } else if q.Signers[i] < q.Committee[j] { + i++ + } else { // A[i] > B[j] + j++ + } + } + + return bitmask +} + +type SignersBitMask uint16 + +func (obm SignersBitMask) SignersList(committee []spectypes.OperatorID) []spectypes.OperatorID { + if len(committee) > maxCommitteeSize { + panic(fmt.Sprintf("invalid committee size: %d", len(committee))) + } + + signers := make([]spectypes.OperatorID, 0, bits.OnesCount16(uint16(obm))) + for j := 0; j < len(committee); j++ { + // #nosec G115 -- j cannot exceed maxCommitteeSize + if obm&(1<