From 7e314f390fa1c5143a84f7069d56101b59349d98 Mon Sep 17 00:00:00 2001 From: edwardmack Date: Tue, 6 Aug 2024 17:25:19 -0400 Subject: [PATCH] add validation logic to workers --- dot/parachain/backing/integration_test.go | 13 +-- .../candidate_validation.go | 45 ++++----- dot/parachain/pvf/host.go | 93 ++++++++++++++++++- dot/parachain/pvf/host_test.go | 2 +- dot/parachain/pvf/worker.go | 92 ++++++++++++++---- dot/parachain/pvf/worker_pool.go | 81 ++++++---------- dot/parachain/pvf/worker_pool_test.go | 2 +- dot/parachain/pvf/worker_test.go | 13 ++- 8 files changed, 235 insertions(+), 106 deletions(-) diff --git a/dot/parachain/backing/integration_test.go b/dot/parachain/backing/integration_test.go index 9cdc09af0d..df6e5ac431 100644 --- a/dot/parachain/backing/integration_test.go +++ b/dot/parachain/backing/integration_test.go @@ -4,6 +4,7 @@ package backing_test import ( + "github.com/ChainSafe/gossamer/dot/parachain/pvf" "testing" "time" @@ -275,9 +276,9 @@ func TestSecondsValidCandidate(t *testing.T) { return false } - badReturn := candidatevalidation.BadReturn - validateFromExhaustive.Ch <- parachaintypes.OverseerFuncRes[candidatevalidation.ValidationResult]{ - Data: candidatevalidation.ValidationResult{ + badReturn := pvf.BadReturn + validateFromExhaustive.Ch <- parachaintypes.OverseerFuncRes[pvf.ValidationResult]{ + Data: pvf.ValidationResult{ InvalidResult: &badReturn, }, } @@ -339,9 +340,9 @@ func TestSecondsValidCandidate(t *testing.T) { return false } - validateFromExhaustive.Ch <- parachaintypes.OverseerFuncRes[candidatevalidation.ValidationResult]{ - Data: candidatevalidation.ValidationResult{ - ValidResult: &candidatevalidation.ValidValidationResult{ + validateFromExhaustive.Ch <- parachaintypes.OverseerFuncRes[pvf.ValidationResult]{ + Data: pvf.ValidationResult{ + ValidResult: &pvf.ValidValidationResult{ CandidateCommitments: parachaintypes.CandidateCommitments{ UpwardMessages: []parachaintypes.UpwardMessage{}, HorizontalMessages: []parachaintypes.OutboundHrmpMessage{}, diff --git a/dot/parachain/candidate-validation/candidate_validation.go b/dot/parachain/candidate-validation/candidate_validation.go index 25cce11368..ac03e208f4 100644 --- a/dot/parachain/candidate-validation/candidate_validation.go +++ b/dot/parachain/candidate-validation/candidate_validation.go @@ -89,33 +89,36 @@ func (cv *CandidateValidation) processMessages(wg *sync.WaitGroup) { case ValidateFromExhaustive: // This is the skeleton to hook up the PVF host to the candidate validation subsystem // This is currently WIP, pending moving the validation logic to the PVF host - validationCodeHash := msg.ValidationCode.Hash() taskResult := make(chan *pvf.ValidationTaskResult) validationTask := &pvf.ValidationTask{ - PersistedValidationData: parachaintypes.PersistedValidationData{}, - WorkerID: &validationCodeHash, - CandidateReceipt: &msg.CandidateReceipt, - PoV: msg.PoV, - ExecutorParams: nil, - PvfExecTimeoutKind: parachaintypes.PvfExecTimeoutKind{}, - ResultCh: taskResult, + PersistedValidationData: msg.PersistedValidationData, + //WorkerID: &validationCodeHash, + ValidationCode: &msg.ValidationCode, + CandidateReceipt: &msg.CandidateReceipt, + PoV: msg.PoV, + ExecutorParams: nil, + PvfExecTimeoutKind: parachaintypes.PvfExecTimeoutKind{}, + ResultCh: taskResult, } - cv.pvfHost.Validate(validationTask) - fmt.Printf("Validation result: %v", <-taskResult) + go cv.pvfHost.Validate(validationTask) + + result := <-taskResult + fmt.Printf("Validation result: %v", result) + // TODO(ed): determine how to handle this error and result // WIP: This is the current implementation of the validation logic, it will be replaced by the PVF host // when the validation logic is moved to the PVF host - result, err := validateFromExhaustive(cv.ValidationHost, msg.PersistedValidationData, - msg.ValidationCode, msg.CandidateReceipt, msg.PoV) - if err != nil { - logger.Errorf("failed to validate from exhaustive: %w", err) - msg.Ch <- parachaintypes.OverseerFuncRes[pvf.ValidationResult]{ - Err: err, - } - } else { - msg.Ch <- parachaintypes.OverseerFuncRes[pvf.ValidationResult]{ - Data: *result, - } + //result, err := validateFromExhaustive(cv.ValidationHost, msg.PersistedValidationData, + // msg.ValidationCode, msg.CandidateReceipt, msg.PoV) + //if err != nil { + // logger.Errorf("failed to validate from exhaustive: %w", err) + // msg.Ch <- parachaintypes.OverseerFuncRes[pvf.ValidationResult]{ + // Err: err, + // } + //} else { + msg.Ch <- parachaintypes.OverseerFuncRes[pvf.ValidationResult]{ + Data: *result.Result, + //} } case PreCheck: diff --git a/dot/parachain/pvf/host.go b/dot/parachain/pvf/host.go index 0c335bb5c7..c2a44f5429 100644 --- a/dot/parachain/pvf/host.go +++ b/dot/parachain/pvf/host.go @@ -2,6 +2,9 @@ package pvf import ( "fmt" + parachainruntime "github.com/ChainSafe/gossamer/dot/parachain/runtime" + parachaintypes "github.com/ChainSafe/gossamer/dot/parachain/types" + "github.com/ChainSafe/gossamer/pkg/scale" "sync" "github.com/ChainSafe/gossamer/internal/log" @@ -40,10 +43,90 @@ func NewValidationHost() *ValidationHost { func (v *ValidationHost) Validate(msg *ValidationTask) { logger.Debugf("Validating worker", "workerID", msg.WorkerID) - logger.Debugf("submitting request for worker", "workerID", msg.WorkerID) - hasWorker := v.workerPool.containsWorker(*msg.WorkerID) - if !hasWorker { - v.workerPool.newValidationWorker(*msg.WorkerID) + validationCodeHash := msg.ValidationCode.Hash() + // basic checks + validationErr, internalErr := performBasicChecks(&msg.CandidateReceipt.Descriptor, + msg.PersistedValidationData.MaxPovSize, + msg.PoV, + validationCodeHash) + // TODO(ed): confirm how to handle internal errors + if internalErr != nil { + logger.Errorf("performing basic checks: %w", internalErr) } - v.workerPool.submitRequest(msg) + + if validationErr != nil { + valErr := &ValidationTaskResult{ + who: validationCodeHash, + Result: &ValidationResult{ + InvalidResult: validationErr, + }, + } + msg.ResultCh <- valErr + return + } + + workerID := v.poolContainsWorker(msg) + validationParams := parachainruntime.ValidationParameters{ + ParentHeadData: msg.PersistedValidationData.ParentHead, + BlockData: msg.PoV.BlockData, + RelayParentNumber: msg.PersistedValidationData.RelayParentNumber, + RelayParentStorageRoot: msg.PersistedValidationData.RelayParentStorageRoot, + } + workTask := &workerTask{ + work: validationParams, + maxPoVSize: msg.PersistedValidationData.MaxPovSize, + ResultCh: msg.ResultCh, + } + v.workerPool.submitRequest(workerID, workTask) +} + +func (v *ValidationHost) poolContainsWorker(msg *ValidationTask) parachaintypes.ValidationCodeHash { + if msg.WorkerID != nil { + return *msg.WorkerID + } + if v.workerPool.containsWorker(msg.ValidationCode.Hash()) { + return msg.ValidationCode.Hash() + } else { + v.workerPool.newValidationWorker(*msg.ValidationCode) + return msg.ValidationCode.Hash() + } +} + +// performBasicChecks Does basic checks of a candidate. Provide the encoded PoV-block. +// Returns ReasonForInvalidity and internal error if any. +func performBasicChecks(candidate *parachaintypes.CandidateDescriptor, maxPoVSize uint32, + pov parachaintypes.PoV, validationCodeHash parachaintypes.ValidationCodeHash) ( + validationError *ReasonForInvalidity, internalError error) { + povHash, err := pov.Hash() + if err != nil { + return nil, fmt.Errorf("hashing PoV: %w", err) + } + + encodedPoV, err := scale.Marshal(pov) + if err != nil { + return nil, fmt.Errorf("encoding PoV: %w", err) + } + encodedPoVSize := uint32(len(encodedPoV)) + + if encodedPoVSize > maxPoVSize { + ci := ParamsTooLarge + return &ci, nil + } + + if povHash != candidate.PovHash { + ci := PoVHashMismatch + return &ci, nil + } + + if validationCodeHash != candidate.ValidationCodeHash { + ci := CodeHashMismatch + return &ci, nil + } + + err = candidate.CheckCollatorSignature() + if err != nil { + ci := BadSignature + return &ci, nil + } + return nil, nil } diff --git a/dot/parachain/pvf/host_test.go b/dot/parachain/pvf/host_test.go index dd691ea2c9..bf30b255d9 100644 --- a/dot/parachain/pvf/host_test.go +++ b/dot/parachain/pvf/host_test.go @@ -33,7 +33,7 @@ func Test_validationHost_start(t *testing.T) { func TestValidationHost(t *testing.T) { v := NewValidationHost() v.Start() - v.workerPool.newValidationWorker(parachaintypes.ValidationCodeHash{1, 2, 3, 4}) + v.workerPool.newValidationWorker(parachaintypes.ValidationCode{1, 2, 3, 4}) resCh := make(chan *ValidationTaskResult) diff --git a/dot/parachain/pvf/worker.go b/dot/parachain/pvf/worker.go index bbe2ddb089..e366ce16f9 100644 --- a/dot/parachain/pvf/worker.go +++ b/dot/parachain/pvf/worker.go @@ -1,44 +1,104 @@ package pvf import ( - "sync" - "time" - + parachainruntime "github.com/ChainSafe/gossamer/dot/parachain/runtime" parachaintypes "github.com/ChainSafe/gossamer/dot/parachain/types" + "sync" ) type worker struct { workerID parachaintypes.ValidationCodeHash + instance *parachainruntime.Instance + queue chan *workerTask } -func newWorker(pID parachaintypes.ValidationCodeHash) *worker { - return &worker{ - workerID: pID, +type workerTask struct { + work parachainruntime.ValidationParameters + maxPoVSize uint32 + ResultCh chan<- *ValidationTaskResult +} + +func newWorker(validationCode parachaintypes.ValidationCode, queue chan *workerTask) (*worker, error) { + validationRuntime, err := parachainruntime.SetupVM(validationCode) + + if err != nil { + return nil, err } + return &worker{ + workerID: validationCode.Hash(), + instance: validationRuntime, + queue: queue, + }, nil } -func (w *worker) run(queue chan *ValidationTask, wg *sync.WaitGroup) { +func (w *worker) run(queue chan *workerTask, wg *sync.WaitGroup) { defer func() { logger.Debugf("[STOPPED] worker %x", w.workerID) wg.Done() }() for task := range queue { - executeRequest(task) + w.executeRequest(task) } } -func executeRequest(task *ValidationTask) { +func (w *worker) executeRequest(task *workerTask) { // WIP: This is a dummy implementation of the worker execution for the validation task. The logic for // validating the parachain block request should be implemented here. - request := task.PoV - logger.Debugf("[EXECUTING] worker %x, block request: %s", task.WorkerID, request) - time.Sleep(500 * time.Millisecond) - dummyResult := &ValidationResult{} + logger.Debugf("[EXECUTING] worker %x task %v", w.workerID, task.work) + + // todo do basic checks + + validationResult, err := w.instance.ValidateBlock(task.work) + + /////////////////////////////// + //if err != nil { + // return nil, fmt.Errorf("executing validate_block: %w", err) + //} + + //headDataHash, err := validationResult.HeadData.Hash() + //if err != nil { + // return nil, fmt.Errorf("hashing head data: %w", err) + //} + // + //if headDataHash != candidateReceipt.Descriptor.ParaHead { + // ci := pvf.ParaHeadHashMismatch + // return &pvf.ValidationResult{InvalidResult: &ci}, nil + //} + candidateCommitments := parachaintypes.CandidateCommitments{ + UpwardMessages: validationResult.UpwardMessages, + HorizontalMessages: validationResult.HorizontalMessages, + NewValidationCode: validationResult.NewValidationCode, + HeadData: validationResult.HeadData, + ProcessedDownwardMessages: validationResult.ProcessedDownwardMessages, + HrmpWatermark: validationResult.HrmpWatermark, + } + + // if validation produced a new set of commitments, we treat the candidate as invalid + //if candidateReceipt.CommitmentsHash != candidateCommitments.Hash() { + // ci := CommitmentsHashMismatch + // return &ValidationResult{InvalidResult: &ci}, nil + //} + pvd := parachaintypes.PersistedValidationData{ + ParentHead: task.work.ParentHeadData, + RelayParentNumber: task.work.RelayParentNumber, + RelayParentStorageRoot: task.work.RelayParentStorageRoot, + MaxPovSize: task.maxPoVSize, + } + dummyResilt := &ValidationResult{ + ValidResult: &ValidValidationResult{ + CandidateCommitments: candidateCommitments, + PersistedValidationData: pvd, + }, + } + ////////////////////////// + + logger.Debugf("[RESULT] worker %x, result: %v, error: %s", w.workerID, dummyResilt, err) + task.ResultCh <- &ValidationTaskResult{ - who: *task.WorkerID, - result: dummyResult, + who: w.workerID, + Result: dummyResilt, } - logger.Debugf("[FINISHED] worker %x", task.WorkerID) + //logger.Debugf("[FINISHED] worker %v, error: %s", validationResult, err) } diff --git a/dot/parachain/pvf/worker_pool.go b/dot/parachain/pvf/worker_pool.go index fc63b9b412..6b9d1328c5 100644 --- a/dot/parachain/pvf/worker_pool.go +++ b/dot/parachain/pvf/worker_pool.go @@ -1,14 +1,11 @@ package pvf import ( - "crypto/rand" "fmt" - "math/big" "sync" "time" parachaintypes "github.com/ChainSafe/gossamer/dot/parachain/types" - "golang.org/x/exp/maps" ) const ( @@ -19,7 +16,7 @@ type validationWorkerPool struct { mtx sync.RWMutex wg sync.WaitGroup - workers map[parachaintypes.ValidationCodeHash]*validationWorker + workers map[parachaintypes.ValidationCodeHash]*worker } type ValidationTask struct { @@ -29,12 +26,13 @@ type ValidationTask struct { PoV parachaintypes.PoV ExecutorParams parachaintypes.ExecutorParams PvfExecTimeoutKind parachaintypes.PvfExecTimeoutKind + ValidationCode *parachaintypes.ValidationCode ResultCh chan<- *ValidationTaskResult } type ValidationTaskResult struct { who parachaintypes.ValidationCodeHash - result *ValidationResult + Result *ValidationResult } // ValidationResult represents the result coming from the candidate validation subsystem. @@ -123,14 +121,14 @@ func (ci ReasonForInvalidity) Error() string { } } -type validationWorker struct { - worker *worker - queue chan *ValidationTask -} +//type validationWorker struct { +// worker *worker +// queue chan *workerTask +//} func newValidationWorkerPool() *validationWorkerPool { return &validationWorkerPool{ - workers: make(map[parachaintypes.ValidationCodeHash]*validationWorker), + workers: make(map[parachaintypes.ValidationCodeHash]*worker), } } @@ -162,61 +160,42 @@ func (v *validationWorkerPool) stop() error { } } -func (v *validationWorkerPool) newValidationWorker(who parachaintypes.ValidationCodeHash) { - - worker := newWorker(who) - workerQueue := make(chan *ValidationTask, maxRequestsAllowed) +func (v *validationWorkerPool) newValidationWorker(validationCode parachaintypes.ValidationCode) parachaintypes.ValidationCodeHash { + workerQueue := make(chan *workerTask, maxRequestsAllowed) + worker, err := newWorker(validationCode, workerQueue) + if err != nil { + // TODO(ed): handle this error + logger.Errorf("failed to create a new worker: %w", err) + } v.wg.Add(1) go worker.run(workerQueue, &v.wg) - v.workers[who] = &validationWorker{ - worker: worker, - queue: workerQueue, - } - logger.Tracef("potential worker added, total in the pool %d", len(v.workers)) + v.workers[worker.workerID] = worker + + return worker.workerID } // submitRequest given a request, the worker pool will get the peer given the peer.ID // parameter or if nil the very first available worker or // to perform the request, the response will be dispatch in the resultCh. -func (v *validationWorkerPool) submitRequest(request *ValidationTask) { - - //task := &validationTask{ - // request: request, - // resultCh: resultCh, - //} - - // if the request is bounded to a specific peer then just - // request it and sent through its queue otherwise send - // the request in the general queue where all worker are - // listening on +func (v *validationWorkerPool) submitRequest(workerID parachaintypes.ValidationCodeHash, request *workerTask) { v.mtx.RLock() defer v.mtx.RUnlock() + logger.Debugf("pool submit request workerID %x", workerID) - if request.WorkerID != nil { - syncWorker, inMap := v.workers[*request.WorkerID] - if inMap { - if syncWorker == nil { - panic("sync worker should not be nil") - } - syncWorker.queue <- request - return + //if request.WorkerID != nil { + syncWorker, inMap := v.workers[workerID] + if inMap { + if syncWorker == nil { + panic("sync worker should not be nil") } + logger.Debugf("sending request", workerID) + syncWorker.queue <- request + return } - - // if the exact peer is not specified then - // randomly select a worker and assign the - // task to it, if the amount of workers is - var selectedWorkerIdx int - workers := maps.Values(v.workers) - nBig, err := rand.Int(rand.Reader, big.NewInt(int64(len(workers)))) - if err != nil { - panic(fmt.Errorf("fail to get a random number: %w", err)) - } - selectedWorkerIdx = int(nBig.Int64()) - selectedWorker := workers[selectedWorkerIdx] - selectedWorker.queue <- request + // TODO(ed): handle this case + logger.Errorf("workerID %x not found in the pool", workerID) } func (v *validationWorkerPool) containsWorker(workerID parachaintypes.ValidationCodeHash) bool { diff --git a/dot/parachain/pvf/worker_pool_test.go b/dot/parachain/pvf/worker_pool_test.go index 88ea789209..23f5becc2c 100644 --- a/dot/parachain/pvf/worker_pool_test.go +++ b/dot/parachain/pvf/worker_pool_test.go @@ -17,7 +17,7 @@ func TestValidationWorkerPool_newValidationWorker(t *testing.T) { "add_one_worker": { setupWorkerPool: func(t *testing.T) *validationWorkerPool { pool := newValidationWorkerPool() - pool.newValidationWorker(parachaintypes.ValidationCodeHash{1, 2, 3, 4}) + pool.newValidationWorker(parachaintypes.ValidationCode{1, 2, 3, 4}) return pool }, expectedWorkers: []parachaintypes.ValidationCodeHash{ diff --git a/dot/parachain/pvf/worker_test.go b/dot/parachain/pvf/worker_test.go index 893671dc71..05f3abf089 100644 --- a/dot/parachain/pvf/worker_test.go +++ b/dot/parachain/pvf/worker_test.go @@ -1,6 +1,7 @@ package pvf import ( + "github.com/stretchr/testify/require" "sync" "testing" "time" @@ -9,12 +10,14 @@ import ( ) func TestWorker(t *testing.T) { - workerID1 := parachaintypes.ValidationCodeHash{1, 2, 3, 4} + workerID1 := parachaintypes.ValidationCode{1, 2, 3, 4} - w := newWorker(workerID1) + workerQueue := make(chan *workerTask, maxRequestsAllowed) + w, err := newWorker(workerID1, workerQueue) + require.NoError(t, err) wg := sync.WaitGroup{} - queue := make(chan *ValidationTask, 2) + queue := make(chan *workerTask, 2) wg.Add(1) go w.run(queue, &wg) @@ -22,11 +25,11 @@ func TestWorker(t *testing.T) { resultCh := make(chan *ValidationTaskResult) defer close(resultCh) - queue <- &ValidationTask{ + queue <- &workerTask{ ResultCh: resultCh, } - queue <- &ValidationTask{ + queue <- &workerTask{ ResultCh: resultCh, }