diff --git a/dot/parachain/candidate-validation/candidate_validation.go b/dot/parachain/candidate-validation/candidate_validation.go index ecc4174c94..ea596798c0 100644 --- a/dot/parachain/candidate-validation/candidate_validation.go +++ b/dot/parachain/candidate-validation/candidate_validation.go @@ -33,7 +33,7 @@ type CandidateValidation struct { SubsystemToOverseer chan<- any OverseerToSubsystem <-chan any BlockState BlockState - pvfHost *pvf.ValidationHost + pvfHost *pvf.Host } type BlockState interface { @@ -51,9 +51,8 @@ func NewCandidateValidation(overseerChan chan<- any, blockState BlockState) *Can } // Run starts the CandidateValidation subsystem -func (cv *CandidateValidation) Run(context.Context, chan any, chan any) { +func (cv *CandidateValidation) Run(context.Context, <-chan any) { cv.wg.Add(1) - go cv.pvfHost.Start() go cv.processMessages(&cv.wg) } @@ -76,7 +75,6 @@ func (*CandidateValidation) ProcessBlockFinalizedSignal(parachaintypes.BlockFina // Stop stops the CandidateValidation subsystem func (cv *CandidateValidation) Stop() { - cv.pvfHost.Stop() close(cv.stopChan) cv.wg.Wait() } @@ -93,7 +91,6 @@ func (cv *CandidateValidation) processMessages(wg *sync.WaitGroup) { cv.validateFromChainState(msg) case ValidateFromExhaustive: - taskResult := make(chan *pvf.ValidationTaskResult) validationTask := &pvf.ValidationTask{ PersistedValidationData: msg.PersistedValidationData, ValidationCode: &msg.ValidationCode, @@ -101,22 +98,24 @@ func (cv *CandidateValidation) processMessages(wg *sync.WaitGroup) { PoV: msg.PoV, ExecutorParams: msg.ExecutorParams, PvfExecTimeoutKind: msg.PvfExecTimeoutKind, - ResultCh: taskResult, } - go cv.pvfHost.Validate(validationTask) - result := <-taskResult - if result.InternalError != nil { - logger.Errorf("failed to validate from exhaustive: %w", result.InternalError) - msg.Ch <- parachaintypes.OverseerFuncRes[pvf.ValidationResult]{ - Err: result.InternalError, - } - } else { - msg.Ch <- parachaintypes.OverseerFuncRes[pvf.ValidationResult]{ - Data: *result.Result, - } - } + taskResultChan := cv.pvfHost.Validate(validationTask) + go func() { + + result := <-taskResultChan + if result.InternalError != nil { + logger.Errorf("failed to validate from exhaustive: %w", result.InternalError) + msg.Ch <- parachaintypes.OverseerFuncRes[pvf.ValidationResult]{ + Err: result.InternalError, + } + } else { + msg.Ch <- parachaintypes.OverseerFuncRes[pvf.ValidationResult]{ + Data: *result.Result, + } + } + }() case PreCheck: // TODO: implement functionality to handle PreCheck, see issue #3921 @@ -195,7 +194,6 @@ func (cv *CandidateValidation) validateFromChainState(msg ValidateFromChainState return } - taskResult := make(chan *pvf.ValidationTaskResult) validationTask := &pvf.ValidationTask{ PersistedValidationData: *persistedValidationData, ValidationCode: validationCode, @@ -203,11 +201,10 @@ func (cv *CandidateValidation) validateFromChainState(msg ValidateFromChainState PoV: msg.Pov, ExecutorParams: msg.ExecutorParams, PvfExecTimeoutKind: parachaintypes.PvfExecTimeoutKind{}, - ResultCh: taskResult, } - go cv.pvfHost.Validate(validationTask) - result := <-taskResult + taskResultChan := cv.pvfHost.Validate(validationTask) + result := <-taskResultChan if result.InternalError != nil { logger.Errorf("failed to validate from chain state: %w", result.InternalError) msg.Ch <- parachaintypes.OverseerFuncRes[pvf.ValidationResult]{ diff --git a/dot/parachain/candidate-validation/candidate_validation_test.go b/dot/parachain/candidate-validation/candidate_validation_test.go index 0fd2a1ba05..ffdb52dd4c 100644 --- a/dot/parachain/candidate-validation/candidate_validation_test.go +++ b/dot/parachain/candidate-validation/candidate_validation_test.go @@ -113,8 +113,6 @@ func TestCandidateValidation_validateFromExhaustive(t *testing.T) { executionError := pvf.ExecutionError pvfHost := pvf.NewValidationHost() - pvfHost.Start() - defer pvfHost.Stop() ctrl := gomock.NewController(t) t.Cleanup(ctrl.Finish) @@ -293,7 +291,7 @@ func TestCandidateValidation_validateFromExhaustive(t *testing.T) { taskResult := make(chan *pvf.ValidationTaskResult) defer close(taskResult) - tt.validationTask.ResultCh = taskResult + //tt.validationTask.ResultCh = taskResult go pvfHost.Validate(tt.validationTask) diff --git a/dot/parachain/collator-protocol/message_test.go b/dot/parachain/collator-protocol/message_test.go index e00c8a8122..9dd209a0bd 100644 --- a/dot/parachain/collator-protocol/message_test.go +++ b/dot/parachain/collator-protocol/message_test.go @@ -19,7 +19,6 @@ import ( "github.com/ChainSafe/gossamer/dot/network" collatorprotocolmessages "github.com/ChainSafe/gossamer/dot/parachain/collator-protocol/messages" networkbridgemessages "github.com/ChainSafe/gossamer/dot/parachain/network-bridge/messages" - "github.com/ChainSafe/gossamer/dot/parachain/overseer" parachaintypes "github.com/ChainSafe/gossamer/dot/parachain/types" "github.com/ChainSafe/gossamer/dot/peerset" ) @@ -375,17 +374,22 @@ func TestHandleCollationMessageDeclare(t *testing.T) { c := c t.Run(c.description, func(t *testing.T) { t.Parallel() + + subsystemToOverseer := make(chan any) cpvs := CollatorProtocolValidatorSide{ - peerData: c.peerData, - currentAssignments: c.currentAssignments, + SubSystemToOverseer: subsystemToOverseer, + peerData: c.peerData, + currentAssignments: c.currentAssignments, } - mockOverseer := overseer.NewMockableOverseer(t) - mockOverseer.RegisterSubsystem(&cpvs) - cpvs.SubSystemToOverseer = mockOverseer.GetSubsystemToOverseerChannel() - - mockOverseer.Start() - defer mockOverseer.Stop() + // ensure that the expected messages are sent to the overseer + if len(c.expectedMessages) > 0 { + go func() { + for _, expectedMessage := range c.expectedMessages { + require.Equal(t, expectedMessage, <-subsystemToOverseer) + } + }() + } msg := collatorprotocolmessages.NewCollationProtocol() vdtChild := collatorprotocolmessages.NewCollatorProtocolMessage() @@ -444,7 +448,6 @@ func TestHandleCollationMessageAdvertiseCollation(t *testing.T) { }, errString: ErrRelayParentUnknown.Error(), }, - { description: "fail with unknown peer if peer is not tracked in our list of active collators", advertiseCollation: collatorprotocolmessages.AdvertiseCollation(testRelayParent), @@ -574,19 +577,21 @@ func TestHandleCollationMessageAdvertiseCollation(t *testing.T) { t.Run(c.description, func(t *testing.T) { t.Parallel() + subsystemToOverseer := make(chan any) cpvs := CollatorProtocolValidatorSide{ - net: c.net, - perRelayParent: c.perRelayParent, - peerData: c.peerData, - activeLeaves: c.activeLeaves, + SubSystemToOverseer: subsystemToOverseer, + net: c.net, + perRelayParent: c.perRelayParent, + peerData: c.peerData, + activeLeaves: c.activeLeaves, } - mockOverseer := overseer.NewMockableOverseer(t) - mockOverseer.RegisterSubsystem(&cpvs) - cpvs.SubSystemToOverseer = mockOverseer.GetSubsystemToOverseerChannel() - - mockOverseer.Start() - defer mockOverseer.Stop() + // ensure that the expected messages are sent to the overseer + if c.expectedMessage != nil { + go func() { + require.Equal(t, c.expectedMessage, <-subsystemToOverseer) + }() + } msg := collatorprotocolmessages.NewCollationProtocol() vdtChild := collatorprotocolmessages.NewCollatorProtocolMessage() @@ -604,7 +609,6 @@ func TestHandleCollationMessageAdvertiseCollation(t *testing.T) { } else { require.ErrorContains(t, err, c.errString) } - }) } } diff --git a/dot/parachain/overseer/mockable_overseer.go b/dot/parachain/overseer/mockable_overseer.go index a9a55c2ef5..85f74780af 100644 --- a/dot/parachain/overseer/mockable_overseer.go +++ b/dot/parachain/overseer/mockable_overseer.go @@ -96,8 +96,15 @@ func (m *MockableOverseer) processMessages() { } actionIndex = actionIndex + 1 + } else { + m.t.Errorf("unexpected message: %T", msg) + return } case <-m.ctx.Done(): + if actionIndex < len(m.actionsForExpectedMessages) { + m.t.Errorf("expected %d overseer actions, but got only %d", len(m.actionsForExpectedMessages), actionIndex) + } + if err := m.ctx.Err(); err != nil { m.t.Logf("ctx error: %v\n", err) } diff --git a/dot/parachain/pvf/host.go b/dot/parachain/pvf/host.go index 2acf7d7aa9..af54c033f2 100644 --- a/dot/parachain/pvf/host.go +++ b/dot/parachain/pvf/host.go @@ -2,7 +2,6 @@ package pvf import ( "fmt" - "sync" parachainruntime "github.com/ChainSafe/gossamer/dot/parachain/runtime" parachaintypes "github.com/ChainSafe/gossamer/dot/parachain/types" @@ -12,96 +11,92 @@ import ( var logger = log.NewFromGlobal(log.AddContext("pkg", "pvf"), log.SetLevel(log.Debug)) -type ValidationHost struct { - wg sync.WaitGroup +type Host struct { stopCh chan struct{} - workerPool *validationWorkerPool + workerPool *workerPool } -func (v *ValidationHost) Start() { - v.wg.Add(1) - logger.Debug("Starting validation host") - go func() { - defer v.wg.Done() - }() -} - -func (v *ValidationHost) Stop() { - close(v.stopCh) - v.wg.Wait() -} - -func NewValidationHost() *ValidationHost { - return &ValidationHost{ +//func (v *Host) Start() { +// v.wg.Add(1) +// logger.Debug("Starting validation host") +// go func() { +// defer v.wg.Done() +// }() +//} + +//func (v *Host) Stop() { +// close(v.stopCh) +// v.wg.Wait() +//} + +func NewValidationHost() *Host { + return &Host{ stopCh: make(chan struct{}), workerPool: newValidationWorkerPool(), } } -func (v *ValidationHost) Validate(msg *ValidationTask) { - logger.Debugf("Validating worker %x", msg.WorkerID) - - validationCodeHash := msg.ValidationCode.Hash() - // basic checks - validationErr, internalErr := performBasicChecks(&msg.CandidateReceipt.Descriptor, - msg.PersistedValidationData.MaxPovSize, - msg.PoV, - validationCodeHash) - - if internalErr != nil { - logger.Errorf("performing basic checks: %w", internalErr) - intErr := &ValidationTaskResult{ - who: validationCodeHash, - InternalError: internalErr, +func (v *Host) Validate(msg *ValidationTask) <-chan *ValidationTaskResult { + resultCh := make(chan *ValidationTaskResult) + go func() { + defer close(resultCh) + logger.Debugf("Start Validating worker %x", msg.WorkerID) + validationCodeHash := msg.ValidationCode.Hash() + // performBasicChecks + validationErr, internalErr := performBasicChecks(&msg.CandidateReceipt.Descriptor, + msg.PersistedValidationData.MaxPovSize, + msg.PoV, + validationCodeHash) + + if internalErr != nil { + resultCh <- &ValidationTaskResult{ + who: validationCodeHash, + InternalError: internalErr, + } } - msg.ResultCh <- intErr - return - } - - if validationErr != nil { - valErr := &ValidationTaskResult{ - who: validationCodeHash, - Result: &ValidationResult{ - InvalidResult: validationErr, - }, + if validationErr != nil { + resultCh <- &ValidationTaskResult{ + who: validationCodeHash, + Result: &ValidationResult{InvalidResult: validationErr}, + } + } + // check if worker is in pool + workerID, err := v.poolContainsWorker(msg) + if err != nil { + resultCh <- &ValidationTaskResult{ + who: validationCodeHash, + InternalError: err, + } } - msg.ResultCh <- valErr - return - } - workerID, err := v.poolContainsWorker(msg) - if err != nil { - logger.Errorf("pool contains worker: %w", err) - intErr := &ValidationTaskResult{ - who: validationCodeHash, - InternalError: err, + // submit request + validationParams := parachainruntime.ValidationParameters{ + ParentHeadData: msg.PersistedValidationData.ParentHead, + BlockData: msg.PoV.BlockData, + RelayParentNumber: msg.PersistedValidationData.RelayParentNumber, + RelayParentStorageRoot: msg.PersistedValidationData.RelayParentStorageRoot, } - msg.ResultCh <- intErr - return - } - validationParams := parachainruntime.ValidationParameters{ - ParentHeadData: msg.PersistedValidationData.ParentHead, - BlockData: msg.PoV.BlockData, - RelayParentNumber: msg.PersistedValidationData.RelayParentNumber, - RelayParentStorageRoot: msg.PersistedValidationData.RelayParentStorageRoot, - } - workTask := &workerTask{ - work: validationParams, - maxPoVSize: msg.PersistedValidationData.MaxPovSize, - candidateReceipt: msg.CandidateReceipt, - ResultCh: msg.ResultCh, - } - v.workerPool.submitRequest(*workerID, workTask) + workTask := &workerTask{ + work: validationParams, + maxPoVSize: msg.PersistedValidationData.MaxPovSize, + candidateReceipt: msg.CandidateReceipt, + } + logger.Debugf("Working Validating worker %x", workerID) + resultWorkCh := v.workerPool.submitRequest(*workerID, workTask) + + result := <-resultWorkCh + resultCh <- result + }() + return resultCh } -func (v *ValidationHost) poolContainsWorker(msg *ValidationTask) (*parachaintypes.ValidationCodeHash, error) { +func (v *Host) poolContainsWorker(msg *ValidationTask) (*parachaintypes.ValidationCodeHash, error) { if msg.WorkerID != nil { return msg.WorkerID, nil } validationCodeHash := msg.ValidationCode.Hash() if v.workerPool.containsWorker(validationCodeHash) { - return &validationCodeHash, nil } else { return v.workerPool.newValidationWorker(*msg.ValidationCode) diff --git a/dot/parachain/pvf/host_test.go b/dot/parachain/pvf/host_test.go index 9070a9a55b..7366e969ce 100644 --- a/dot/parachain/pvf/host_test.go +++ b/dot/parachain/pvf/host_test.go @@ -6,7 +6,7 @@ import ( func Test_validationHost_start(t *testing.T) { type fields struct { - workerPool *validationWorkerPool + workerPool *workerPool } tests := map[string]struct { name string @@ -19,10 +19,10 @@ func Test_validationHost_start(t *testing.T) { for tname, tt := range tests { tt := tt t.Run(tname, func(t *testing.T) { - v := &ValidationHost{ + v := &Host{ workerPool: tt.fields.workerPool, } - v.Start() + v.Validate(&ValidationTask{}) }) } } diff --git a/dot/parachain/pvf/worker.go b/dot/parachain/pvf/worker.go index 5db740cbdf..50993b0f01 100644 --- a/dot/parachain/pvf/worker.go +++ b/dot/parachain/pvf/worker.go @@ -1,23 +1,24 @@ package pvf import ( - "sync" - parachainruntime "github.com/ChainSafe/gossamer/dot/parachain/runtime" parachaintypes "github.com/ChainSafe/gossamer/dot/parachain/types" ) +// TODO(ed): figure out a better name for this that describes what it does type worker struct { workerID parachaintypes.ValidationCodeHash instance *parachainruntime.Instance - queue chan *workerTask + // TODO(ed): determine if wasProcessed is stored here or in host + isProcessed map[parachaintypes.CandidateHash]struct{} + // TODO make this a buffered channel, and determine the buffer size + workerTasksChan chan *workerTask } type workerTask struct { work parachainruntime.ValidationParameters maxPoVSize uint32 candidateReceipt *parachaintypes.CandidateReceipt - ResultCh chan<- *ValidationTaskResult } func newWorker(validationCode parachaintypes.ValidationCode, queue chan *workerTask) (*worker, error) { @@ -27,104 +28,110 @@ func newWorker(validationCode parachaintypes.ValidationCode, queue chan *workerT return nil, err } return &worker{ - workerID: validationCode.Hash(), - instance: validationRuntime, - queue: queue, + workerID: validationCode.Hash(), + instance: validationRuntime, + workerTasksChan: queue, }, nil } -func (w *worker) run(queue chan *workerTask, wg *sync.WaitGroup) { - defer func() { - logger.Debugf("[STOPPED] worker %x", w.workerID) - wg.Done() - }() - - for task := range queue { - w.executeRequest(task) - } -} - -func (w *worker) executeRequest(task *workerTask) { +func (w *worker) executeRequest(task *workerTask) chan *ValidationTaskResult { logger.Debugf("[EXECUTING] worker %x task %v", w.workerID, task.work) + resultCh := make(chan *ValidationTaskResult) + + go func() { + defer close(resultCh) + candidateHash, err := parachaintypes.GetCandidateHash(task.candidateReceipt) + if err != nil { + // TODO: handle error + logger.Errorf("getting candidate hash: %w", err) + } - validationResult, err := w.instance.ValidateBlock(task.work) - - if err != nil { - logger.Errorf("executing validate_block: %w", err) - reasonForInvalidity := ExecutionError - errorResult := &ValidationResult{ - InvalidResult: &reasonForInvalidity, + // do isProcessed check here + if _, ok := w.isProcessed[candidateHash]; ok { + // TODO: determine what the isPreccessed check should return, and if re-trying is allowed + // get a better understanding of what the isProcessed check should be checking for + logger.Debugf("candidate %x already processed", candidateHash) } - task.ResultCh <- &ValidationTaskResult{ - who: w.workerID, - Result: errorResult, + validationResult, err := w.instance.ValidateBlock(task.work) + + if err != nil { + logger.Errorf("executing validate_block: %w", err) + reasonForInvalidity := ExecutionError + errorResult := &ValidationResult{ + InvalidResult: &reasonForInvalidity, + } + resultCh <- &ValidationTaskResult{ + who: w.workerID, + Result: errorResult, + } + return } - return - } - headDataHash, err := validationResult.HeadData.Hash() - if err != nil { - logger.Errorf("hashing head data: %w", err) - reasonForInvalidity := ExecutionError - errorResult := &ValidationResult{ - InvalidResult: &reasonForInvalidity, - } - task.ResultCh <- &ValidationTaskResult{ - who: w.workerID, - Result: errorResult, + headDataHash, err := validationResult.HeadData.Hash() + if err != nil { + logger.Errorf("hashing head data: %w", err) + reasonForInvalidity := ExecutionError + errorResult := &ValidationResult{ + InvalidResult: &reasonForInvalidity, + } + resultCh <- &ValidationTaskResult{ + who: w.workerID, + Result: errorResult, + } + return } - return - } - if headDataHash != task.candidateReceipt.Descriptor.ParaHead { - reasonForInvalidity := ParaHeadHashMismatch - errorResult := &ValidationResult{ - InvalidResult: &reasonForInvalidity, + if headDataHash != task.candidateReceipt.Descriptor.ParaHead { + reasonForInvalidity := ParaHeadHashMismatch + errorResult := &ValidationResult{ + InvalidResult: &reasonForInvalidity, + } + resultCh <- &ValidationTaskResult{ + who: w.workerID, + Result: errorResult, + } + return } - task.ResultCh <- &ValidationTaskResult{ - who: w.workerID, - Result: errorResult, + candidateCommitments := parachaintypes.CandidateCommitments{ + UpwardMessages: validationResult.UpwardMessages, + HorizontalMessages: validationResult.HorizontalMessages, + NewValidationCode: validationResult.NewValidationCode, + HeadData: validationResult.HeadData, + ProcessedDownwardMessages: validationResult.ProcessedDownwardMessages, + HrmpWatermark: validationResult.HrmpWatermark, } - return - } - candidateCommitments := parachaintypes.CandidateCommitments{ - UpwardMessages: validationResult.UpwardMessages, - HorizontalMessages: validationResult.HorizontalMessages, - NewValidationCode: validationResult.NewValidationCode, - HeadData: validationResult.HeadData, - ProcessedDownwardMessages: validationResult.ProcessedDownwardMessages, - HrmpWatermark: validationResult.HrmpWatermark, - } - // if validation produced a new set of commitments, we treat the candidate as invalid - if task.candidateReceipt.CommitmentsHash != candidateCommitments.Hash() { - reasonForInvalidity := CommitmentsHashMismatch - errorResult := &ValidationResult{ - InvalidResult: &reasonForInvalidity, + // if validation produced a new set of commitments, we treat the candidate as invalid + if task.candidateReceipt.CommitmentsHash != candidateCommitments.Hash() { + reasonForInvalidity := CommitmentsHashMismatch + errorResult := &ValidationResult{ + InvalidResult: &reasonForInvalidity, + } + resultCh <- &ValidationTaskResult{ + who: w.workerID, + Result: errorResult, + } + return } - task.ResultCh <- &ValidationTaskResult{ - who: w.workerID, - Result: errorResult, + pvd := parachaintypes.PersistedValidationData{ + ParentHead: task.work.ParentHeadData, + RelayParentNumber: task.work.RelayParentNumber, + RelayParentStorageRoot: task.work.RelayParentStorageRoot, + MaxPovSize: task.maxPoVSize, + } + validResult := &ValidationResult{ + ValidResult: &ValidValidationResult{ + CandidateCommitments: candidateCommitments, + PersistedValidationData: pvd, + }, } - return - } - pvd := parachaintypes.PersistedValidationData{ - ParentHead: task.work.ParentHeadData, - RelayParentNumber: task.work.RelayParentNumber, - RelayParentStorageRoot: task.work.RelayParentStorageRoot, - MaxPovSize: task.maxPoVSize, - } - validResult := &ValidationResult{ - ValidResult: &ValidValidationResult{ - CandidateCommitments: candidateCommitments, - PersistedValidationData: pvd, - }, - } - logger.Debugf("[RESULT] worker %x, result: %v, error: %s", w.workerID, validResult, err) + logger.Debugf("[RESULT] worker %x, result: %v, error: %s", w.workerID, validResult, err) - task.ResultCh <- &ValidationTaskResult{ - who: w.workerID, - Result: validResult, - } + resultCh <- &ValidationTaskResult{ + who: w.workerID, + Result: validResult, + } + }() + return resultCh } diff --git a/dot/parachain/pvf/worker_pool.go b/dot/parachain/pvf/worker_pool.go index 85d0f3b528..0e089dcaad 100644 --- a/dot/parachain/pvf/worker_pool.go +++ b/dot/parachain/pvf/worker_pool.go @@ -3,7 +3,6 @@ package pvf import ( "fmt" "sync" - "time" parachaintypes "github.com/ChainSafe/gossamer/dot/parachain/types" ) @@ -12,10 +11,10 @@ const ( maxRequestsAllowed uint = 60 ) -type validationWorkerPool struct { +type workerPool struct { mtx sync.RWMutex - wg sync.WaitGroup + // todo, make sure other functions work with paraID workers map[parachaintypes.ValidationCodeHash]*worker } @@ -27,7 +26,6 @@ type ValidationTask struct { ExecutorParams parachaintypes.ExecutorParams PvfExecTimeoutKind parachaintypes.PvfExecTimeoutKind ValidationCode *parachaintypes.ValidationCode - ResultCh chan<- *ValidationTaskResult } type ValidationTaskResult struct { @@ -122,66 +120,30 @@ func (ci ReasonForInvalidity) Error() string { } } -//type validationWorker struct { -// worker *worker -// queue chan *workerTask -//} - -func newValidationWorkerPool() *validationWorkerPool { - return &validationWorkerPool{ +func newValidationWorkerPool() *workerPool { + return &workerPool{ workers: make(map[parachaintypes.ValidationCodeHash]*worker), } } -// stop will shutdown all the available workers goroutines -func (v *validationWorkerPool) stop() error { - v.mtx.RLock() - defer v.mtx.RUnlock() - - for _, sw := range v.workers { - close(sw.queue) - } - - allWorkersDoneCh := make(chan struct{}) - go func() { - defer close(allWorkersDoneCh) - v.wg.Wait() - }() - - timeoutTimer := time.NewTimer(30 * time.Second) - select { - case <-timeoutTimer.C: - return fmt.Errorf("timeout reached while finishing workers") - case <-allWorkersDoneCh: - if !timeoutTimer.Stop() { - <-timeoutTimer.C - } - - return nil - } -} - -func (v *validationWorkerPool) newValidationWorker(validationCode parachaintypes.ValidationCode) (*parachaintypes. +func (v *workerPool) newValidationWorker(validationCode parachaintypes.ValidationCode) (*parachaintypes. ValidationCodeHash, error) { workerQueue := make(chan *workerTask, maxRequestsAllowed) worker, err := newWorker(validationCode, workerQueue) if err != nil { - logger.Errorf("failed to create a new worker: %w", err) - return nil, err + return nil, fmt.Errorf("failed to create a new worker: %w", err) } - v.wg.Add(1) - go worker.run(workerQueue, &v.wg) v.workers[worker.workerID] = worker return &worker.workerID, nil } -// submitRequest given a request, the worker pool will get the peer given the peer.ID -// parameter or if nil the very first available worker or -// to perform the request, the response will be dispatch in the resultCh. -func (v *validationWorkerPool) submitRequest(workerID parachaintypes.ValidationCodeHash, request *workerTask) { +// submitRequest given a request, the worker pool will get the worker for a given workerID +// a channel in returned that the response will be dispatch on +func (v *workerPool) submitRequest(workerID parachaintypes.ValidationCodeHash, + request *workerTask) chan *ValidationTaskResult { v.mtx.RLock() defer v.mtx.RUnlock() logger.Debugf("pool submit request workerID %x", workerID) @@ -192,18 +154,12 @@ func (v *validationWorkerPool) submitRequest(workerID parachaintypes.ValidationC panic("sync worker should not be nil") } logger.Debugf("sending request", workerID) - syncWorker.queue <- request - return - } - - logger.Errorf("workerID %x not found in the pool", workerID) - request.ResultCh <- &ValidationTaskResult{ - who: workerID, - InternalError: fmt.Errorf("workerID %x not found in the pool", workerID), + return syncWorker.executeRequest(request) } + return nil } -func (v *validationWorkerPool) containsWorker(workerID parachaintypes.ValidationCodeHash) bool { +func (v *workerPool) containsWorker(workerID parachaintypes.ValidationCodeHash) bool { v.mtx.RLock() defer v.mtx.RUnlock() diff --git a/dot/parachain/pvf/worker_pool_test.go b/dot/parachain/pvf/worker_pool_test.go index a692ae76ff..e29fdb9d5f 100644 --- a/dot/parachain/pvf/worker_pool_test.go +++ b/dot/parachain/pvf/worker_pool_test.go @@ -24,11 +24,11 @@ func TestValidationWorkerPool_newValidationWorker(t *testing.T) { testValidationCode := createTestValidationCode(t) cases := map[string]struct { - setupWorkerPool func(t *testing.T) *validationWorkerPool + setupWorkerPool func(t *testing.T) *workerPool expectedWorkers []parachaintypes.ValidationCodeHash }{ "add_one_invalid_worker": { - setupWorkerPool: func(t *testing.T) *validationWorkerPool { + setupWorkerPool: func(t *testing.T) *workerPool { pool := newValidationWorkerPool() _, err := pool.newValidationWorker(parachaintypes.ValidationCode{1, 2, 3, 4}) require.Error(t, err) @@ -37,7 +37,7 @@ func TestValidationWorkerPool_newValidationWorker(t *testing.T) { expectedWorkers: []parachaintypes.ValidationCodeHash{}, }, "add_one_valid_worker": { - setupWorkerPool: func(t *testing.T) *validationWorkerPool { + setupWorkerPool: func(t *testing.T) *workerPool { pool := newValidationWorkerPool() _, err := pool.newValidationWorker(testValidationCode) require.NoError(t, err) @@ -55,7 +55,6 @@ func TestValidationWorkerPool_newValidationWorker(t *testing.T) { t.Parallel() workerPool := tt.setupWorkerPool(t) - defer workerPool.stop() require.ElementsMatch(t, maps.Keys(workerPool.workers), diff --git a/go.mod b/go.mod index d9d9ebbf39..f585b2c749 100644 --- a/go.mod +++ b/go.mod @@ -16,6 +16,7 @@ require ( github.com/fatih/color v1.17.0 github.com/gammazero/deque v0.2.1 github.com/go-playground/validator/v10 v10.21.0 + github.com/golang/mock v1.6.0 github.com/google/go-cmp v0.6.0 github.com/google/uuid v1.6.0 github.com/gorilla/mux v1.8.1 diff --git a/go.sum b/go.sum index 03ec8f0ac3..990d74d772 100644 --- a/go.sum +++ b/go.sum @@ -206,6 +206,8 @@ github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da/go.mod h1:cIg4er github.com/golang/lint v0.0.0-20180702182130-06c8688daad7/go.mod h1:tluoj9z5200jBnyusfRPU2LqT6J+DAorxEvtC7LHB+E= github.com/golang/mock v1.1.1/go.mod h1:oTYuIxOrZwtPieC+H1uAHpcLFnEyAGVDL/k47Jfbm0A= github.com/golang/mock v1.2.0/go.mod h1:oTYuIxOrZwtPieC+H1uAHpcLFnEyAGVDL/k47Jfbm0A= +github.com/golang/mock v1.6.0 h1:ErTB+efbowRARo13NNdxyJji2egdxLGQhRaY+DUumQc= +github.com/golang/mock v1.6.0/go.mod h1:p6yTPP+5HYm5mzsMV8JkE6ZKdX+/wYM6Hr+LicevLPs= github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= github.com/golang/protobuf v1.3.1/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= github.com/golang/protobuf v1.3.2/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= @@ -797,6 +799,7 @@ golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtn golang.org/x/tools v0.0.0-20200130002326-2f3ba24bd6e7/go.mod h1:TB2adYChydJhpapKDTa4BR/hXlZSLoq2Wpct/0txZ28= golang.org/x/tools v0.0.0-20200619180055-7c47624df98f/go.mod h1:EkVYQZoAsY45+roYkvgYkIh4xh/qjgUK9TdY2XT94GE= golang.org/x/tools v0.0.0-20210106214847-113979e3529a/go.mod h1:emZCQorbCU4vsT4fOWvOPXz4eW1wZW4PmDk9uLelYpA= +golang.org/x/tools v0.1.1/go.mod h1:o0xws9oXOQQZyjljx8fwUC0k7L1pTE6eaCbjGeHmOkk= golang.org/x/tools v0.1.5/go.mod h1:o0xws9oXOQQZyjljx8fwUC0k7L1pTE6eaCbjGeHmOkk= golang.org/x/tools v0.21.1-0.20240508182429-e35e4ccd0d2d h1:vU5i/LfpvrRCpgM/VPfJLg5KjxD3E+hfT1SH+d9zLwg= golang.org/x/tools v0.21.1-0.20240508182429-e35e4ccd0d2d/go.mod h1:aiJjzUbINMkxbQROHiO6hDPo2LHcIPhhQsa9DLh0yGk=