Skip to content

Commit

Permalink
Merge pull request #28 from RockX-SG/keysign-protocol-simple
Browse files Browse the repository at this point in the history
Keysign protocol simple
  • Loading branch information
calvinzhou-rockx authored Jun 23, 2023
2 parents 5b25a86 + 8f9b19e commit 2424b91
Show file tree
Hide file tree
Showing 35 changed files with 1,272 additions and 518 deletions.
103 changes: 103 additions & 0 deletions dkg/common/message_container.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
package common

import (
"fmt"
"sync"

"github.com/bloxapp/ssv-spec/dkg"
"github.com/pkg/errors"
)

type IMsgContainer interface {
SaveMsg(round ProtocolRound, msg *dkg.SignedMessage) (existingMessage *dkg.SignedMessage, err error)
GetSignedMsg(round ProtocolRound, operatorID uint32) (*dkg.SignedMessage, error)
AllMessagesForRound(round ProtocolRound) map[uint32]*dkg.SignedMessage
AllMessagesReceivedFor(round ProtocolRound, operators []uint32) bool
AllMessagesReceivedUpto(round ProtocolRound, operators []uint32, threshold uint64) bool
}

type MsgContainer struct {
mu *sync.Mutex
msgs map[ProtocolRound]map[uint32]*dkg.SignedMessage
}

func NewMsgContainer() IMsgContainer {
m := make(map[ProtocolRound]map[uint32]*dkg.SignedMessage)
for _, round := range rounds {
m[round] = make(map[uint32]*dkg.SignedMessage)
}
return &MsgContainer{msgs: m, mu: new(sync.Mutex)}
}

func (msgContainer *MsgContainer) SaveMsg(round ProtocolRound, msg *dkg.SignedMessage) (existingMessage *dkg.SignedMessage, err error) {
msgContainer.mu.Lock()
defer msgContainer.mu.Unlock()

existingMessage, exists := msgContainer.msgs[round][uint32(msg.Signer)]
if exists {
return existingMessage, errors.New("msg already exists")
}
msgContainer.msgs[round][uint32(msg.Signer)] = msg
return nil, nil
}

func (msgContainer *MsgContainer) GetSignedMsg(round ProtocolRound, operatorID uint32) (*dkg.SignedMessage, error) {
msgContainer.mu.Lock()
defer msgContainer.mu.Unlock()

signedMsg, exist := msgContainer.msgs[round][operatorID]
if !exist {
return nil, ErrMsgNotFound{Round: round, OperatorID: operatorID}
}
return signedMsg, nil
}

func (msgContainer *MsgContainer) AllMessagesForRound(round ProtocolRound) map[uint32]*dkg.SignedMessage {
msgContainer.mu.Lock()
defer msgContainer.mu.Unlock()

return msgContainer.msgs[round]
}

func (msgContainer *MsgContainer) AllMessagesReceivedFor(round ProtocolRound, operators []uint32) bool {
msgContainer.mu.Lock()
defer msgContainer.mu.Unlock()

for _, operatorID := range operators {
if _, ok := msgContainer.msgs[round][uint32(operatorID)]; !ok {
return false
}
}
return true
}

func (msgContainer *MsgContainer) AllMessagesReceivedUpto(round ProtocolRound, operators []uint32, threshold uint64) bool {
msgContainer.mu.Lock()
defer msgContainer.mu.Unlock()

totalMsgsRecieved := uint64(0)
for _, operatorID := range operators {
if _, ok := msgContainer.msgs[round][uint32(operatorID)]; ok {
totalMsgsRecieved += 1
}
}
return totalMsgsRecieved >= threshold
}

type ErrMsgNotFound struct {
Round ProtocolRound
OperatorID uint32
}

func (e ErrMsgNotFound) Error() string {
return fmt.Sprintf("message for operatorID %d and round %d not found\n", e.OperatorID, e.Round)
}

type ErrMsgNil struct {
Round ProtocolRound
OperatorID uint32
}

func (e ErrMsgNil) Error() string {
return fmt.Sprintf("message for operatorID %d and round %d is nil\n", e.OperatorID, e.Round)
}
40 changes: 40 additions & 0 deletions dkg/common/protocol_rounds.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
package common

// ProtocolRound is enum for all the rounds in the protocol
type ProtocolRound int

const (
Uninitialized ProtocolRound = iota
Preparation
Round1
Round2
KeygenOutput
Blame
Timeout
KeysignOutput
)

var rounds = []ProtocolRound{
Uninitialized,
Preparation,
Round1,
Round2,
KeygenOutput,
Blame,
Timeout,
KeysignOutput,
}

func (round ProtocolRound) String() string {
m := map[ProtocolRound]string{
Uninitialized: "Uninitialized",
Preparation: "Preparation",
Round1: "Round1",
Round2: "Round2",
KeygenOutput: "KeygenOutput",
Blame: "Blame",
Timeout: "Timeout",
KeysignOutput: "KeysignOutput",
}
return m[round]
}
19 changes: 10 additions & 9 deletions dkg/frost/blame.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"bytes"

"github.com/bloxapp/ssv-spec/dkg"
"github.com/bloxapp/ssv-spec/dkg/common"
"github.com/coinbase/kryptology/pkg/sharing"
ecies "github.com/ecies/go/v2"
"github.com/pkg/errors"
Expand All @@ -13,7 +14,7 @@ import (
// Message to blame an inconsistent message, including the existing message and
// the new message as blame data.
func (fr *Instance) createAndBroadcastBlameOfInconsistentMessage(existingMessage, newMessage *dkg.SignedMessage) (bool, *dkg.ProtocolOutcome, error) {
fr.state.SetCurrentRound(Blame)
fr.state.SetCurrentRound(common.Blame)

existingMessageBytes, err := existingMessage.Encode()
if err != nil {
Expand All @@ -25,7 +26,7 @@ func (fr *Instance) createAndBroadcastBlameOfInconsistentMessage(existingMessage
}

msg := &ProtocolMsg{
Round: Blame,
Round: common.Blame,
BlameMessage: &BlameMessage{
Type: InconsistentMessage,
TargetOperatorID: uint32(newMessage.Signer),
Expand Down Expand Up @@ -54,9 +55,9 @@ func (fr *Instance) createAndBroadcastBlameOfInconsistentMessage(existingMessage
// Message to blame an invalid share, including the round 1 message from culprit
// operator
func (fr *Instance) createAndBroadcastBlameOfInvalidShare(culpritOID uint32) (bool, *dkg.ProtocolOutcome, error) {
fr.state.SetCurrentRound(Blame)
fr.state.SetCurrentRound(common.Blame)

round1Msg, err := fr.state.msgContainer.GetSignedMsg(Round1, culpritOID)
round1Msg, err := fr.state.msgContainer.GetSignedMsg(common.Round1, culpritOID)
if err != nil {
return false, nil, err
}
Expand All @@ -66,7 +67,7 @@ func (fr *Instance) createAndBroadcastBlameOfInvalidShare(culpritOID uint32) (bo
}

msg := &ProtocolMsg{
Round: Blame,
Round: common.Blame,
BlameMessage: &BlameMessage{
Type: InvalidShare,
TargetOperatorID: culpritOID,
Expand Down Expand Up @@ -95,15 +96,15 @@ func (fr *Instance) createAndBroadcastBlameOfInvalidShare(culpritOID uint32) (bo
// blame an invalid message, including the operatorID of the culprit and the
// received signed message.
func (fr *Instance) createAndBroadcastBlameOfInvalidMessage(culpritOID uint32, message *dkg.SignedMessage) (bool, *dkg.ProtocolOutcome, error) {
fr.state.SetCurrentRound(Blame)
fr.state.SetCurrentRound(common.Blame)

bytes, err := message.Encode()
if err != nil {
return false, nil, err
}

msg := &ProtocolMsg{
Round: Blame,
Round: common.Blame,
BlameMessage: &BlameMessage{
Type: InvalidMessage,
TargetOperatorID: culpritOID,
Expand All @@ -130,7 +131,7 @@ func (fr *Instance) createAndBroadcastBlameOfInvalidMessage(culpritOID uint32, m

// checkBlame checks validity of the blame message as per its blame type
func (fr *Instance) checkBlame(blamerOID uint32, protocolMessage *ProtocolMsg, signedMessage *dkg.SignedMessage) (finished bool, protocolOutcome *dkg.ProtocolOutcome, err error) {
fr.state.SetCurrentRound(Blame)
fr.state.SetCurrentRound(common.Blame)

var valid bool
switch protocolMessage.BlameMessage.Type {
Expand Down Expand Up @@ -167,7 +168,7 @@ func (fr *Instance) processBlameTypeInvalidShare(blamerOID uint32, blameMessage
return false, errors.Wrap(err, "failed to Validate signature for blame data")
}

blamerPrepMsg, err := fr.state.msgContainer.GetPreparationMsg(blamerOID)
blamerPrepMsg, err := GetPreparationMsg(fr.state.msgContainer, blamerOID)
if err != nil {
return false, errors.New("unable to retrieve blamer's PreparationMessage")
}
Expand Down
9 changes: 5 additions & 4 deletions dkg/frost/helper_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,20 +3,21 @@ package frost
import (
"testing"

"github.com/bloxapp/ssv-spec/dkg/common"
"github.com/stretchr/testify/require"
)

func Test_haveSameRoot(t *testing.T) {
t.Run("true case", func(t *testing.T) {
msg := testSignedMessage(Preparation, 1)
msg2 := testSignedMessage(Preparation, 1)
msg := testSignedMessage(common.Preparation, 1)
msg2 := testSignedMessage(common.Preparation, 1)
actual := haveSameRoot(msg, msg2)
require.EqualValues(t, true, actual)
})

t.Run("false case", func(t *testing.T) {
msg := testSignedMessage(Preparation, 1)
msg2 := testSignedMessage(Preparation, 2)
msg := testSignedMessage(common.Preparation, 1)
msg2 := testSignedMessage(common.Preparation, 2)
actual := haveSameRoot(msg, msg2)
require.EqualValues(t, false, actual)
})
Expand Down
27 changes: 14 additions & 13 deletions dkg/frost/instance.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,12 @@ import (
"math/rand"

"github.com/bloxapp/ssv-spec/dkg"
"github.com/bloxapp/ssv-spec/dkg/common"
"github.com/bloxapp/ssv-spec/types"
"github.com/coinbase/kryptology/pkg/core/curves"
"github.com/coinbase/kryptology/pkg/dkg/frost"
ecies "github.com/ecies/go/v2"
"github.com/ethereum/go-ethereum/common"
ethcommon "github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/crypto"
"github.com/pkg/errors"
)
Expand Down Expand Up @@ -87,7 +88,7 @@ func newProtocol(config dkg.IConfig, instanceParams InstanceParams) dkg.Protocol
// starting this process.
func (fr *Instance) Start() error {
fr.state.roundTimer.OnTimeout(fr.UponRoundTimeout)
fr.state.SetCurrentRound(Preparation)
fr.state.SetCurrentRound(common.Preparation)
fr.state.roundTimer.StartRoundTimeoutTimer(fr.state.GetCurrentRound())

// create a new dkg participant
Expand All @@ -114,7 +115,7 @@ func (fr *Instance) Start() error {

// create and broadcast PreparationMessage
msg := &ProtocolMsg{
Round: Preparation,
Round: common.Preparation,
PreparationMessage: &PreparationMessage{
SessionPk: k.PublicKey.Bytes(true),
},
Expand Down Expand Up @@ -153,17 +154,17 @@ func (fr *Instance) ProcessMsg(msg *dkg.SignedMessage) (finished bool, protocolO

// process message based on their round
switch protocolMessage.Round {
case Preparation:
case common.Preparation:
return fr.processRound1()
case Round1:
case common.Round1:
return fr.processRound2()
case Round2:
case common.Round2:
return fr.processKeygenOutput()
case Blame:
case common.Blame:
// here we are checking blame right away unlike other rounds where
// we wait to receive messages from all the operators in the protocol
return fr.checkBlame(uint32(msg.Signer), protocolMessage, msg)
case Timeout:
case common.Timeout:
return fr.ProcessTimeoutMessage()
default:
return true, nil, dkg.ErrInvalidRound{}
Expand All @@ -172,8 +173,8 @@ func (fr *Instance) ProcessMsg(msg *dkg.SignedMessage) (finished bool, protocolO

func (fr *Instance) canProceedThisRound() bool {
// Note: for Resharing, Preparation (New Committee) -> Round1 (Old Committee) -> Round2 (New Committee)
if fr.instanceParams.isResharing() && fr.state.GetCurrentRound() == Round1 {
return fr.state.msgContainer.AllMessagesReceivedFor(Round1, fr.instanceParams.operatorsOld)
if fr.instanceParams.isResharing() && fr.state.GetCurrentRound() == common.Round1 {
return fr.state.msgContainer.AllMessagesReceivedFor(common.Round1, fr.instanceParams.operatorsOld)
}
return fr.state.msgContainer.AllMessagesReceivedFor(fr.state.GetCurrentRound(), fr.instanceParams.operators)
}
Expand All @@ -183,9 +184,9 @@ func (fr *Instance) needToRunCurrentRound() bool {
return true // always run for new keygen
}
switch fr.state.GetCurrentRound() {
case Preparation, Round2, KeygenOutput:
case common.Preparation, common.Round2, common.KeygenOutput:
return fr.instanceParams.inNewCommittee()
case Round1:
case common.Round1:
return fr.instanceParams.inOldCommittee()
default:
return false
Expand Down Expand Up @@ -215,7 +216,7 @@ func (fr *Instance) validateSignedMessage(msg *dkg.SignedMessage) error {
return errors.Wrap(err, "unable to recover public key")
}

addr := common.BytesToAddress(crypto.Keccak256(pk[1:])[12:])
addr := ethcommon.BytesToAddress(crypto.Keccak256(pk[1:])[12:])
if addr != operator.ETHAddress {
return errors.New("invalid signature")
}
Expand Down
7 changes: 4 additions & 3 deletions dkg/frost/keygen_output.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"bytes"

"github.com/bloxapp/ssv-spec/dkg"
"github.com/bloxapp/ssv-spec/dkg/common"
"github.com/bloxapp/ssv-spec/types"

"github.com/herumi/bls-eth-go-binary/bls"
Expand All @@ -20,7 +21,7 @@ func (fr *Instance) processKeygenOutput() (finished bool, protocolOutcome *dkg.P
if !fr.canProceedThisRound() {
return false, nil, nil
}
fr.state.SetCurrentRound(KeygenOutput)
fr.state.SetCurrentRound(common.KeygenOutput)
fr.state.roundTimer.StartRoundTimeoutTimer(fr.state.GetCurrentRound())

if !fr.needToRunCurrentRound() {
Expand All @@ -42,7 +43,7 @@ func (fr *Instance) processKeygenOutput() (finished bool, protocolOutcome *dkg.P

operatorPubKeys := make(map[types.OperatorID]*bls.PublicKey)
for _, operatorID := range fr.instanceParams.operators {
msg, err := fr.state.msgContainer.GetRound2Msg(operatorID)
msg, err := GetRound2Msg(fr.state.msgContainer, operatorID)
if err != nil {
return false, nil, errors.Wrap(err, "failed to retrieve round2 msg")
}
Expand Down Expand Up @@ -131,7 +132,7 @@ func (fr *Instance) getXVec(operators []uint32) ([]bls.Fr, error) {
func (fr *Instance) getYVec(operators []uint32) ([]bls.G1, error) {
yVec := make([]bls.G1, 0)
for _, operatorID := range operators {
msg, err := fr.state.msgContainer.GetRound2Msg(operatorID)
msg, err := GetRound2Msg(fr.state.msgContainer, operatorID)
if err != nil {
return nil, errors.Wrap(err, "failed to retrieve round2 msg")
}
Expand Down
Loading

0 comments on commit 2424b91

Please sign in to comment.