Skip to content

Commit

Permalink
Add test environment support for Nexus Operations
Browse files Browse the repository at this point in the history
  • Loading branch information
bergundy committed May 16, 2024
1 parent 16d6fd4 commit 6cadae7
Show file tree
Hide file tree
Showing 4 changed files with 863 additions and 10 deletions.
251 changes: 242 additions & 9 deletions internal/internal_workflow_testsuite.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,12 +29,15 @@ import (
"errors"
"fmt"
"reflect"
"strconv"
"strings"
"sync"
"time"

"github.com/facebookgo/clock"
"github.com/golang/mock/gomock"
"github.com/google/uuid"
"github.com/nexus-rpc/sdk-go/nexus"
"github.com/robfig/cron"
"github.com/stretchr/testify/mock"
"google.golang.org/grpc"
Expand All @@ -44,6 +47,8 @@ import (
commandpb "go.temporal.io/api/command/v1"
commonpb "go.temporal.io/api/common/v1"
enumspb "go.temporal.io/api/enums/v1"
failurepb "go.temporal.io/api/failure/v1"
nexuspb "go.temporal.io/api/nexus/v1"
"go.temporal.io/api/serviceerror"
taskqueuepb "go.temporal.io/api/taskqueue/v1"
"go.temporal.io/api/workflowservice/v1"
Expand Down Expand Up @@ -97,6 +102,18 @@ type (
err error
}

testNexusOperationHandle struct {
env *testWorkflowEnvironmentImpl
seq int64
params executeNexusOperationParams
operationID string
cancelRequested bool
started bool
done bool
onCompleted func(*commonpb.Payload, error)
onStarted func(opID string, e error)
}

testCallbackHandle struct {
callback func()
startWorkflowTask bool // start a new workflow task after callback() is handled.
Expand Down Expand Up @@ -149,11 +166,12 @@ type (
testTimeout time.Duration
header *commonpb.Header

counterID int64
activities map[string]*testActivityHandle
localActivities map[string]*localActivityTask
timers map[string]*testTimerHandle
runningWorkflows map[string]*testWorkflowHandle
counterID int64
activities map[string]*testActivityHandle
localActivities map[string]*localActivityTask
timers map[string]*testTimerHandle
runningWorkflows map[string]*testWorkflowHandle
runningNexusOperations map[int64]*testNexusOperationHandle

runningCount int

Expand Down Expand Up @@ -240,6 +258,7 @@ func newTestWorkflowEnvironmentImpl(s *WorkflowTestSuite, parentRegistry *regist
activities: make(map[string]*testActivityHandle),
localActivities: make(map[string]*localActivityTask),
runningWorkflows: make(map[string]*testWorkflowHandle),
runningNexusOperations: make(map[int64]*testNexusOperationHandle),
callbackChannel: make(chan testCallbackHandle, 1000),
testTimeout: 3 * time.Second,
expectedWorkflowMockCalls: make(map[string]struct{}),
Expand Down Expand Up @@ -2121,6 +2140,10 @@ func (env *testWorkflowEnvironmentImpl) RegisterActivityWithOptions(a interface{
env.registry.RegisterActivityWithOptions(a, options)
}

func (env *testWorkflowEnvironmentImpl) RegisterNexusService(s *nexus.Service) {
env.registry.RegisterNexusService(s)
}

func (env *testWorkflowEnvironmentImpl) RegisterCancelHandler(handler func()) {
env.workflowCancelHandler = handler
}
Expand Down Expand Up @@ -2279,12 +2302,133 @@ func (env *testWorkflowEnvironmentImpl) executeChildWorkflowWithDelay(delayStart
go childEnv.executeWorkflowInternal(delayStart, params.WorkflowType.Name, params.Input)
}

func (wc *testWorkflowEnvironmentImpl) ExecuteNexusOperation(params executeNexusOperationParams, callback func(*commonpb.Payload, error), startedHandler func(opID string, e error)) int64 {
panic("TODO")
func (env *testWorkflowEnvironmentImpl) newTestNexusTaskHandler() *nexusTaskHandler {
if len(env.registry.nexusServices) == 0 {
panic(fmt.Errorf("no nexus services registered"))
}

reg := nexus.NewServiceRegistry()
for _, service := range env.registry.nexusServices {
if err := reg.Register(service); err != nil {
panic(fmt.Errorf("failed to register nexus service '%v': %w", service, err))
}
}
handler, err := reg.NewHandler()
if err != nil {
panic(fmt.Errorf("failed to create nexus handler: %w", err))
}

return newNexusTaskHandler(
handler,
env.identity,
env.workflowInfo.Namespace,
env.workflowInfo.TaskQueueName,
&testSuiteClientForNexusOperations{env},
env.dataConverter,
env.logger,
env.metricsHandler,
)
}

func (env *testWorkflowEnvironmentImpl) ExecuteNexusOperation(params executeNexusOperationParams, callback func(*commonpb.Payload, error), startedHandler func(opID string, e error)) int64 {
seq := env.nextID()
taskHandler := env.newTestNexusTaskHandler()
handle := &testNexusOperationHandle{
env: env,
seq: seq,
params: params,
onCompleted: callback,
onStarted: startedHandler,
}
env.runningNexusOperations[seq] = handle

task := handle.newStartTask()
env.runningCount++
go func() {
response, failure, err := taskHandler.Execute(task)
if err != nil {
// No retries for operations, fail the operation immediately.
failure = taskHandler.fillInFailure(task.TaskToken, nexusHandlerError(nexus.HandlerErrorTypeInternal, err.Error()))
}
if failure != nil {
err := env.failureConverter.FailureToError(nexusOperationFailure(params, "", &failurepb.Failure{
Message: failure.GetError().GetFailure().GetMessage(),
FailureInfo: &failurepb.Failure_ApplicationFailureInfo{
ApplicationFailureInfo: &failurepb.ApplicationFailureInfo{
NonRetryable: true,
},
},
}))
env.postCallback(func() {
handle.startedCallback("", err)
handle.completedCallback(nil, err)
}, true)
return
} else {
switch v := response.GetResponse().GetStartOperation().GetVariant().(type) {
case *nexuspb.StartOperationResponse_SyncSuccess:
env.postCallback(func() {
handle.startedCallback("", nil)
handle.completedCallback(v.SyncSuccess.GetPayload(), nil)
}, true)
case *nexuspb.StartOperationResponse_AsyncSuccess:
env.postCallback(func() {
handle.startedCallback(v.AsyncSuccess.GetOperationId(), nil)
if handle.cancelRequested {
handle.cancel()
}
}, true)
case *nexuspb.StartOperationResponse_OperationError:
err := env.failureConverter.FailureToError(
nexusOperationFailure(params, "", unsuccessfulOperationErrorToTemporalFailure(v.OperationError)),
)
env.postCallback(func() {
handle.startedCallback("", err)
handle.completedCallback(nil, err)
}, true)
default:
panic(fmt.Errorf("unknown response variant: %v", v))
}
}
}()
return seq
}

func (env *testWorkflowEnvironmentImpl) RequestCancelNexusOperation(seq int64) {
handle, ok := env.runningNexusOperations[seq]
if !ok {
panic(fmt.Errorf("no running operation found for sequence: %d", seq))
}

// Avoid duplicate cancelation.
if handle.cancelRequested {
return
}

// Mark this cancelation request in case the operation hasn't started yet.
// Cancel will be called after start.
handle.cancelRequested = true

// Only cancel after started, we need an operation ID.
if handle.started {
handle.cancel()
}
}

func (wc *testWorkflowEnvironmentImpl) RequestCancelNexusOperation(seq int64) {
panic("TODO")
func (env *testWorkflowEnvironmentImpl) resolveNexusOperation(seq int64, result *commonpb.Payload, err error) {
env.postCallback(func() {
handle, ok := env.runningNexusOperations[seq]
if !ok {
panic(fmt.Errorf("no running operation found for sequence: %d", seq))
}
if err != nil {
failure := env.failureConverter.ErrorToFailure(err)
err = env.failureConverter.FailureToError(nexusOperationFailure(handle.params, handle.operationID, failure.GetCause()))
handle.completedCallback(nil, err)
} else {
handle.completedCallback(result, nil)
}
}, true)
}

func (env *testWorkflowEnvironmentImpl) SideEffect(f func() (*commonpb.Payloads, error), callback ResultHandler) {
Expand Down Expand Up @@ -2665,3 +2809,92 @@ func mockFnGetVersion(string, Version, Version) Version {

// make sure interface is implemented
var _ WorkflowEnvironment = (*testWorkflowEnvironmentImpl)(nil)

func (h *testNexusOperationHandle) newStartTask() *workflowservice.PollNexusTaskQueueResponse {
return &workflowservice.PollNexusTaskQueueResponse{
TaskToken: []byte{},
Request: &nexuspb.Request{
ScheduledTime: timestamppb.Now(),
Header: h.params.nexusHeader,
Variant: &nexuspb.Request_StartOperation{
StartOperation: &nexuspb.StartOperationRequest{
Service: h.params.client.Service(),
Operation: h.params.operation,
RequestId: uuid.NewString(),
// This is effectively ignored.
Callback: "http://test-env/operations",
CallbackHeader: map[string]string{
// The test client uses this to call resolveNexusOperation.
"operation-sequence": strconv.FormatInt(h.seq, 10),
},
Payload: h.params.input,
},
},
},
}
}

func (h *testNexusOperationHandle) newCancelTask() *workflowservice.PollNexusTaskQueueResponse {
return &workflowservice.PollNexusTaskQueueResponse{
TaskToken: []byte{},
Request: &nexuspb.Request{
ScheduledTime: timestamppb.Now(),
Header: h.params.nexusHeader,
Variant: &nexuspb.Request_CancelOperation{
CancelOperation: &nexuspb.CancelOperationRequest{
Service: h.params.client.Service(),
Operation: h.params.operation,
OperationId: h.operationID,
},
},
},
}
}

// completedCallback is a callback registered to handle operation completion.
// Must be called in a postCallback block.
func (h *testNexusOperationHandle) completedCallback(result *commonpb.Payload, err error) {
if h.done {
// Ignore duplicate completions.
return
}
h.done = true
delete(h.env.runningNexusOperations, h.seq)
h.onCompleted(result, err)
}

// startedCallback is a callback registered to handle operation start.
// Must be called in a postCallback block.
func (h *testNexusOperationHandle) startedCallback(opID string, e error) {
h.operationID = opID
h.started = true
h.onStarted(opID, e)
h.env.runningCount--
}

func (h *testNexusOperationHandle) cancel() {
if h.done {
return
}
if h.started && h.operationID == "" {
panic(fmt.Errorf("incomplete operation has no operation ID: (%s, %s, %s)",
h.params.client.Endpoint(), h.params.client.Service(), h.params.operation))
}
h.env.runningCount++
task := h.newCancelTask()
taskHandler := h.env.newTestNexusTaskHandler()

go func() {
_, failure, err := taskHandler.Execute(task)
h.env.postCallback(func() {
if err != nil {
// No retries in the test env, fail the operation immediately.
h.completedCallback(nil, fmt.Errorf("operation cancelation handler failed: %w", err))
} else if failure != nil {
// No retries in the test env, fail the operation immediately.
h.completedCallback(nil, fmt.Errorf("operation cancelation handler failed: %v", failure.GetError().GetFailure().GetMessage()))
}
h.env.runningCount--
}, false)
}()
}
Loading

0 comments on commit 6cadae7

Please sign in to comment.