From 3f641763552a24ac973d4606a47ff4bba463dbb5 Mon Sep 17 00:00:00 2001 From: Zijian Date: Fri, 1 Mar 2024 09:41:54 -0800 Subject: [PATCH] Do not get workflow execution from database when shard is closed (#5697) --- service/history/execution/context.go | 2 +- service/history/ndc/transaction_manager.go | 2 +- service/history/shard/context.go | 11 +++ service/history/shard/context_test.go | 79 ++++++++++++++++++++++ 4 files changed, 92 insertions(+), 2 deletions(-) diff --git a/service/history/execution/context.go b/service/history/execution/context.go index 4509448d336..8038e86971e 100644 --- a/service/history/execution/context.go +++ b/service/history/execution/context.go @@ -1193,7 +1193,7 @@ func (c *contextImpl) getWorkflowExecutionWithRetry( var resp *persistence.GetWorkflowExecutionResponse op := func() error { var err error - resp, err = c.executionManager.GetWorkflowExecution(ctx, request) + resp, err = c.shard.GetWorkflowExecution(ctx, request) return err } diff --git a/service/history/ndc/transaction_manager.go b/service/history/ndc/transaction_manager.go index 9c79433b5ce..e256baf4e21 100644 --- a/service/history/ndc/transaction_manager.go +++ b/service/history/ndc/transaction_manager.go @@ -389,7 +389,7 @@ func (r *transactionManagerImpl) checkWorkflowExists( if errorDomainName != nil { return false, errorDomainName } - _, err := r.shard.GetExecutionManager().GetWorkflowExecution( + _, err := r.shard.GetWorkflowExecution( ctx, &persistence.GetWorkflowExecutionRequest{ DomainID: domainID, diff --git a/service/history/shard/context.go b/service/history/shard/context.go index 118e96260aa..037a71de136 100644 --- a/service/history/shard/context.go +++ b/service/history/shard/context.go @@ -107,6 +107,7 @@ type ( GetDomainNotificationVersion() int64 UpdateDomainNotificationVersion(domainNotificationVersion int64) error + GetWorkflowExecution(ctx context.Context, request *persistence.GetWorkflowExecutionRequest) (*persistence.GetWorkflowExecutionResponse, error) CreateWorkflowExecution(ctx context.Context, request *persistence.CreateWorkflowExecutionRequest) (*persistence.CreateWorkflowExecutionResponse, error) UpdateWorkflowExecution(ctx context.Context, request *persistence.UpdateWorkflowExecutionRequest) (*persistence.UpdateWorkflowExecutionResponse, error) ConflictResolveWorkflowExecution(ctx context.Context, request *persistence.ConflictResolveWorkflowExecutionRequest) (*persistence.ConflictResolveWorkflowExecutionResponse, error) @@ -578,6 +579,16 @@ func (s *contextImpl) UpdateTimerMaxReadLevel(cluster string) time.Time { return s.timerMaxReadLevelMap[cluster] } +func (s *contextImpl) GetWorkflowExecution( + ctx context.Context, + request *persistence.GetWorkflowExecutionRequest, +) (*persistence.GetWorkflowExecutionResponse, error) { + if s.isClosed() { + return nil, ErrShardClosed + } + return s.executionManager.GetWorkflowExecution(ctx, request) +} + func (s *contextImpl) CreateWorkflowExecution( ctx context.Context, request *persistence.CreateWorkflowExecutionRequest, diff --git a/service/history/shard/context_test.go b/service/history/shard/context_test.go index 86885b3a5e6..42d0975e873 100644 --- a/service/history/shard/context_test.go +++ b/service/history/shard/context_test.go @@ -30,6 +30,7 @@ import ( "time" "github.com/golang/mock/gomock" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/mock" "github.com/stretchr/testify/require" "github.com/stretchr/testify/suite" @@ -229,3 +230,81 @@ func (s *contextTestSuite) TestGetAndUpdateProcessingQueueStates() { s.Equal(updatedTransferQueueStates[0].GetAckLevel(), s.context.GetTransferClusterAckLevel(clusterName)) s.Equal(time.Unix(0, updatedTimerQueueStates[0].GetAckLevel()), s.context.GetTimerClusterAckLevel(clusterName)) } + +func TestGetWorkflowExecution(t *testing.T) { + testCases := []struct { + name string + isClosed bool + request *persistence.GetWorkflowExecutionRequest + mockSetup func(*mocks.ExecutionManager) + expectedResult *persistence.GetWorkflowExecutionResponse + expectedError error + }{ + { + name: "Success", + request: &persistence.GetWorkflowExecutionRequest{ + DomainID: "testDomain", + Execution: types.WorkflowExecution{WorkflowID: "testWorkflowID", RunID: "testRunID"}, + }, + mockSetup: func(mgr *mocks.ExecutionManager) { + mgr.On("GetWorkflowExecution", mock.Anything, mock.Anything).Return(&persistence.GetWorkflowExecutionResponse{ + State: &persistence.WorkflowMutableState{ + ExecutionInfo: &persistence.WorkflowExecutionInfo{ + DomainID: "testDomain", + WorkflowID: "testWorkflowID", + RunID: "testRunID", + }, + }, + }, nil) + }, + expectedResult: &persistence.GetWorkflowExecutionResponse{ + State: &persistence.WorkflowMutableState{ + ExecutionInfo: &persistence.WorkflowExecutionInfo{ + DomainID: "testDomain", + WorkflowID: "testWorkflowID", + RunID: "testRunID", + }, + }, + }, + expectedError: nil, + }, + { + name: "Error", + request: &persistence.GetWorkflowExecutionRequest{ + DomainID: "testDomain", + Execution: types.WorkflowExecution{WorkflowID: "testWorkflowID", RunID: "testRunID"}, + }, + mockSetup: func(mgr *mocks.ExecutionManager) { + mgr.On("GetWorkflowExecution", mock.Anything, mock.Anything).Return(nil, errors.New("some random error")) + }, + expectedResult: nil, + expectedError: errors.New("some random error"), + }, + { + name: "Shard closed", + isClosed: true, + request: &persistence.GetWorkflowExecutionRequest{ + DomainID: "testDomain", + Execution: types.WorkflowExecution{WorkflowID: "testWorkflowID", RunID: "testRunID"}, + }, + mockSetup: func(mgr *mocks.ExecutionManager) {}, + expectedResult: nil, + expectedError: ErrShardClosed, + }, + } + + for _, tc := range testCases { + mockExecutionMgr := &mocks.ExecutionManager{} + shardContext := &contextImpl{ + executionManager: mockExecutionMgr, + } + if tc.isClosed { + shardContext.closed = 1 + } + tc.mockSetup(mockExecutionMgr) + + result, err := shardContext.GetWorkflowExecution(context.Background(), tc.request) + assert.Equal(t, tc.expectedResult, result) + assert.Equal(t, tc.expectedError, err) + } +}