From 4f2835bf793932089bbf66c0d165c7ce36f3422c Mon Sep 17 00:00:00 2001 From: Roey Berman Date: Mon, 13 May 2024 09:43:09 -0700 Subject: [PATCH 1/3] Execute nexus operation from a workflow --- interceptor/interceptor.go | 5 + internal/client.go | 5 +- internal/error.go | 46 ++ internal/failure_converter.go | 20 + internal/interceptor.go | 23 + internal/interceptor_base.go | 18 + internal/internal_command_state_machine.go | 240 ++++++++- internal/internal_event_handlers.go | 127 +++++ internal/internal_logging_tags.go | 1 + internal/internal_nexus_task_poller.go | 2 +- internal/internal_nexus_worker.go | 2 +- internal/internal_task_handlers.go | 20 +- internal/internal_worker_base.go | 10 + internal/internal_workflow.go | 9 + internal/internal_workflow_testsuite.go | 8 + internal/workflow.go | 157 ++++++ temporal/error.go | 5 + test/nexus_test.go | 573 +++++++++++++++++---- workflow/nexus_example_test.go | 50 ++ workflow/workflow.go | 25 + 20 files changed, 1229 insertions(+), 117 deletions(-) create mode 100644 workflow/nexus_example_test.go diff --git a/interceptor/interceptor.go b/interceptor/interceptor.go index 73f9cadde..6061e88b1 100644 --- a/interceptor/interceptor.go +++ b/interceptor/interceptor.go @@ -131,6 +131,11 @@ type HandleQueryInput = internal.HandleQueryInput // NOTE: Experimental type UpdateInput = internal.UpdateInput +// RequestCancelNexusOperationInput is the input to WorkflowOutboundInterceptor.RequestCancelNexusOperation. +// +// NOTE: Experimental +type RequestCancelNexusOperationInput = internal.RequestCancelNexusOperationInput + // WorkflowOutboundInterceptor is an interface for all workflow calls // originating from the SDK. // diff --git a/internal/client.go b/internal/client.go index 9e02584a8..296fce9f8 100644 --- a/internal/client.go +++ b/internal/client.go @@ -31,7 +31,6 @@ import ( "sync/atomic" "time" - "go.temporal.io/api/common/v1" commonpb "go.temporal.io/api/common/v1" enumspb "go.temporal.io/api/enums/v1" "go.temporal.io/api/operatorservice/v1" @@ -649,7 +648,7 @@ type ( // request ID. Only settable by the SDK - e.g. [temporalnexus.workflowRunOperation]. requestID string // workflow completion callback. Only settable by the SDK - e.g. [temporalnexus.workflowRunOperation]. - callbacks []*common.Callback + callbacks []*commonpb.Callback } // RetryPolicy defines the retry policy. @@ -1004,6 +1003,6 @@ func SetRequestIDOnStartWorkflowOptions(opts *StartWorkflowOptions, requestID st // SetCallbacksOnStartWorkflowOptions is an internal only method for setting callbacks on StartWorkflowOptions. // Callbacks are purposefully not exposed to users for the time being. -func SetCallbacksOnStartWorkflowOptions(opts *StartWorkflowOptions, callbacks []*common.Callback) { +func SetCallbacksOnStartWorkflowOptions(opts *StartWorkflowOptions, callbacks []*commonpb.Callback) { opts.callbacks = callbacks } diff --git a/internal/error.go b/internal/error.go index 69ca659da..ba844c0ca 100644 --- a/internal/error.go +++ b/internal/error.go @@ -255,6 +255,28 @@ type ( cause error } + // NexusOperationError is an error returned when a Nexus Operation has failed. + // + // NOTE: Experimental + NexusOperationError struct { + // The raw proto failure object this error was created from. + Failure *failurepb.Failure + // Error message. + Message string + // ID of the NexusOperationScheduled event. + ScheduledEventID int64 + // Endpoint name. + Endpoint string + // Service name. + Service string + // Operation name. + Operation string + // Operation ID - may be empty if the operation completed synchronously. + OperationID string + // Chained cause - typically an ApplicationError or a CanceledError. + Cause error + } + // ChildWorkflowExecutionAlreadyStartedError is set as the cause of // ChildWorkflowExecutionError when failure is due the child workflow having // already started. @@ -800,6 +822,30 @@ func (e *ChildWorkflowExecutionError) Unwrap() error { return e.cause } +func (e *NexusOperationError) Error() string { + msg := fmt.Sprintf( + "%s (endpoint: %q, service: %q, operation: %q, operation ID: %q, scheduledEventID: %d)", + e.Message, e.Endpoint, e.Service, e.Operation, e.OperationID, e.ScheduledEventID) + if e.Cause != nil { + msg = fmt.Sprintf("%s: %v", msg, e.Cause) + } + return msg +} + +// setFailure implements the failureHolder interface for consistency with other failure based errors.. +func (e *NexusOperationError) setFailure(f *failurepb.Failure) { + e.Failure = f +} + +// failure implements the failureHolder interface for consistency with other failure based errors. +func (e *NexusOperationError) failure() *failurepb.Failure { + return e.Failure +} + +func (e *NexusOperationError) Unwrap() error { + return e.Cause +} + // Error from error interface func (*NamespaceNotFoundError) Error() string { return "namespace not found" diff --git a/internal/failure_converter.go b/internal/failure_converter.go index d16063895..c306f1cc3 100644 --- a/internal/failure_converter.go +++ b/internal/failure_converter.go @@ -160,6 +160,15 @@ func (dfc *DefaultFailureConverter) ErrorToFailure(err error) *failurepb.Failure RetryState: err.retryState, } failure.FailureInfo = &failurepb.Failure_ChildWorkflowExecutionFailureInfo{ChildWorkflowExecutionFailureInfo: failureInfo} + case *NexusOperationError: + failureInfo := &failurepb.NexusOperationFailureInfo{ + ScheduledEventId: err.ScheduledEventID, + Endpoint: err.Endpoint, + Service: err.Service, + Operation: err.Operation, + OperationId: err.OperationID, + } + failure.FailureInfo = &failurepb.Failure_NexusOperationExecutionFailureInfo{NexusOperationExecutionFailureInfo: failureInfo} default: // All unknown errors are considered to be retryable ApplicationFailureInfo. failureInfo := &failurepb.ApplicationFailureInfo{ Type: getErrType(err), @@ -254,6 +263,17 @@ func (dfc *DefaultFailureConverter) FailureToError(failure *failurepb.Failure) e childWorkflowExecutionFailureInfo.GetRetryState(), dfc.FailureToError(failure.GetCause()), ) + } else if info := failure.GetNexusOperationExecutionFailureInfo(); info != nil { + err = &NexusOperationError{ + Message: failure.Message, + Cause: dfc.FailureToError(failure.GetCause()), + Failure: originalFailure, + ScheduledEventID: info.GetScheduledEventId(), + Endpoint: info.GetEndpoint(), + Service: info.GetService(), + Operation: info.GetOperation(), + OperationID: info.GetOperationId(), + } } if err == nil { diff --git a/internal/interceptor.go b/internal/interceptor.go index 3a6210606..b942307ae 100644 --- a/internal/interceptor.go +++ b/internal/interceptor.go @@ -169,6 +169,20 @@ type HandleQueryInput struct { Args []interface{} } +// RequestCancelNexusOperationInput is the input to WorkflowOutboundInterceptor.RequestCancelNexusOperation. +// +// NOTE: Experimental +type RequestCancelNexusOperationInput struct { + // Client that was used to start the operation. + Client NexusClient + // Operation name. + Operation any + // Operation ID. May be empty if the operation is synchronous or has not started yet. + ID string + // seq number. For internal use only. + seq int64 +} + // WorkflowOutboundInterceptor is an interface for all workflow calls // originating from the SDK. See documentation in the interceptor package for // more details. @@ -283,6 +297,15 @@ type WorkflowOutboundInterceptor interface { // interceptor.WorkflowHeader will return a non-nil map for this context. NewContinueAsNewError(ctx Context, wfn interface{}, args ...interface{}) error + // ExecuteNexusOperation intercepts NexusClient.ExecuteOperation. + // + // NOTE: Experimental + ExecuteNexusOperation(ctx Context, client NexusClient, operation any, input any, options NexusOperationOptions) NexusOperationFuture + // RequestCancelNexusOperation intercepts Nexus Operation cancelation via context. + // + // NOTE: Experimental + RequestCancelNexusOperation(ctx Context, input RequestCancelNexusOperationInput) + mustEmbedWorkflowOutboundInterceptorBase() } diff --git a/internal/interceptor_base.go b/internal/interceptor_base.go index 1a28cfa02..dc284a7f4 100644 --- a/internal/interceptor_base.go +++ b/internal/interceptor_base.go @@ -389,6 +389,24 @@ func (w *WorkflowOutboundInterceptorBase) NewContinueAsNewError( return w.Next.NewContinueAsNewError(ctx, wfn, args...) } +// ExecuteNexusOperation implements +// WorkflowOutboundInterceptor.ExecuteNexusOperation. +func (w *WorkflowOutboundInterceptorBase) ExecuteNexusOperation( + ctx Context, + client NexusClient, + operation any, + input any, + options NexusOperationOptions, +) NexusOperationFuture { + return w.Next.ExecuteNexusOperation(ctx, client, operation, input, options) +} + +// RequestCancelNexusOperation implements +// WorkflowOutboundInterceptor.RequestCancelNexusOperation. +func (w *WorkflowOutboundInterceptorBase) RequestCancelNexusOperation(ctx Context, input RequestCancelNexusOperationInput) { + w.Next.RequestCancelNexusOperation(ctx, input) +} + func (*WorkflowOutboundInterceptorBase) mustEmbedWorkflowOutboundInterceptorBase() {} // ClientInterceptorBase is a default implementation of ClientInterceptor meant diff --git a/internal/internal_command_state_machine.go b/internal/internal_command_state_machine.go index c23be27a7..0a3ac6abb 100644 --- a/internal/internal_command_state_machine.go +++ b/internal/internal_command_state_machine.go @@ -132,6 +132,35 @@ type ( *naiveCommandStateMachine } + // nexusOperationStateMachine is the state machine for the NexusOperation lifecycle. + // It may never transition to the started state if the operation completes synchronously. + // Valid transitions: + // commandStateCreated -> commandStateCommandSent + // commandStateCommandSent - (NexusOperationScheduled) -> commandStateInitiated + // commandStateInitiated - (NexusOperationStarted) -> commandStateStarted + // commandStateInitiated - (NexusOperation(Completed|Failed|Canceled|TimedOut)) -> commandStateCompleted + // commandStateStarted - (NexusOperation(Completed|Failed|Canceled|TimedOut)) -> commandStateCompleted + nexusOperationStateMachine struct { + *commandStateMachineBase + // Unique sequence number for identifying this machine SDK side. + seq int64 + // Event ID of the NexusOperationScheduled event for correlating progress events with this machine. + scheduledEventID int64 + attributes *commandpb.ScheduleNexusOperationCommandAttributes + // Instead of tracking cancelation as a state, we track it as a separate dimension with the request-cancel state + // machine. + cancelation *requestCancelNexusOperationStateMachine + } + + // requestCancelNexusOperationStateMachine is the state machine for the RequestCancelNexusOperation command. + // Valid transitions: + // commandStateCreated -> commandStateCommandSent + // commandStateCommandSent - (NexusOperationCancelRequested) -> commandStateCompleted + requestCancelNexusOperationStateMachine struct { + *commandStateMachineBase + attributes *commandpb.RequestCancelNexusOperationCommandAttributes + } + versionMarker struct { changeID string searchAttrUpdated bool @@ -146,6 +175,15 @@ type ( scheduledEventIDToCancellationID map[int64]string scheduledEventIDToSignalID map[int64]string versionMarkerLookup map[int64]versionMarker + + // A mapping of scheduled event ID to a sequence. + scheduledEventIDToNexusSeq map[int64]int64 + // A list containing all nexus operation machines that have not yet been assigned a scheduled event ID. + // Every new operation state machine is added to this list on creation and deleted once the scheduled event is + // seen or the operation was deleted before sending the command. + // This mechanism is based on Core SDK + // (https://github.com/temporalio/sdk-core/blob/16c7a33dc1aec8fafb33c9ad6f77569a3dacc8ea/core/src/worker/workflow/machines/workflow_machines.rs#L837). + nexusOperationsWithoutScheduledID *list.List } // panic when command or message state machine is in illegal state @@ -176,20 +214,22 @@ const ( ) const ( - commandTypeActivity commandType = 0 - commandTypeChildWorkflow commandType = 1 - commandTypeCancellation commandType = 2 - commandTypeMarker commandType = 3 - commandTypeTimer commandType = 4 - commandTypeSignal commandType = 5 - commandTypeUpsertSearchAttributes commandType = 6 - commandTypeCancelTimer commandType = 7 - commandTypeRequestCancelActivityTask commandType = 8 - commandTypeAcceptWorkflowUpdate commandType = 9 - commandTypeCompleteWorkflowUpdate commandType = 10 - commandTypeModifyProperties commandType = 11 - commandTypeRejectWorkflowUpdate commandType = 12 - commandTypeProtocolMessage commandType = 13 + commandTypeActivity commandType = 0 + commandTypeChildWorkflow commandType = 1 + commandTypeCancellation commandType = 2 + commandTypeMarker commandType = 3 + commandTypeTimer commandType = 4 + commandTypeSignal commandType = 5 + commandTypeUpsertSearchAttributes commandType = 6 + commandTypeCancelTimer commandType = 7 + commandTypeRequestCancelActivityTask commandType = 8 + commandTypeAcceptWorkflowUpdate commandType = 9 + commandTypeCompleteWorkflowUpdate commandType = 10 + commandTypeModifyProperties commandType = 11 + commandTypeRejectWorkflowUpdate commandType = 12 + commandTypeProtocolMessage commandType = 13 + commandTypeNexusOperation commandType = 14 + commandTypeRequestCancelNexusOperation commandType = 15 ) const ( @@ -276,6 +316,10 @@ func (d commandType) String() string { return "CompleteWorkflowUpdate" case commandTypeRejectWorkflowUpdate: return "RejectWorkflowUpdate" + case commandTypeNexusOperation: + return "NexusOperation" + case commandTypeRequestCancelNexusOperation: + return "RequestCancelNexusOperation" default: return "Unknown" } @@ -318,6 +362,29 @@ func (h *commandsHelper) newCancelActivityStateMachine(attributes *commandpb.Req } } +func (h *commandsHelper) newNexusOperationStateMachine( + seq int64, + attributes *commandpb.ScheduleNexusOperationCommandAttributes, +) *nexusOperationStateMachine { + base := h.newCommandStateMachineBase(commandTypeNexusOperation, strconv.FormatInt(seq, 10)) + sm := &nexusOperationStateMachine{ + commandStateMachineBase: base, + attributes: attributes, + seq: seq, + // scheduledEventID will be assigned by the server when the corresponding event comes in. + } + h.nexusOperationsWithoutScheduledID.PushBack(sm) + return sm +} + +func (h *commandsHelper) newRequestCancelNexusOperationStateMachine(attributes *commandpb.RequestCancelNexusOperationCommandAttributes) *requestCancelNexusOperationStateMachine { + base := h.newCommandStateMachineBase(commandTypeRequestCancelNexusOperation, strconv.FormatInt(attributes.GetScheduledEventId(), 10)) + return &requestCancelNexusOperationStateMachine{ + commandStateMachineBase: base, + attributes: attributes, + } +} + func (h *commandsHelper) newTimerCommandStateMachine(attributes *commandpb.StartTimerCommandAttributes) *timerCommandStateMachine { base := h.newCommandStateMachineBase(commandTypeTimer, attributes.GetTimerId()) return &timerCommandStateMachine{ @@ -853,15 +920,87 @@ func (d *modifyPropertiesCommandStateMachine) handleCommandSent() { } } +func (sm *nexusOperationStateMachine) getCommand() *commandpb.Command { + if sm.state == commandStateCreated && sm.cancelation == nil { + // Only create the command in this state unlike other machines that also create it if canceled before sent. + return &commandpb.Command{ + CommandType: enumspb.COMMAND_TYPE_SCHEDULE_NEXUS_OPERATION, + Attributes: &commandpb.Command_ScheduleNexusOperationCommandAttributes{ + ScheduleNexusOperationCommandAttributes: sm.attributes, + }, + } + } + return nil +} + +func (sm *nexusOperationStateMachine) handleStartedEvent() { + switch sm.state { + case commandStateInitiated: + sm.moveState(commandStateStarted, eventStarted) + default: + sm.failStateTransition(eventStarted) + } +} + +func (sm *nexusOperationStateMachine) handleCompletionEvent() { + switch sm.state { + case commandStateInitiated, + commandStateStarted: + sm.moveState(commandStateCompleted, eventCompletion) + default: + sm.failStateTransition(eventStarted) + } +} + +func (sm *nexusOperationStateMachine) cancel() { + // Already canceled or already completed. + if sm.cancelation != nil || sm.state == commandStateCompleted { + return + } + + attribs := &commandpb.RequestCancelNexusOperationCommandAttributes{ + ScheduledEventId: sm.scheduledEventID, + } + cancelCmd := sm.helper.newRequestCancelNexusOperationStateMachine(attribs) + sm.cancelation = cancelCmd + sm.helper.addCommand(cancelCmd) + + // No need to actually send the cancelation, mark the state machine as completed. + if sm.state == commandStateCreated { + cancelCmd.handleCompletionEvent() + } +} + +func (d *requestCancelNexusOperationStateMachine) getCommand() *commandpb.Command { + switch d.state { + case commandStateCreated: + command := createNewCommand(enumspb.COMMAND_TYPE_REQUEST_CANCEL_NEXUS_OPERATION) + command.Attributes = &commandpb.Command_RequestCancelNexusOperationCommandAttributes{RequestCancelNexusOperationCommandAttributes: d.attributes} + return command + default: + return nil + } +} + +func (d *requestCancelNexusOperationStateMachine) handleCompletionEvent() { + if d.state != commandStateCommandSent && d.state != commandStateCreated { + d.failStateTransition(eventCompletion) + return + } + d.moveState(commandStateCompleted, eventCompletion) +} + func newCommandsHelper() *commandsHelper { return &commandsHelper{ orderedCommands: list.New(), commands: make(map[commandID]*list.Element), - scheduledEventIDToActivityID: make(map[int64]string), - scheduledEventIDToCancellationID: make(map[int64]string), - scheduledEventIDToSignalID: make(map[int64]string), - versionMarkerLookup: make(map[int64]versionMarker), + scheduledEventIDToActivityID: make(map[int64]string), + scheduledEventIDToCancellationID: make(map[int64]string), + scheduledEventIDToSignalID: make(map[int64]string), + versionMarkerLookup: make(map[int64]versionMarker), + scheduledEventIDToNexusSeq: make(map[int64]int64), + nexusOperationsWithoutScheduledID: list.New(), } } @@ -1040,6 +1179,71 @@ func (h *commandsHelper) getActivityAndScheduledEventIDs(event *historypb.Histor return activityID, scheduledEventID } +func (h *commandsHelper) scheduleNexusOperation( + seq int64, + attributes *commandpb.ScheduleNexusOperationCommandAttributes, +) *nexusOperationStateMachine { + command := h.newNexusOperationStateMachine(seq, attributes) + h.addCommand(command) + return command +} + +func (h *commandsHelper) handleNexusOperationScheduled(event *historypb.HistoryEvent) { + elem := h.nexusOperationsWithoutScheduledID.Front() + if elem == nil { + panicIllegalState(fmt.Sprintf("[TMPRL1100] unable to find nexus operation state machine for event: %v", util.HistoryEventToString(event))) + } + command := h.nexusOperationsWithoutScheduledID.Remove(elem).(*nexusOperationStateMachine) + + command.scheduledEventID = event.EventId + h.scheduledEventIDToNexusSeq[event.EventId] = command.seq + command.handleInitiatedEvent() +} + +func (h *commandsHelper) handleNexusOperationStarted(scheduledEventID int64) commandStateMachine { + seq, ok := h.scheduledEventIDToNexusSeq[scheduledEventID] + if !ok { + panicIllegalState(fmt.Sprintf("[TMPRL1100] unable to find nexus operation state machine for event ID: %v", scheduledEventID)) + } + command := h.getCommand(makeCommandID(commandTypeNexusOperation, strconv.FormatInt(seq, 10))) + command.handleStartedEvent() + return command +} + +func (h *commandsHelper) handleNexusOperationCompleted(scheduledEventID int64) commandStateMachine { + seq, ok := h.scheduledEventIDToNexusSeq[scheduledEventID] + if !ok { + panicIllegalState(fmt.Sprintf("[TMPRL1100] unable to find nexus operation state machine for event ID: %v", scheduledEventID)) + } + // We don't need this anymore, the state will not transition after completion. + delete(h.scheduledEventIDToNexusSeq, scheduledEventID) + command := h.getCommand(makeCommandID(commandTypeNexusOperation, strconv.FormatInt(seq, 10))) + command.handleCompletionEvent() + return command +} + +func (h *commandsHelper) handleNexusOperationCancelRequested(scheduledEventID int64) { + command := h.getCommand(makeCommandID(commandTypeRequestCancelNexusOperation, strconv.FormatInt(scheduledEventID, 10))) + command.handleCompletionEvent() +} + +func (h *commandsHelper) requestCancelNexusOperation(seq int64) commandStateMachine { + command := h.getCommand(makeCommandID(commandTypeNexusOperation, strconv.FormatInt(seq, 10))) + command.cancel() + // If we haven't sent the command yet, ensure that it doesn't get mapped to the wrong scheduledEventID. + if command.getState() != commandStateCanceledBeforeSent { + return command + } + for elem := h.nexusOperationsWithoutScheduledID.Front(); elem != nil; elem = elem.Next() { + sm := elem.Value.(*nexusOperationStateMachine) + if sm.seq == seq { + h.nexusOperationsWithoutScheduledID.Remove(elem) + break + } + } + return command +} + func (h *commandsHelper) recordVersionMarker(changeID string, version Version, dc converter.DataConverter, searchAttributeWasUpdated bool) commandStateMachine { markerID := fmt.Sprintf("%v_%v", versionMarkerName, changeID) diff --git a/internal/internal_event_handlers.go b/internal/internal_event_handlers.go index d60e82d2b..b382aa720 100644 --- a/internal/internal_event_handlers.go +++ b/internal/internal_event_handlers.go @@ -84,6 +84,14 @@ type ( activityType ActivityType } + scheduledNexusOperation struct { + startedCallback func(operationID string, err error) + completedCallback func(result *commonpb.Payload, err error) + endpoint string + service string + operation string + } + scheduledChildWorkflow struct { resultCallback ResultHandler startedCallback func(r WorkflowExecution, e error) @@ -613,6 +621,57 @@ func (wc *workflowEnvironmentImpl) ExecuteChildWorkflow( tagWorkflowType, params.WorkflowType.Name) } +func (wc *workflowEnvironmentImpl) ExecuteNexusOperation(params executeNexusOperationParams, callback func(*commonpb.Payload, error), startedHandler func(opID string, e error)) int64 { + seq := wc.GenerateSequence() + scheduleTaskAttr := &commandpb.ScheduleNexusOperationCommandAttributes{ + Endpoint: params.client.Endpoint(), + Service: params.client.Service(), + Operation: params.operation, + Input: params.input, + ScheduleToCloseTimeout: durationpb.New(params.options.ScheduleToCloseTimeout), + NexusHeader: params.nexusHeader, + } + + command := wc.commandsHelper.scheduleNexusOperation(seq, scheduleTaskAttr) + command.setData(&scheduledNexusOperation{ + startedCallback: startedHandler, + completedCallback: callback, + endpoint: params.client.Endpoint(), + service: params.client.Service(), + operation: params.operation, + }) + + wc.logger.Debug("ScheduleNexusOperation", + tagNexusEndpoint, params.client.Endpoint(), + tagNexusService, params.client.Service(), + tagNexusOperation, params.operation, + ) + + return command.seq +} + +func (wc *workflowEnvironmentImpl) RequestCancelNexusOperation(seq int64) { + command := wc.commandsHelper.requestCancelNexusOperation(seq) + data := command.getData().(*scheduledNexusOperation) + + // Make sure to unblock the futures. + if command.getState() == commandStateCreated || command.getState() == commandStateCommandSent { + if data.startedCallback != nil { + data.startedCallback("", ErrCanceled) + data.startedCallback = nil + } + if data.completedCallback != nil { + data.completedCallback(nil, ErrCanceled) + data.completedCallback = nil + } + } + wc.logger.Debug("RequestCancelNexusOperation", + tagNexusEndpoint, data.endpoint, + tagNexusService, data.service, + tagNexusOperation, data.operation, + ) +} + func (wc *workflowEnvironmentImpl) RegisterSignalHandler( handler func(name string, input *commonpb.Payloads, header *commonpb.Header) error, ) { @@ -1260,6 +1319,19 @@ func (weh *workflowExecutionEventHandlerImpl) ProcessEvent( case enumspb.EVENT_TYPE_WORKFLOW_EXECUTION_UPDATE_COMPLETED: // No Operation + case enumspb.EVENT_TYPE_NEXUS_OPERATION_SCHEDULED: + weh.commandsHelper.handleNexusOperationScheduled(event) + case enumspb.EVENT_TYPE_NEXUS_OPERATION_STARTED: + err = weh.handleNexusOperationStarted(event) + // all forms of completions are handled by the same method. + case enumspb.EVENT_TYPE_NEXUS_OPERATION_COMPLETED, + enumspb.EVENT_TYPE_NEXUS_OPERATION_FAILED, + enumspb.EVENT_TYPE_NEXUS_OPERATION_CANCELED, + enumspb.EVENT_TYPE_NEXUS_OPERATION_TIMED_OUT: + err = weh.handleNexusOperationCompleted(event) + case enumspb.EVENT_TYPE_NEXUS_OPERATION_CANCEL_REQUESTED: + weh.commandsHelper.handleNexusOperationCancelRequested(event.GetNexusOperationCancelRequestedEventAttributes().GetScheduledEventId()) + default: if event.WorkerMayIgnore { // Do not fail to be forward compatible with new events @@ -1820,6 +1892,61 @@ func (weh *workflowExecutionEventHandlerImpl) handleChildWorkflowExecutionTermin return nil } +func (weh *workflowExecutionEventHandlerImpl) handleNexusOperationStarted(event *historypb.HistoryEvent) error { + attributes := event.GetNexusOperationStartedEventAttributes() + command := weh.commandsHelper.handleNexusOperationStarted(attributes.ScheduledEventId) + state := command.getData().(*scheduledNexusOperation) + if state.startedCallback != nil { + state.startedCallback(attributes.OperationId, nil) + state.startedCallback = nil + } + return nil +} + +func (weh *workflowExecutionEventHandlerImpl) handleNexusOperationCompleted(event *historypb.HistoryEvent) error { + var result *commonpb.Payload + var failure *failurepb.Failure + var scheduledEventId int64 + + switch event.EventType { + case enumspb.EVENT_TYPE_NEXUS_OPERATION_COMPLETED: + attrs := event.GetNexusOperationCompletedEventAttributes() + result = attrs.GetResult() + scheduledEventId = attrs.GetScheduledEventId() + case enumspb.EVENT_TYPE_NEXUS_OPERATION_FAILED: + attrs := event.GetNexusOperationFailedEventAttributes() + failure = attrs.GetFailure() + scheduledEventId = attrs.GetScheduledEventId() + case enumspb.EVENT_TYPE_NEXUS_OPERATION_CANCELED: + attrs := event.GetNexusOperationCanceledEventAttributes() + failure = attrs.GetFailure() + scheduledEventId = attrs.GetScheduledEventId() + case enumspb.EVENT_TYPE_NEXUS_OPERATION_TIMED_OUT: + attrs := event.GetNexusOperationTimedOutEventAttributes() + failure = attrs.GetFailure() + scheduledEventId = attrs.GetScheduledEventId() + default: + // This is only called internally and should never happen. + panic(fmt.Errorf("invalid event type, not a Nexus Operation resolution: %v", event.EventType)) + } + command := weh.commandsHelper.handleNexusOperationCompleted(scheduledEventId) + state := command.getData().(*scheduledNexusOperation) + var err error + if failure != nil { + err = weh.failureConverter.FailureToError(failure) + } + // Also unblock the start future + if state.startedCallback != nil { + state.startedCallback("", err) // We didn't get a started event, the operation completed synchronously. + state.startedCallback = nil + } + if state.completedCallback != nil { + state.completedCallback(result, err) + state.completedCallback = nil + } + return nil +} + func (weh *workflowExecutionEventHandlerImpl) handleUpsertWorkflowSearchAttributes(event *historypb.HistoryEvent) { weh.updateWorkflowInfoWithSearchAttributes(event.GetUpsertWorkflowSearchAttributesEventAttributes().SearchAttributes) } diff --git a/internal/internal_logging_tags.go b/internal/internal_logging_tags.go index 977c95aba..99b7f93c0 100644 --- a/internal/internal_logging_tags.go +++ b/internal/internal_logging_tags.go @@ -49,6 +49,7 @@ const ( tagTaskStartedEventID = "TaskStartedEventID" tagPreviousStartedEventID = "PreviousStartedEventID" tagCachedPreviousStartedEventID = "CachedPreviousStartedEventID" + tagNexusEndpoint = "NexusEndpoint" tagNexusOperation = "NexusOperation" tagNexusService = "NexusService" tagPanicError = "PanicError" diff --git a/internal/internal_nexus_task_poller.go b/internal/internal_nexus_task_poller.go index 2d818a3f6..dd1e9655e 100644 --- a/internal/internal_nexus_task_poller.go +++ b/internal/internal_nexus_task_poller.go @@ -134,7 +134,7 @@ func (ntp *nexusTaskPoller) ProcessTask(task interface{}) error { if err := ntp.reportCompletion(res, failure); err != nil { traceLog(func() { - ntp.logger.Debug("reportActivityComplete failed", tagError, err) + ntp.logger.Debug("reportNexusTaskComplete failed", tagError, err) }) return err } diff --git a/internal/internal_nexus_worker.go b/internal/internal_nexus_worker.go index 1445d7429..8c9dd6a0f 100644 --- a/internal/internal_nexus_worker.go +++ b/internal/internal_nexus_worker.go @@ -27,8 +27,8 @@ func newNexusWorker(opts nexusWorkerOptions) (*nexusWorker, error) { poller := newNexusTaskPoller( newNexusTaskHandler( opts.handler, - opts.executionParameters.Namespace, opts.executionParameters.Identity, + opts.executionParameters.Namespace, opts.executionParameters.TaskQueue, opts.client, opts.executionParameters.DataConverter, diff --git a/internal/internal_task_handlers.go b/internal/internal_task_handlers.go index 70b718f84..0f58aa08d 100644 --- a/internal/internal_task_handlers.go +++ b/internal/internal_task_handlers.go @@ -336,7 +336,9 @@ func isCommandEvent(eventType enumspb.EventType) bool { enumspb.EVENT_TYPE_WORKFLOW_PROPERTIES_MODIFIED, enumspb.EVENT_TYPE_WORKFLOW_EXECUTION_UPDATE_ACCEPTED, enumspb.EVENT_TYPE_WORKFLOW_EXECUTION_UPDATE_COMPLETED, - enumspb.EVENT_TYPE_WORKFLOW_EXECUTION_UPDATE_REJECTED: + enumspb.EVENT_TYPE_WORKFLOW_EXECUTION_UPDATE_REJECTED, + enumspb.EVENT_TYPE_NEXUS_OPERATION_SCHEDULED, + enumspb.EVENT_TYPE_NEXUS_OPERATION_CANCEL_REQUESTED: return true default: return false @@ -1687,6 +1689,22 @@ func isCommandMatchEvent(d *commandpb.Command, e *historypb.HistoryEvent, obes [ return false } return true + + case enumspb.COMMAND_TYPE_SCHEDULE_NEXUS_OPERATION: + if e.GetEventType() != enumspb.EVENT_TYPE_NEXUS_OPERATION_SCHEDULED { + return false + } + eventAttributes := e.GetNexusOperationScheduledEventAttributes() + commandAttributes := d.GetScheduleNexusOperationCommandAttributes() + + if eventAttributes.GetService() != commandAttributes.GetService() || eventAttributes.GetOperation() != commandAttributes.GetOperation() { + return false + } + + return true + + case enumspb.COMMAND_TYPE_REQUEST_CANCEL_NEXUS_OPERATION: + return e.GetEventType() == enumspb.EVENT_TYPE_NEXUS_OPERATION_CANCEL_REQUESTED } return false diff --git a/internal/internal_worker_base.go b/internal/internal_worker_base.go index 60285a8b6..c841d0bed 100644 --- a/internal/internal_worker_base.go +++ b/internal/internal_worker_base.go @@ -77,6 +77,14 @@ type ( Backoff time.Duration } + executeNexusOperationParams struct { + client NexusClient + operation string + input *commonpb.Payload + options NexusOperationOptions + nexusHeader map[string]string + } + // WorkflowEnvironment Represents the environment for workflow. // Should only be used within the scope of workflow definition. WorkflowEnvironment interface { @@ -92,6 +100,8 @@ type ( RequestCancelChildWorkflow(namespace, workflowID string) RequestCancelExternalWorkflow(namespace, workflowID, runID string, callback ResultHandler) ExecuteChildWorkflow(params ExecuteWorkflowParams, callback ResultHandler, startedHandler func(r WorkflowExecution, e error)) + ExecuteNexusOperation(params executeNexusOperationParams, callback func(*commonpb.Payload, error), startedHandler func(opID string, e error)) int64 + RequestCancelNexusOperation(seq int64) GetLogger() log.Logger GetMetricsHandler() metrics.Handler // Must be called before WorkflowDefinition.Execute returns diff --git a/internal/internal_workflow.go b/internal/internal_workflow.go index 42335ca2f..09b6f8965 100644 --- a/internal/internal_workflow.go +++ b/internal/internal_workflow.go @@ -230,6 +230,11 @@ type ( executionFuture *futureImpl // for child workflow execution future } + nexusOperationFutureImpl struct { + *decodeFutureImpl // for the result + executionFuture *futureImpl // for the NexusOperationExecution + } + asyncFuture interface { Future // Used by selectorImpl @@ -467,6 +472,10 @@ func (f *childWorkflowFutureImpl) SignalChildWorkflow(ctx Context, signalName st return i.SignalChildWorkflow(ctx, childExec.ID, signalName, data) } +func (f *nexusOperationFutureImpl) GetNexusOperationExecution() Future { + return f.executionFuture +} + func newWorkflowContext( env WorkflowEnvironment, interceptors []WorkerInterceptor, diff --git a/internal/internal_workflow_testsuite.go b/internal/internal_workflow_testsuite.go index bf4398fdd..c7f944b19 100644 --- a/internal/internal_workflow_testsuite.go +++ b/internal/internal_workflow_testsuite.go @@ -2279,6 +2279,14 @@ func (env *testWorkflowEnvironmentImpl) executeChildWorkflowWithDelay(delayStart go childEnv.executeWorkflowInternal(delayStart, params.WorkflowType.Name, params.Input) } +func (wc *testWorkflowEnvironmentImpl) ExecuteNexusOperation(params executeNexusOperationParams, callback func(*commonpb.Payload, error), startedHandler func(opID string, e error)) int64 { + panic("TODO") +} + +func (wc *testWorkflowEnvironmentImpl) RequestCancelNexusOperation(seq int64) { + panic("TODO") +} + func (env *testWorkflowEnvironmentImpl) SideEffect(f func() (*commonpb.Payloads, error), callback ResultHandler) { callback(f()) } diff --git a/internal/workflow.go b/internal/workflow.go index 7d29a62d1..ef0e2b91e 100644 --- a/internal/workflow.go +++ b/internal/workflow.go @@ -2119,3 +2119,160 @@ func DeterministicKeysFunc[K comparable, V any](m map[K]V, cmp func(a K, b K) in slices.SortStableFunc(r, cmp) return r } + +// NexusOperationOptions are options for starting a Nexus Operation from a Workflow. +type NexusOperationOptions struct { + ScheduleToCloseTimeout time.Duration +} + +// NexusOperationExecution is the result of [NexusOperationFuture.GetNexusOperationExecution]. +type NexusOperationExecution struct { + OperationID string +} + +// NexusOperationFuture represents the result of a Nexus Operation. +type NexusOperationFuture interface { + Future + // GetNexusOperationExecution returns a future that is resolved when the operation reaches the STARTED state. + // For synchronous operations, this will be resolved at the same as the containing [NexusOperationFuture]. For + // asynchronous operations, this future is resolved independently. + // If the operation is unsuccessful, this future will contain the same error as the [NexusOperationFuture]. + // Use this method to extract the Operation ID of an asynchronous operation. OperationID will be empty for + // synchronous operations. + // + // NOTE: Experimental + GetNexusOperationExecution() Future +} + +// NexusClient is a client for executing Nexus Operations from a workflow. +type NexusClient interface { + // The endpoint name this client uses. + // + // NOTE: Experimental + Endpoint() string + // The service name this client uses. + // + // NOTE: Experimental + Service() string + + // ExecuteOperation executes a Nexus Operation. + // The operation argument can be a string, a [nexus.Operation] or a [nexus.OperationReference]. + // + // NOTE: Experimental + ExecuteOperation(ctx Context, operation any, input any, options NexusOperationOptions) NexusOperationFuture +} + +type nexusClient struct { + endpoint, service string +} + +// Create a [NexusClient] from an endpoint name and a service name. +// +// NOTE: Experimental +func NewNexusClient(endpoint, service string) NexusClient { + return nexusClient{endpoint, service} +} + +func (c nexusClient) Endpoint() string { + return c.endpoint +} + +func (c nexusClient) Service() string { + return c.service +} + +func (c nexusClient) ExecuteOperation(ctx Context, operation any, input any, options NexusOperationOptions) NexusOperationFuture { + assertNotInReadOnlyState(ctx) + i := getWorkflowOutboundInterceptor(ctx) + return i.ExecuteNexusOperation(ctx, c, operation, input, options) +} + +func (wc *workflowEnvironmentInterceptor) prepareNexusOperationParams(ctx Context, client NexusClient, operation any, input any, options NexusOperationOptions) (executeNexusOperationParams, error) { + dc := WithWorkflowContext(ctx, wc.env.GetDataConverter()) + + var ok bool + var operationName string + if operationName, ok = operation.(string); ok { + } else if regOp, ok := operation.(interface{ Name() string }); ok { + operationName = regOp.Name() + } else { + return executeNexusOperationParams{}, fmt.Errorf("invalid 'operation' parameter, must be an OperationReference or a string") + } + // TODO(bergundy): Validate operation types against input once there's a good way to extract the generic types from + // OperationReference in the Nexus Go SDK. + + payload, err := dc.ToPayload(input) + if err != nil { + return executeNexusOperationParams{}, err + } + + return executeNexusOperationParams{ + client: client, + operation: operationName, + input: payload, + options: options, + }, nil +} + +func (wc *workflowEnvironmentInterceptor) ExecuteNexusOperation(ctx Context, client NexusClient, operation any, input any, options NexusOperationOptions) NexusOperationFuture { + mainFuture, mainSettable := newDecodeFuture(ctx, nil /* this param is never used */) + executionFuture, executionSettable := NewFuture(ctx) + result := &nexusOperationFutureImpl{ + decodeFutureImpl: mainFuture.(*decodeFutureImpl), + executionFuture: executionFuture.(*futureImpl), + } + + // Immediately return if the context has an error without spawning the child workflow + if ctx.Err() != nil { + executionSettable.Set(nil, ctx.Err()) + mainSettable.Set(nil, ctx.Err()) + return result + } + + ctxDone, cancellable := ctx.Done().(*channelImpl) + cancellationCallback := &receiveCallback{} + params, err := wc.prepareNexusOperationParams(ctx, client, operation, input, options) + if err != nil { + executionSettable.Set(nil, err) + mainSettable.Set(nil, err) + return result + } + + var operationID string + seq := wc.env.ExecuteNexusOperation(params, func(r *commonpb.Payload, e error) { + mainSettable.Set(&commonpb.Payloads{Payloads: []*commonpb.Payload{r}}, e) + if cancellable { + // future is done, we don't need cancellation anymore + ctxDone.removeReceiveCallback(cancellationCallback) + } + }, func(opID string, e error) { + operationID = opID + executionSettable.Set(NexusOperationExecution{opID}, e) + }) + + if cancellable { + cancellationCallback.fn = func(v any, _ bool) bool { + assertNotInReadOnlyStateCancellation(ctx) + if ctx.Err() == ErrCanceled && !mainFuture.IsReady() { + // Go back to the top of the interception chain. + getWorkflowOutboundInterceptor(ctx).RequestCancelNexusOperation(ctx, RequestCancelNexusOperationInput{ + Client: client, + Operation: operation, + ID: operationID, + seq: seq, + }) + } + return false + } + _, ok, more := ctxDone.receiveAsyncImpl(cancellationCallback) + if ok || !more { + cancellationCallback.fn(nil, more) + } + } + + return result +} + +func (wc *workflowEnvironmentInterceptor) RequestCancelNexusOperation(ctx Context, input RequestCancelNexusOperationInput) { + wc.env.RequestCancelNexusOperation(input.seq) +} diff --git a/temporal/error.go b/temporal/error.go index 132b42394..3f14b41b9 100644 --- a/temporal/error.go +++ b/temporal/error.go @@ -131,6 +131,11 @@ type ( // ChildWorkflowExecutionError returned from workflow when child workflow returned an error. ChildWorkflowExecutionError = internal.ChildWorkflowExecutionError + // NexusOperationError is an error returned when a Nexus Operation has failed. + // + // NOTE: Experimental + NexusOperationError = internal.NexusOperationError + // ChildWorkflowExecutionAlreadyStartedError is set as the cause of // ChildWorkflowExecutionError when failure is due the child workflow having // already started. diff --git a/test/nexus_test.go b/test/nexus_test.go index ee7010120..5b518acb0 100644 --- a/test/nexus_test.go +++ b/test/nexus_test.go @@ -24,6 +24,7 @@ package test_test import ( "context" + "fmt" "net/http" "slices" "testing" @@ -31,18 +32,123 @@ import ( "github.com/google/uuid" "github.com/nexus-rpc/sdk-go/nexus" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "go.temporal.io/api/common/v1" + "go.temporal.io/api/enums/v1" + historypb "go.temporal.io/api/history/v1" nexuspb "go.temporal.io/api/nexus/v1" "go.temporal.io/api/operatorservice/v1" "go.temporal.io/sdk/client" "go.temporal.io/sdk/internal/common/metrics" ilog "go.temporal.io/sdk/internal/log" + "go.temporal.io/sdk/temporal" "go.temporal.io/sdk/temporalnexus" "go.temporal.io/sdk/worker" "go.temporal.io/sdk/workflow" ) +type testContext struct { + client client.Client + metricsHandler *metrics.CapturingHandler + testConfig Config + taskQueue, endpoint, endpointBaseURL string +} + +func newTestContext(t *testing.T, ctx context.Context) *testContext { + config := NewConfig() + require.NoError(t, WaitForTCP(time.Minute, config.ServiceAddr)) + + metricsHandler := metrics.NewCapturingHandler() + c, err := client.DialContext(ctx, client.Options{ + HostPort: config.ServiceAddr, + Namespace: config.Namespace, + Logger: ilog.NewDefaultLogger(), + ConnectionOptions: client.ConnectionOptions{TLS: config.TLS}, + MetricsHandler: metricsHandler, + }) + require.NoError(t, err) + + taskQueue := "sdk-go-nexus-test-tq-" + uuid.NewString() + endpoint := "sdk-go-nexus-test-ep-" + uuid.NewString() + res, err := c.OperatorService().CreateNexusEndpoint(ctx, &operatorservice.CreateNexusEndpointRequest{ + Spec: &nexuspb.EndpointSpec{ + Name: endpoint, + Target: &nexuspb.EndpointTarget{ + Variant: &nexuspb.EndpointTarget_Worker_{ + Worker: &nexuspb.EndpointTarget_Worker{ + Namespace: config.Namespace, + TaskQueue: taskQueue, + }, + }, + }, + }, + }) + require.NoError(t, err) + + scheme := "http" + if config.TLS != nil { + scheme = "https" + } + endpointBaseURL := scheme + "://" + config.ServiceHTTPAddr + res.Endpoint.UrlPrefix + + tc := &testContext{ + client: c, + testConfig: config, + metricsHandler: metricsHandler, + taskQueue: taskQueue, + endpoint: endpoint, + endpointBaseURL: endpointBaseURL, + } + + return tc +} + +func (tc *testContext) newNexusClient(t *testing.T, service string) *nexus.Client { + httpClient := http.Client{ + Transport: &http.Transport{ + TLSClientConfig: tc.testConfig.TLS, + }, + } + nc, err := nexus.NewClient(nexus.ClientOptions{ + BaseURL: tc.endpointBaseURL, + Service: service, + HTTPCaller: func(r *http.Request) (*http.Response, error) { + attempt := 0 + for { + attempt++ + res, err := httpClient.Do(r) + // Give the endpoint configuration some time to propagate in the frontend. + // This should not take more than a few milliseconds. + // TODO(bergundy): Remove this once the server supports cache read through for unknown endpoints. + if attempt < 10 && err == nil && res.StatusCode == http.StatusNotFound { + time.Sleep(time.Millisecond * 100) + continue + } + return res, err + } + }, + }) + require.NoError(t, err) + return nc +} + +func (tc *testContext) requireTimer(t *assert.CollectT, metric, service, operation string) { + assert.True(t, slices.ContainsFunc(tc.metricsHandler.Timers(), func(ct *metrics.CapturedTimer) bool { + return ct.Name == metric && + ct.Tags[metrics.NexusServiceTagName] == service && + ct.Tags[metrics.NexusOperationTagName] == operation + })) +} + +func (tc *testContext) requireCounter(t *assert.CollectT, metric, service, operation string) { + assert.True(t, slices.ContainsFunc(tc.metricsHandler.Counters(), func(ct *metrics.CapturedCounter) bool { + return ct.Name == metric && + ct.Tags[metrics.NexusServiceTagName] == service && + ct.Tags[metrics.NexusOperationTagName] == operation + })) +} + var syncOp = temporalnexus.NewSyncOperation("sync-op", func(ctx context.Context, c client.Client, s string, o nexus.StartOperationOptions) (string, error) { switch s { case "ok": @@ -93,23 +199,10 @@ var workflowOp = temporalnexus.NewWorkflowRunOperation( func TestNexusSyncOperation(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), time.Second*5) defer cancel() - config := NewConfig() - require.NoError(t, WaitForTCP(time.Minute, config.ServiceAddr)) - - metricsHandler := metrics.NewCapturingHandler() - c, err := client.DialContext(ctx, client.Options{ - HostPort: config.ServiceAddr, - Namespace: config.Namespace, - Logger: ilog.NewDefaultLogger(), - ConnectionOptions: client.ConnectionOptions{TLS: config.TLS}, - MetricsHandler: metricsHandler, - }) - require.NoError(t, err) - - taskQueue := "nexus-test" + uuid.NewString() - w := worker.New(c, taskQueue, worker.Options{}) + tc := newTestContext(t, ctx) + w := worker.New(tc.client, tc.taskQueue, worker.Options{}) service := nexus.NewService("test") require.NoError(t, service.Register(syncOp, workflowOp)) w.RegisterNexusService(service) @@ -117,29 +210,10 @@ func TestNexusSyncOperation(t *testing.T) { require.NoError(t, w.Start()) t.Cleanup(w.Stop) - res, err := c.OperatorService().CreateNexusEndpoint(ctx, &operatorservice.CreateNexusEndpointRequest{ - Spec: &nexuspb.EndpointSpec{ - Name: "sdk-go-test-" + uuid.NewString(), - Target: &nexuspb.EndpointTarget{ - Variant: &nexuspb.EndpointTarget_Worker_{ - Worker: &nexuspb.EndpointTarget_Worker{ - Namespace: config.Namespace, - TaskQueue: taskQueue, - }, - }, - }, - }, - }) - require.NoError(t, err) - - nc, err := nexus.NewClient(nexus.ClientOptions{ - BaseURL: "http://" + config.ServiceHTTPAddr + res.Endpoint.UrlPrefix, - Service: service.Name, - }) - require.NoError(t, err) + nc := tc.newNexusClient(t, service.Name) t.Run("ok", func(t *testing.T) { - metricsHandler.Clear() + tc.metricsHandler.Clear() result, err := nexus.ExecuteOperation(ctx, nc, syncOp, "ok", nexus.ExecuteOperationOptions{ RequestID: "test-request-id", Header: nexus.Header{"test": "ok"}, @@ -148,23 +222,21 @@ func TestNexusSyncOperation(t *testing.T) { }) require.NoError(t, err) require.Equal(t, "ok", result) - requireTimer(t, metricsHandler, metrics.NexusTaskEndToEndLatency, service.Name, syncOp.Name()) - requireTimer(t, metricsHandler, metrics.NexusTaskScheduleToStartLatency, service.Name, syncOp.Name()) - requireTimer(t, metricsHandler, metrics.NexusTaskExecutionLatency, service.Name, syncOp.Name()) + + require.EventuallyWithT(t, func(t *assert.CollectT) { + tc.requireTimer(t, metrics.NexusTaskEndToEndLatency, service.Name, syncOp.Name()) + tc.requireTimer(t, metrics.NexusTaskScheduleToStartLatency, service.Name, syncOp.Name()) + tc.requireTimer(t, metrics.NexusTaskExecutionLatency, service.Name, syncOp.Name()) + }, time.Second*3, time.Millisecond*100) }) t.Run("fail", func(t *testing.T) { - metricsHandler.Clear() + tc.metricsHandler.Clear() _, err := nexus.ExecuteOperation(ctx, nc, syncOp, "fail", nexus.ExecuteOperationOptions{}) var unsuccessfulOperationErr *nexus.UnsuccessfulOperationError require.ErrorAs(t, err, &unsuccessfulOperationErr) require.Equal(t, nexus.OperationStateFailed, unsuccessfulOperationErr.State) require.Equal(t, "fail", unsuccessfulOperationErr.Failure.Message) - - requireTimer(t, metricsHandler, metrics.NexusTaskEndToEndLatency, service.Name, syncOp.Name()) - requireTimer(t, metricsHandler, metrics.NexusTaskScheduleToStartLatency, service.Name, syncOp.Name()) - requireTimer(t, metricsHandler, metrics.NexusTaskExecutionLatency, service.Name, syncOp.Name()) - requireCounter(t, metricsHandler, metrics.NexusTaskExecutionFailedCounter, service.Name, syncOp.Name()) }) t.Run("handlererror", func(t *testing.T) { @@ -173,6 +245,13 @@ func TestNexusSyncOperation(t *testing.T) { require.ErrorAs(t, err, &unexpectedResponseErr) require.Equal(t, http.StatusBadRequest, unexpectedResponseErr.Response.StatusCode) require.Contains(t, unexpectedResponseErr.Message, `"400 Bad Request": handlererror`) + + require.EventuallyWithT(t, func(t *assert.CollectT) { + tc.requireTimer(t, metrics.NexusTaskEndToEndLatency, service.Name, syncOp.Name()) + tc.requireTimer(t, metrics.NexusTaskScheduleToStartLatency, service.Name, syncOp.Name()) + tc.requireTimer(t, metrics.NexusTaskExecutionLatency, service.Name, syncOp.Name()) + tc.requireCounter(t, metrics.NexusTaskExecutionFailedCounter, service.Name, syncOp.Name()) + }, time.Second*3, time.Millisecond*100) }) t.Run("panic", func(t *testing.T) { @@ -189,21 +268,9 @@ func TestNexusSyncOperation(t *testing.T) { func TestNexusWorkflowRunOperation(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), time.Second*5) defer cancel() - config := NewConfig() - require.NoError(t, WaitForTCP(time.Minute, config.ServiceAddr)) - - c, err := client.DialContext(ctx, client.Options{ - HostPort: config.ServiceAddr, - Namespace: config.Namespace, - Logger: ilog.NewDefaultLogger(), - ConnectionOptions: client.ConnectionOptions{TLS: config.TLS}, - }) - require.NoError(t, err) - - taskQueue := "nexus-test" + uuid.NewString() - - w := worker.New(c, taskQueue, worker.Options{}) + tc := newTestContext(t, ctx) + w := worker.New(tc.client, tc.taskQueue, worker.Options{}) service := nexus.NewService("test") require.NoError(t, service.Register(syncOp, workflowOp)) w.RegisterNexusService(service) @@ -211,26 +278,7 @@ func TestNexusWorkflowRunOperation(t *testing.T) { require.NoError(t, w.Start()) t.Cleanup(w.Stop) - res, err := c.OperatorService().CreateNexusEndpoint(ctx, &operatorservice.CreateNexusEndpointRequest{ - Spec: &nexuspb.EndpointSpec{ - Name: "sdk-go-test-" + uuid.NewString(), - Target: &nexuspb.EndpointTarget{ - Variant: &nexuspb.EndpointTarget_Worker_{ - Worker: &nexuspb.EndpointTarget_Worker{ - Namespace: config.Namespace, - TaskQueue: taskQueue, - }, - }, - }, - }, - }) - require.NoError(t, err) - - nc, err := nexus.NewClient(nexus.ClientOptions{ - BaseURL: "http://" + config.ServiceHTTPAddr + res.Endpoint.UrlPrefix, - Service: service.Name, - }) - require.NoError(t, err) + nc := tc.newNexusClient(t, service.Name) workflowID := "nexus-handler-workflow-" + uuid.NewString() result, err := nexus.StartOperation(ctx, nc, workflowOp, workflowID, nexus.StartOperationOptions{ @@ -241,7 +289,7 @@ func TestNexusWorkflowRunOperation(t *testing.T) { require.NotNil(t, result.Pending) handle := result.Pending require.Equal(t, workflowID, handle.ID) - desc, err := c.DescribeWorkflowExecution(ctx, workflowID, "") + desc, err := tc.client.DescribeWorkflowExecution(ctx, workflowID, "") require.NoError(t, err) require.Equal(t, 1, len(desc.Callbacks)) @@ -250,23 +298,362 @@ func TestNexusWorkflowRunOperation(t *testing.T) { require.Equal(t, "http://localhost/test", callback.Nexus.Url) require.Equal(t, map[string]string{"test": "ok"}, callback.Nexus.Header) - run := c.GetWorkflow(ctx, workflowID, "") + run := tc.client.GetWorkflow(ctx, workflowID, "") require.NoError(t, handle.Cancel(ctx, nexus.CancelOperationOptions{})) require.ErrorContains(t, run.Get(ctx, nil), "canceled") } -func requireTimer(t *testing.T, metricsHandler *metrics.CapturingHandler, metric, service, operation string) { - require.True(t, slices.ContainsFunc(metricsHandler.Timers(), func(ct *metrics.CapturedTimer) bool { - return ct.Name == metric && - ct.Tags[metrics.NexusServiceTagName] == service && - ct.Tags[metrics.NexusOperationTagName] == operation - })) +func TestSyncOperationFromWorkflow(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), time.Second*5) + defer cancel() + tc := newTestContext(t, ctx) + + op := temporalnexus.NewSyncOperation("op", func(ctx context.Context, c client.Client, outcome string, o nexus.StartOperationOptions) (string, error) { + switch outcome { + case "successful": + return outcome, nil + case "failed": + return "", &nexus.UnsuccessfulOperationError{ + State: nexus.OperationStateFailed, + Failure: nexus.Failure{Message: "failed for test"}, + } + case "canceled": + return "", &nexus.UnsuccessfulOperationError{ + State: nexus.OperationStateCanceled, + Failure: nexus.Failure{Message: "canceled for test"}, + } + default: + panic(fmt.Errorf("unexpected outcome: %s", outcome)) + } + }) + + wf := func(ctx workflow.Context, outcome string) error { + c := workflow.NewNexusClient(tc.endpoint, "test") + fut := c.ExecuteOperation(ctx, op, outcome, workflow.NexusOperationOptions{}) + var res string + + var exec workflow.NexusOperationExecution + if err := fut.GetNexusOperationExecution().Get(ctx, &exec); err != nil && outcome == "successful" { + return fmt.Errorf("expected start to succeed: %w", err) + } + if exec.OperationID != "" { + return fmt.Errorf("expected empty operation ID") + } + if err := fut.Get(ctx, &res); err != nil { + return err + } + // If the operation didn't fail the only expected result is "successful". + if res != "successful" { + return fmt.Errorf("unexpected result: %v", res) + } + return nil + } + + w := worker.New(tc.client, tc.taskQueue, worker.Options{}) + service := nexus.NewService("test") + require.NoError(t, service.Register(op)) + w.RegisterNexusService(service) + w.RegisterWorkflow(wf) + require.NoError(t, w.Start()) + t.Cleanup(w.Stop) + + t.Run("OpSuccessful", func(t *testing.T) { + run, err := tc.client.ExecuteWorkflow(ctx, client.StartWorkflowOptions{ + TaskQueue: tc.taskQueue, + // The endpoint registry may take a bit to propagate to the history service, use a shorter workflow task + // timeout to speed up the attempts. + WorkflowTaskTimeout: time.Second, + }, wf, "successful") + require.NoError(t, err) + require.NoError(t, run.Get(ctx, nil)) + }) + + t.Run("OpFailed", func(t *testing.T) { + run, err := tc.client.ExecuteWorkflow(ctx, client.StartWorkflowOptions{ + TaskQueue: tc.taskQueue, + // The endpoint registry may take a bit to propagate to the history service, use a shorter workflow task + // timeout to speed up the attempts. + WorkflowTaskTimeout: time.Second, + }, wf, "failed") + require.NoError(t, err) + var execErr *temporal.WorkflowExecutionError + err = run.Get(ctx, nil) + require.ErrorAs(t, err, &execErr) + var opErr *temporal.NexusOperationError + err = execErr.Unwrap() + require.ErrorAs(t, err, &opErr) + require.Equal(t, tc.endpoint, opErr.Endpoint) + require.Equal(t, "test", opErr.Service) + require.Equal(t, op.Name(), opErr.Operation) + require.Equal(t, "", opErr.OperationID) + require.Equal(t, "nexus operation completed unsuccessfully", opErr.Message) + require.Greater(t, opErr.ScheduledEventID, int64(0)) + err = opErr.Unwrap() + var appErr *temporal.ApplicationError + require.ErrorAs(t, err, &appErr) + require.Equal(t, "failed for test", appErr.Message()) + }) + + t.Run("OpCanceled", func(t *testing.T) { + run, err := tc.client.ExecuteWorkflow(ctx, client.StartWorkflowOptions{ + TaskQueue: tc.taskQueue, + // The endpoint registry may take a bit to propagate to the history service, use a shorter workflow task + // timeout to speed up the attempts. + WorkflowTaskTimeout: time.Second, + }, wf, "canceled") + require.NoError(t, err) + var execErr *temporal.WorkflowExecutionError + err = run.Get(ctx, nil) + require.ErrorAs(t, err, &execErr) + // The Go SDK unwraps workflow errors to check for cancelation even if the workflow was never canceled, losing + // the error chain, Nexus operation errors are treated the same as other workflow errors for consistency. + var canceledErr *temporal.CanceledError + err = execErr.Unwrap() + require.ErrorAs(t, err, &canceledErr) + }) } -func requireCounter(t *testing.T, metricsHandler *metrics.CapturingHandler, metric, service, operation string) { - require.True(t, slices.ContainsFunc(metricsHandler.Counters(), func(ct *metrics.CapturedCounter) bool { - return ct.Name == metric && - ct.Tags[metrics.NexusServiceTagName] == service && - ct.Tags[metrics.NexusOperationTagName] == operation - })) +func TestAsyncOperationFromWorkflow(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), time.Second*5) + defer cancel() + tc := newTestContext(t, ctx) + + handlerWorkflow := func(ctx workflow.Context, action string) (string, error) { + switch action { + case "succeed": + return action, nil + case "fail": + return "", fmt.Errorf("handler workflow failed in test") + case "wait-for-cancel": + return "", workflow.Await(ctx, func() bool { return false }) + default: + panic(fmt.Errorf("unexpected outcome: %s", action)) + } + } + op := temporalnexus.NewWorkflowRunOperation("op", handlerWorkflow, func(ctx context.Context, action string, soo nexus.StartOperationOptions) (client.StartWorkflowOptions, error) { + if action == "fail-to-start" { + return client.StartWorkflowOptions{}, nexus.HandlerErrorf(nexus.HandlerErrorTypeInternal, "fake internal error") + } + return client.StartWorkflowOptions{ + ID: soo.RequestID, + }, nil + }) + callerWorkflow := func(ctx workflow.Context, action string) error { + c := workflow.NewNexusClient(tc.endpoint, "test") + ctx, cancel := workflow.WithCancel(ctx) + defer cancel() + fut := c.ExecuteOperation(ctx, op, action, workflow.NexusOperationOptions{}) + var res string + ch := workflow.GetSignalChannel(ctx, "cancel-op") + workflow.Go(ctx, func(ctx workflow.Context) { + var action string + ch.Receive(ctx, &action) + switch action { + case "wait-for-started": + fut.GetNexusOperationExecution().Get(ctx, nil) + case "sleep": + workflow.Sleep(ctx, time.Millisecond) + } + cancel() + }) + var exec workflow.NexusOperationExecution + if err := fut.GetNexusOperationExecution().Get(ctx, &exec); err != nil && action != "fail-to-start" { + return fmt.Errorf("expected start to succeed: %w", err) + } + if exec.OperationID == "" && action != "fail-to-start" { + return fmt.Errorf("expected non empty operation ID") + } + if err := fut.Get(ctx, &res); err != nil { + return err + } + // If the operation didn't fail the only expected result is "successful". + if res != "succeed" { + return fmt.Errorf("unexpected result: %v", res) + } + return nil + } + + w := worker.New(tc.client, tc.taskQueue, worker.Options{}) + service := nexus.NewService("test") + require.NoError(t, service.Register(op)) + w.RegisterNexusService(service) + w.RegisterWorkflow(handlerWorkflow) + w.RegisterWorkflow(callerWorkflow) + require.NoError(t, w.Start()) + t.Cleanup(w.Stop) + + t.Run("OpSuccessful", func(t *testing.T) { + run, err := tc.client.ExecuteWorkflow(ctx, client.StartWorkflowOptions{ + TaskQueue: tc.taskQueue, + // The endpoint registry may take a bit to propagate to the history service, use a shorter workflow task + // timeout to speed up the attempts. + WorkflowTaskTimeout: time.Second, + }, callerWorkflow, "succeed") + require.NoError(t, err) + require.NoError(t, run.Get(ctx, nil)) + }) + + t.Run("OpFailed", func(t *testing.T) { + run, err := tc.client.ExecuteWorkflow(ctx, client.StartWorkflowOptions{ + TaskQueue: tc.taskQueue, + // The endpoint registry may take a bit to propagate to the history service, use a shorter workflow task + // timeout to speed up the attempts. + WorkflowTaskTimeout: time.Second, + }, callerWorkflow, "fail") + require.NoError(t, err) + var execErr *temporal.WorkflowExecutionError + err = run.Get(ctx, nil) + require.ErrorAs(t, err, &execErr) + var opErr *temporal.NexusOperationError + err = execErr.Unwrap() + require.ErrorAs(t, err, &opErr) + require.Equal(t, tc.endpoint, opErr.Endpoint) + require.Equal(t, "test", opErr.Service) + require.Equal(t, op.Name(), opErr.Operation) + require.NotEmpty(t, opErr.OperationID) + require.Equal(t, "nexus operation completed unsuccessfully", opErr.Message) + require.Greater(t, opErr.ScheduledEventID, int64(0)) + err = opErr.Unwrap() + var appErr *temporal.ApplicationError + require.ErrorAs(t, err, &appErr) + require.Equal(t, "handler workflow failed in test", appErr.Message()) + }) + + t.Run("OpCanceledBeforeSent", func(t *testing.T) { + run, err := tc.client.SignalWithStartWorkflow(ctx, uuid.NewString(), "cancel-op", "no-wait", client.StartWorkflowOptions{ + TaskQueue: tc.taskQueue, + }, callerWorkflow, "wait-for-cancel") + require.NoError(t, err) + var execErr *temporal.WorkflowExecutionError + err = run.Get(ctx, nil) + require.ErrorAs(t, err, &execErr) + // The Go SDK unwraps workflow errors to check for cancelation even if the workflow was never canceled, losing + // the error chain, Nexus operation errors are treated the same as other workflow errors for consistency. + var canceledErr *temporal.CanceledError + err = execErr.Unwrap() + require.ErrorAs(t, err, &canceledErr) + + // Verify that the operation was never scheduled. + history := tc.client.GetWorkflowHistory(ctx, run.GetID(), run.GetRunID(), false, enums.HISTORY_EVENT_FILTER_TYPE_ALL_EVENT) + for history.HasNext() { + event, err := history.Next() + require.NoError(t, err) + require.NotEqual(t, enums.EVENT_TYPE_NEXUS_OPERATION_SCHEDULED, event.EventType) + } + }) + + t.Run("OpCanceledBeforeStarted", func(t *testing.T) { + run, err := tc.client.SignalWithStartWorkflow(ctx, uuid.NewString(), "cancel-op", "sleep", client.StartWorkflowOptions{ + TaskQueue: tc.taskQueue, + }, callerWorkflow, "fail-to-start") + require.NoError(t, err) + var execErr *temporal.WorkflowExecutionError + err = run.Get(ctx, nil) + require.ErrorAs(t, err, &execErr) + // The Go SDK unwraps workflow errors to check for cancelation even if the workflow was never canceled, losing + // the error chain, Nexus operation errors are treated the same as other workflow errors for consistency. + var canceledErr *temporal.CanceledError + err = execErr.Unwrap() + require.ErrorAs(t, err, &canceledErr) + }) + + t.Run("OpCanceledAfterStarted", func(t *testing.T) { + run, err := tc.client.SignalWithStartWorkflow(ctx, uuid.NewString(), "cancel-op", "wait-for-started", client.StartWorkflowOptions{ + TaskQueue: tc.taskQueue, + }, callerWorkflow, "wait-for-cancel") + require.NoError(t, err) + var execErr *temporal.WorkflowExecutionError + err = run.Get(ctx, nil) + require.ErrorAs(t, err, &execErr) + // The Go SDK unwraps workflow errors to check for cancelation even if the workflow was never canceled, losing + // the error chain, Nexus operation errors are treated the same as other workflow errors for consistency. + var canceledErr *temporal.CanceledError + err = execErr.Unwrap() + require.ErrorAs(t, err, &canceledErr) + }) +} + +func TestReplay(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), time.Second*5) + defer cancel() + tc := newTestContext(t, ctx) + + op := temporalnexus.NewSyncOperation("op", func(ctx context.Context, c client.Client, nv nexus.NoValue, soo nexus.StartOperationOptions) (nexus.NoValue, error) { + return nil, nil + }) + + endpointForTest := tc.endpoint + serviceForTest := "test" + opForTest := op.Name() + + callerWorkflow := func(ctx workflow.Context) error { + c := workflow.NewNexusClient(endpointForTest, serviceForTest) + ctx, cancel := workflow.WithCancel(ctx) + defer cancel() + fut := c.ExecuteOperation(ctx, opForTest, nil, workflow.NexusOperationOptions{}) + if err := fut.Get(ctx, nil); err != nil { + return err + } + return nil + } + + w := worker.New(tc.client, tc.taskQueue, worker.Options{}) + service := nexus.NewService("test") + require.NoError(t, service.Register(op)) + w.RegisterNexusService(service) + w.RegisterWorkflow(callerWorkflow) + require.NoError(t, w.Start()) + t.Cleanup(w.Stop) + + run, err := tc.client.ExecuteWorkflow(ctx, client.StartWorkflowOptions{ + TaskQueue: tc.taskQueue, + // The endpoint registry may take a bit to propagate to the history service, use a shorter workflow task + // timeout to speed up the attempts. + WorkflowTaskTimeout: time.Second, + }, callerWorkflow) + require.NoError(t, err) + require.NoError(t, run.Get(ctx, nil)) + + events := make([]*historypb.HistoryEvent, 0) + hist := tc.client.GetWorkflowHistory(ctx, run.GetID(), run.GetRunID(), false, enums.HISTORY_EVENT_FILTER_TYPE_ALL_EVENT) + for hist.HasNext() { + e, err := hist.Next() + require.NoError(t, err) + events = append(events, e) + } + + t.Run("OK", func(t *testing.T) { + // endpointForTest, serviceForTest = tc.endpoint, "test" + rw := worker.NewWorkflowReplayer() + rw.RegisterWorkflow(callerWorkflow) + err = rw.ReplayWorkflowHistory(ilog.NewDefaultLogger(), &historypb.History{Events: events}) + require.NoError(t, err) + }) + + t.Run("EndpointMismatchOK", func(t *testing.T) { + endpointForTest = "endpoint-changed" // It's okay to change the endpoint as it is environment specific. + // endpointForTest, serviceForTest = tc.endpoint, "test" + rw := worker.NewWorkflowReplayer() + rw.RegisterWorkflow(callerWorkflow) + err = rw.ReplayWorkflowHistory(ilog.NewDefaultLogger(), &historypb.History{Events: events}) + require.NoError(t, err) + }) + + t.Run("ServiceMismatchNDE", func(t *testing.T) { + serviceForTest = "service-changed" + // endpointForTest, serviceForTest = tc.endpoint, "test" + rw := worker.NewWorkflowReplayer() + rw.RegisterWorkflow(callerWorkflow) + err = rw.ReplayWorkflowHistory(ilog.NewDefaultLogger(), &historypb.History{Events: events}) + require.ErrorContains(t, err, "[TMPRL1100]") + }) + + t.Run("OperationMismatchNDE", func(t *testing.T) { + serviceForTest = "test" // Restore + opForTest = "op-changed" + rw := worker.NewWorkflowReplayer() + rw.RegisterWorkflow(callerWorkflow) + err = rw.ReplayWorkflowHistory(ilog.NewDefaultLogger(), &historypb.History{Events: events}) + require.ErrorContains(t, err, "[TMPRL1100]") + }) } diff --git a/workflow/nexus_example_test.go b/workflow/nexus_example_test.go new file mode 100644 index 000000000..ee7ebd2ec --- /dev/null +++ b/workflow/nexus_example_test.go @@ -0,0 +1,50 @@ +package workflow_test + +import ( + "context" + "time" + + "github.com/nexus-rpc/sdk-go/nexus" + "go.temporal.io/sdk/client" + "go.temporal.io/sdk/temporalnexus" + "go.temporal.io/sdk/workflow" +) + +type MyInput struct{} +type MyOutput struct{} + +var myOperationRef = nexus.NewOperationReference[MyInput, MyOutput]("my-operation") + +var myOperation = temporalnexus.NewSyncOperation("my-operation", func(ctx context.Context, c client.Client, mi MyInput, soo nexus.StartOperationOptions) (MyOutput, error) { + return MyOutput{}, nil +}) + +func ExampleNexusClient() { + myWorkflow := func(ctx workflow.Context) (MyOutput, error) { + client := workflow.NewNexusClient("my-endpoint", "my-service") + // Execute an operation using an operation name. + fut := client.ExecuteOperation(ctx, "my-operation", MyInput{}, workflow.NexusOperationOptions{ + ScheduleToCloseTimeout: time.Hour, + }) + // Or using an OperationReference. + fut = client.ExecuteOperation(ctx, myOperationRef, MyInput{}, workflow.NexusOperationOptions{ + ScheduleToCloseTimeout: time.Hour, + }) + // Or using a defined operation (which is also an OperationReference). + fut = client.ExecuteOperation(ctx, myOperation, MyInput{}, workflow.NexusOperationOptions{ + ScheduleToCloseTimeout: time.Hour, + }) + + var exec workflow.NexusOperationExecution + // Optionally wait for the operation to be started. + _ = fut.GetNexusOperationExecution().Get(ctx, &exec) + // OperationID will be empty if the operation completed synchronously. + workflow.GetLogger(ctx).Info("operation started", "operationID", exec.OperationID) + + // Get the result of the operation. + var output MyOutput + return output, fut.Get(ctx, &output) + } + + _ = myWorkflow +} diff --git a/workflow/workflow.go b/workflow/workflow.go index addd568ea..7766a5d53 100644 --- a/workflow/workflow.go +++ b/workflow/workflow.go @@ -71,6 +71,26 @@ type ( ContinueAsNewErrorOptions = internal.ContinueAsNewErrorOptions UpdateHandlerOptions = internal.UpdateHandlerOptions + + // NexusClient is a client for executing Nexus Operations from a workflow. + // + // NOTE: Experimental + NexusClient = internal.NexusClient + + // NexusOperationOptions are options for starting a Nexus Operation from a Workflow. + // + // NOTE: Experimental + NexusOperationOptions = internal.NexusOperationOptions + + // NexusOperationFuture represents the result of a Nexus Operation. + // + // NOTE: Experimental + NexusOperationFuture = internal.NexusOperationFuture + + // NexusOperationExecution is the result of [NexusOperationFuture.GetNexusOperationExecution]. + // + // NOTE: Experimental + NexusOperationExecution = internal.NexusOperationExecution ) // ExecuteActivity requests activity execution in the context of a workflow. @@ -688,3 +708,8 @@ func DeterministicKeys[K constraints.Ordered, V any](m map[K]V) []K { func DeterministicKeysFunc[K comparable, V any](m map[K]V, cmp func(K, K) int) []K { return internal.DeterministicKeysFunc(m, cmp) } + +// Create a [NexusClient] from an endpoint name and a service name. +func NewNexusClient(endpoint, service string) NexusClient { + return internal.NewNexusClient(endpoint, service) +} From dcfa8cf3be4da4478e75219d3a7d444944ca75db Mon Sep 17 00:00:00 2001 From: Roey Berman Date: Thu, 16 May 2024 17:28:58 -0700 Subject: [PATCH 2/3] Address review comments --- internal/error.go | 2 ++ internal/internal_task_handlers.go | 12 ++++++++---- internal/workflow.go | 17 +++++++++++++++-- test/nexus_test.go | 2 +- workflow/workflow.go | 21 ++++++++++++++++++--- 5 files changed, 44 insertions(+), 10 deletions(-) diff --git a/internal/error.go b/internal/error.go index ba844c0ca..469cfa1ba 100644 --- a/internal/error.go +++ b/internal/error.go @@ -822,6 +822,7 @@ func (e *ChildWorkflowExecutionError) Unwrap() error { return e.cause } +// Error implements the error interface. func (e *NexusOperationError) Error() string { msg := fmt.Sprintf( "%s (endpoint: %q, service: %q, operation: %q, operation ID: %q, scheduledEventID: %d)", @@ -842,6 +843,7 @@ func (e *NexusOperationError) failure() *failurepb.Failure { return e.Failure } +// Unwrap returns the Cause associated with this error. func (e *NexusOperationError) Unwrap() error { return e.Cause } diff --git a/internal/internal_task_handlers.go b/internal/internal_task_handlers.go index 0f58aa08d..df53899cb 100644 --- a/internal/internal_task_handlers.go +++ b/internal/internal_task_handlers.go @@ -1697,14 +1697,18 @@ func isCommandMatchEvent(d *commandpb.Command, e *historypb.HistoryEvent, obes [ eventAttributes := e.GetNexusOperationScheduledEventAttributes() commandAttributes := d.GetScheduleNexusOperationCommandAttributes() - if eventAttributes.GetService() != commandAttributes.GetService() || eventAttributes.GetOperation() != commandAttributes.GetOperation() { + return eventAttributes.GetService() == commandAttributes.GetService() && + eventAttributes.GetOperation() == commandAttributes.GetOperation() + + case enumspb.COMMAND_TYPE_REQUEST_CANCEL_NEXUS_OPERATION: + if e.GetEventType() != enumspb.EVENT_TYPE_NEXUS_OPERATION_CANCEL_REQUESTED { return false } - return true + eventAttributes := e.GetNexusOperationCancelRequestedEventAttributes() + commandAttributes := d.GetRequestCancelNexusOperationCommandAttributes() - case enumspb.COMMAND_TYPE_REQUEST_CANCEL_NEXUS_OPERATION: - return e.GetEventType() == enumspb.EVENT_TYPE_NEXUS_OPERATION_CANCEL_REQUESTED + return eventAttributes.GetScheduledEventId() == commandAttributes.GetScheduledEventId() } return false diff --git a/internal/workflow.go b/internal/workflow.go index ef0e2b91e..186c450b1 100644 --- a/internal/workflow.go +++ b/internal/workflow.go @@ -2125,8 +2125,10 @@ type NexusOperationOptions struct { ScheduleToCloseTimeout time.Duration } -// NexusOperationExecution is the result of [NexusOperationFuture.GetNexusOperationExecution]. +// NexusOperationExecution is the result of NexusOperationFuture.GetNexusOperationExecution. type NexusOperationExecution struct { + // Operation ID as set by the Operation's handler. May be empty if the operation hasn't started yet or completed + // synchronously. OperationID string } @@ -2141,10 +2143,17 @@ type NexusOperationFuture interface { // synchronous operations. // // NOTE: Experimental + // + // fut := nexusClient.ExecuteOperation(ctx, op, ...) + // var exec workflow.NexusOperationExecution + // if err := fut.GetNexusOperationExecution().Get(ctx, &exec); err == nil { + // // Nexus Operation started, OperationID is optionally set. + // } GetNexusOperationExecution() Future } // NexusClient is a client for executing Nexus Operations from a workflow. +// NOTE to maintainers, this interface definition is duplicated in the workflow package to provide a better UX. type NexusClient interface { // The endpoint name this client uses. // @@ -2240,7 +2249,11 @@ func (wc *workflowEnvironmentInterceptor) ExecuteNexusOperation(ctx Context, cli var operationID string seq := wc.env.ExecuteNexusOperation(params, func(r *commonpb.Payload, e error) { - mainSettable.Set(&commonpb.Payloads{Payloads: []*commonpb.Payload{r}}, e) + var payloads *commonpb.Payloads + if r != nil { + payloads = &commonpb.Payloads{Payloads: []*commonpb.Payload{r}} + } + mainSettable.Set(payloads, e) if cancellable { // future is done, we don't need cancellation anymore ctxDone.removeReceiveCallback(cancellationCallback) diff --git a/test/nexus_test.go b/test/nexus_test.go index 5b518acb0..816b045eb 100644 --- a/test/nexus_test.go +++ b/test/nexus_test.go @@ -574,7 +574,7 @@ func TestAsyncOperationFromWorkflow(t *testing.T) { } func TestReplay(t *testing.T) { - ctx, cancel := context.WithTimeout(context.Background(), time.Second*5) + ctx, cancel := context.WithTimeout(context.Background(), time.Second*10) defer cancel() tc := newTestContext(t, ctx) diff --git a/workflow/workflow.go b/workflow/workflow.go index 7766a5d53..33100c3df 100644 --- a/workflow/workflow.go +++ b/workflow/workflow.go @@ -72,10 +72,25 @@ type ( UpdateHandlerOptions = internal.UpdateHandlerOptions + // NOTE to maintainers, this interface definition is duplicated in the internal package to provide a better UX. + // NexusClient is a client for executing Nexus Operations from a workflow. - // - // NOTE: Experimental - NexusClient = internal.NexusClient + NexusClient interface { + // The endpoint name this client uses. + // + // NOTE: Experimental + Endpoint() string + // The service name this client uses. + // + // NOTE: Experimental + Service() string + + // ExecuteOperation executes a Nexus Operation. + // The operation argument can be a string, a [nexus.Operation] or a [nexus.OperationReference]. + // + // NOTE: Experimental + ExecuteOperation(ctx Context, operation any, input any, options NexusOperationOptions) NexusOperationFuture + } // NexusOperationOptions are options for starting a Nexus Operation from a Workflow. // From 42984a064cc192a8912e1bd7ef108f5d02e03648 Mon Sep 17 00:00:00 2001 From: Roey Berman Date: Wed, 19 Jun 2024 16:36:50 -0700 Subject: [PATCH 3/3] Add test environment support for Nexus Operations (#1475) * Add test environment support for Nexus Operations * Change client to not allow any direct calls --- internal/internal_workflow_testsuite.go | 251 +++++++++++++++- internal/nexus_operations.go | 366 +++++++++++++++++++++++- internal/workflow_testsuite.go | 6 + temporalnexus/operation.go | 6 + test/nexus_test.go | 250 ++++++++++++++++ 5 files changed, 869 insertions(+), 10 deletions(-) diff --git a/internal/internal_workflow_testsuite.go b/internal/internal_workflow_testsuite.go index c7f944b19..b83dbecbc 100644 --- a/internal/internal_workflow_testsuite.go +++ b/internal/internal_workflow_testsuite.go @@ -29,12 +29,15 @@ import ( "errors" "fmt" "reflect" + "strconv" "strings" "sync" "time" "github.com/facebookgo/clock" "github.com/golang/mock/gomock" + "github.com/google/uuid" + "github.com/nexus-rpc/sdk-go/nexus" "github.com/robfig/cron" "github.com/stretchr/testify/mock" "google.golang.org/grpc" @@ -44,6 +47,8 @@ import ( commandpb "go.temporal.io/api/command/v1" commonpb "go.temporal.io/api/common/v1" enumspb "go.temporal.io/api/enums/v1" + failurepb "go.temporal.io/api/failure/v1" + nexuspb "go.temporal.io/api/nexus/v1" "go.temporal.io/api/serviceerror" taskqueuepb "go.temporal.io/api/taskqueue/v1" "go.temporal.io/api/workflowservice/v1" @@ -97,6 +102,18 @@ type ( err error } + testNexusOperationHandle struct { + env *testWorkflowEnvironmentImpl + seq int64 + params executeNexusOperationParams + operationID string + cancelRequested bool + started bool + done bool + onCompleted func(*commonpb.Payload, error) + onStarted func(opID string, e error) + } + testCallbackHandle struct { callback func() startWorkflowTask bool // start a new workflow task after callback() is handled. @@ -149,11 +166,12 @@ type ( testTimeout time.Duration header *commonpb.Header - counterID int64 - activities map[string]*testActivityHandle - localActivities map[string]*localActivityTask - timers map[string]*testTimerHandle - runningWorkflows map[string]*testWorkflowHandle + counterID int64 + activities map[string]*testActivityHandle + localActivities map[string]*localActivityTask + timers map[string]*testTimerHandle + runningWorkflows map[string]*testWorkflowHandle + runningNexusOperations map[int64]*testNexusOperationHandle runningCount int @@ -240,6 +258,7 @@ func newTestWorkflowEnvironmentImpl(s *WorkflowTestSuite, parentRegistry *regist activities: make(map[string]*testActivityHandle), localActivities: make(map[string]*localActivityTask), runningWorkflows: make(map[string]*testWorkflowHandle), + runningNexusOperations: make(map[int64]*testNexusOperationHandle), callbackChannel: make(chan testCallbackHandle, 1000), testTimeout: 3 * time.Second, expectedWorkflowMockCalls: make(map[string]struct{}), @@ -2121,6 +2140,10 @@ func (env *testWorkflowEnvironmentImpl) RegisterActivityWithOptions(a interface{ env.registry.RegisterActivityWithOptions(a, options) } +func (env *testWorkflowEnvironmentImpl) RegisterNexusService(s *nexus.Service) { + env.registry.RegisterNexusService(s) +} + func (env *testWorkflowEnvironmentImpl) RegisterCancelHandler(handler func()) { env.workflowCancelHandler = handler } @@ -2279,12 +2302,133 @@ func (env *testWorkflowEnvironmentImpl) executeChildWorkflowWithDelay(delayStart go childEnv.executeWorkflowInternal(delayStart, params.WorkflowType.Name, params.Input) } -func (wc *testWorkflowEnvironmentImpl) ExecuteNexusOperation(params executeNexusOperationParams, callback func(*commonpb.Payload, error), startedHandler func(opID string, e error)) int64 { - panic("TODO") +func (env *testWorkflowEnvironmentImpl) newTestNexusTaskHandler() *nexusTaskHandler { + if len(env.registry.nexusServices) == 0 { + panic(fmt.Errorf("no nexus services registered")) + } + + reg := nexus.NewServiceRegistry() + for _, service := range env.registry.nexusServices { + if err := reg.Register(service); err != nil { + panic(fmt.Errorf("failed to register nexus service '%v': %w", service, err)) + } + } + handler, err := reg.NewHandler() + if err != nil { + panic(fmt.Errorf("failed to create nexus handler: %w", err)) + } + + return newNexusTaskHandler( + handler, + env.identity, + env.workflowInfo.Namespace, + env.workflowInfo.TaskQueueName, + &testSuiteClientForNexusOperations{env: env}, + env.dataConverter, + env.logger, + env.metricsHandler, + ) +} + +func (env *testWorkflowEnvironmentImpl) ExecuteNexusOperation(params executeNexusOperationParams, callback func(*commonpb.Payload, error), startedHandler func(opID string, e error)) int64 { + seq := env.nextID() + taskHandler := env.newTestNexusTaskHandler() + handle := &testNexusOperationHandle{ + env: env, + seq: seq, + params: params, + onCompleted: callback, + onStarted: startedHandler, + } + env.runningNexusOperations[seq] = handle + + task := handle.newStartTask() + env.runningCount++ + go func() { + response, failure, err := taskHandler.Execute(task) + if err != nil { + // No retries for operations, fail the operation immediately. + failure = taskHandler.fillInFailure(task.TaskToken, nexusHandlerError(nexus.HandlerErrorTypeInternal, err.Error())) + } + if failure != nil { + err := env.failureConverter.FailureToError(nexusOperationFailure(params, "", &failurepb.Failure{ + Message: failure.GetError().GetFailure().GetMessage(), + FailureInfo: &failurepb.Failure_ApplicationFailureInfo{ + ApplicationFailureInfo: &failurepb.ApplicationFailureInfo{ + NonRetryable: true, + }, + }, + })) + env.postCallback(func() { + handle.startedCallback("", err) + handle.completedCallback(nil, err) + }, true) + return + } else { + switch v := response.GetResponse().GetStartOperation().GetVariant().(type) { + case *nexuspb.StartOperationResponse_SyncSuccess: + env.postCallback(func() { + handle.startedCallback("", nil) + handle.completedCallback(v.SyncSuccess.GetPayload(), nil) + }, true) + case *nexuspb.StartOperationResponse_AsyncSuccess: + env.postCallback(func() { + handle.startedCallback(v.AsyncSuccess.GetOperationId(), nil) + if handle.cancelRequested { + handle.cancel() + } + }, true) + case *nexuspb.StartOperationResponse_OperationError: + err := env.failureConverter.FailureToError( + nexusOperationFailure(params, "", unsuccessfulOperationErrorToTemporalFailure(v.OperationError)), + ) + env.postCallback(func() { + handle.startedCallback("", err) + handle.completedCallback(nil, err) + }, true) + default: + panic(fmt.Errorf("unknown response variant: %v", v)) + } + } + }() + return seq +} + +func (env *testWorkflowEnvironmentImpl) RequestCancelNexusOperation(seq int64) { + handle, ok := env.runningNexusOperations[seq] + if !ok { + panic(fmt.Errorf("no running operation found for sequence: %d", seq)) + } + + // Avoid duplicate cancelation. + if handle.cancelRequested { + return + } + + // Mark this cancelation request in case the operation hasn't started yet. + // Cancel will be called after start. + handle.cancelRequested = true + + // Only cancel after started, we need an operation ID. + if handle.started { + handle.cancel() + } } -func (wc *testWorkflowEnvironmentImpl) RequestCancelNexusOperation(seq int64) { - panic("TODO") +func (env *testWorkflowEnvironmentImpl) resolveNexusOperation(seq int64, result *commonpb.Payload, err error) { + env.postCallback(func() { + handle, ok := env.runningNexusOperations[seq] + if !ok { + panic(fmt.Errorf("no running operation found for sequence: %d", seq)) + } + if err != nil { + failure := env.failureConverter.ErrorToFailure(err) + err = env.failureConverter.FailureToError(nexusOperationFailure(handle.params, handle.operationID, failure.GetCause())) + handle.completedCallback(nil, err) + } else { + handle.completedCallback(result, nil) + } + }, true) } func (env *testWorkflowEnvironmentImpl) SideEffect(f func() (*commonpb.Payloads, error), callback ResultHandler) { @@ -2665,3 +2809,92 @@ func mockFnGetVersion(string, Version, Version) Version { // make sure interface is implemented var _ WorkflowEnvironment = (*testWorkflowEnvironmentImpl)(nil) + +func (h *testNexusOperationHandle) newStartTask() *workflowservice.PollNexusTaskQueueResponse { + return &workflowservice.PollNexusTaskQueueResponse{ + TaskToken: []byte{}, + Request: &nexuspb.Request{ + ScheduledTime: timestamppb.Now(), + Header: h.params.nexusHeader, + Variant: &nexuspb.Request_StartOperation{ + StartOperation: &nexuspb.StartOperationRequest{ + Service: h.params.client.Service(), + Operation: h.params.operation, + RequestId: uuid.NewString(), + // This is effectively ignored. + Callback: "http://test-env/operations", + CallbackHeader: map[string]string{ + // The test client uses this to call resolveNexusOperation. + "operation-sequence": strconv.FormatInt(h.seq, 10), + }, + Payload: h.params.input, + }, + }, + }, + } +} + +func (h *testNexusOperationHandle) newCancelTask() *workflowservice.PollNexusTaskQueueResponse { + return &workflowservice.PollNexusTaskQueueResponse{ + TaskToken: []byte{}, + Request: &nexuspb.Request{ + ScheduledTime: timestamppb.Now(), + Header: h.params.nexusHeader, + Variant: &nexuspb.Request_CancelOperation{ + CancelOperation: &nexuspb.CancelOperationRequest{ + Service: h.params.client.Service(), + Operation: h.params.operation, + OperationId: h.operationID, + }, + }, + }, + } +} + +// completedCallback is a callback registered to handle operation completion. +// Must be called in a postCallback block. +func (h *testNexusOperationHandle) completedCallback(result *commonpb.Payload, err error) { + if h.done { + // Ignore duplicate completions. + return + } + h.done = true + delete(h.env.runningNexusOperations, h.seq) + h.onCompleted(result, err) +} + +// startedCallback is a callback registered to handle operation start. +// Must be called in a postCallback block. +func (h *testNexusOperationHandle) startedCallback(opID string, e error) { + h.operationID = opID + h.started = true + h.onStarted(opID, e) + h.env.runningCount-- +} + +func (h *testNexusOperationHandle) cancel() { + if h.done { + return + } + if h.started && h.operationID == "" { + panic(fmt.Errorf("incomplete operation has no operation ID: (%s, %s, %s)", + h.params.client.Endpoint(), h.params.client.Service(), h.params.operation)) + } + h.env.runningCount++ + task := h.newCancelTask() + taskHandler := h.env.newTestNexusTaskHandler() + + go func() { + _, failure, err := taskHandler.Execute(task) + h.env.postCallback(func() { + if err != nil { + // No retries in the test env, fail the operation immediately. + h.completedCallback(nil, fmt.Errorf("operation cancelation handler failed: %w", err)) + } else if failure != nil { + // No retries in the test env, fail the operation immediately. + h.completedCallback(nil, fmt.Errorf("operation cancelation handler failed: %v", failure.GetError().GetFailure().GetMessage())) + } + h.env.runningCount-- + }, false) + }() +} diff --git a/internal/nexus_operations.go b/internal/nexus_operations.go index 021e24480..684016ebb 100644 --- a/internal/nexus_operations.go +++ b/internal/nexus_operations.go @@ -2,7 +2,17 @@ package internal import ( "context" + "fmt" + "strconv" + "github.com/nexus-rpc/sdk-go/nexus" + commonpb "go.temporal.io/api/common/v1" + "go.temporal.io/api/enums/v1" + failurepb "go.temporal.io/api/failure/v1" + nexuspb "go.temporal.io/api/nexus/v1" + "go.temporal.io/api/operatorservice/v1" + "go.temporal.io/api/workflowservice/v1" + "go.temporal.io/sdk/converter" "go.temporal.io/sdk/log" ) @@ -13,11 +23,365 @@ type NexusOperationContext struct { Log log.Logger } +type nexusOperationContextKeyType struct{} + // nexusOperationContextKey is a key for associating a [NexusOperationContext] with a [context.Context]. -var nexusOperationContextKey = struct{}{} +var nexusOperationContextKey = nexusOperationContextKeyType{} + +type isWorkflowRunOpContextKeyType struct{} + +// IsWorkflowRunOpContextKey is a key to mark that the current context is used within a workflow run operation. +// The fake test env client verifies this key is set on the context to decide whether it should execute a method or +// panic as we don't want to expose a partial client to sync operations. +var IsWorkflowRunOpContextKey = isWorkflowRunOpContextKeyType{} // NexusOperationContextFromGoContext gets the [NexusOperationContext] associated with the given [context.Context]. func NexusOperationContextFromGoContext(ctx context.Context) (nctx *NexusOperationContext, ok bool) { nctx, ok = ctx.Value(nexusOperationContextKey).(*NexusOperationContext) return } + +// nexusOperationFailure is a utility in use by the test environment. +func nexusOperationFailure(params executeNexusOperationParams, operationID string, cause *failurepb.Failure) *failurepb.Failure { + return &failurepb.Failure{ + Message: "nexus operation completed unsuccessfully", + FailureInfo: &failurepb.Failure_NexusOperationExecutionFailureInfo{ + NexusOperationExecutionFailureInfo: &failurepb.NexusOperationFailureInfo{ + Endpoint: params.client.Endpoint(), + Service: params.client.Service(), + Operation: params.operation, + OperationId: operationID, + }, + }, + Cause: cause, + } +} + +// unsuccessfulOperationErrorToTemporalFailure is a utility in use by the test environment. +// copied from the server codebase with a slight adaptation: https://github.com/temporalio/temporal/blob/7635cd7dbdc7dd3219f387e8fc66fa117f585ff6/common/nexus/failure.go#L69-L108 +func unsuccessfulOperationErrorToTemporalFailure(err *nexuspb.UnsuccessfulOperationError) *failurepb.Failure { + failure := &failurepb.Failure{ + Message: err.Failure.Message, + } + if err.OperationState == string(nexus.OperationStateCanceled) { + failure.FailureInfo = &failurepb.Failure_CanceledFailureInfo{ + CanceledFailureInfo: &failurepb.CanceledFailureInfo{ + Details: nexusFailureMetadataToPayloads(err.Failure), + }, + } + } else { + failure.FailureInfo = &failurepb.Failure_ApplicationFailureInfo{ + ApplicationFailureInfo: &failurepb.ApplicationFailureInfo{ + // Make up a type here, it's not part of the Nexus Failure spec. + Type: "NexusOperationFailure", + Details: nexusFailureMetadataToPayloads(err.Failure), + NonRetryable: true, + }, + } + } + return failure +} + +// nexusFailureMetadataToPayloads is a utility in use by the test environment. +// copied from the server codebase with a slight adaptation: https://github.com/temporalio/temporal/blob/7635cd7dbdc7dd3219f387e8fc66fa117f585ff6/common/nexus/failure.go#L69-L108 +func nexusFailureMetadataToPayloads(failure *nexuspb.Failure) *commonpb.Payloads { + if len(failure.Metadata) == 0 && len(failure.Details) == 0 { + return nil + } + metadata := make(map[string][]byte, len(failure.Metadata)) + for k, v := range failure.Metadata { + metadata[k] = []byte(v) + } + return &commonpb.Payloads{ + Payloads: []*commonpb.Payload{ + { + Metadata: metadata, + Data: failure.Details, + }, + }, + } +} + +// testSuiteClientForNexusOperations is a partial [Client] implementation for the test workflow environment used to +// support running the workflow run operation - and only this operation, all methods will panic when this client is +// passed to sync operations. +type testSuiteClientForNexusOperations struct { + env *testWorkflowEnvironmentImpl +} + +// CancelWorkflow implements Client. +func (t *testSuiteClientForNexusOperations) CancelWorkflow(ctx context.Context, workflowID string, runID string) error { + if set, ok := ctx.Value(IsWorkflowRunOpContextKey).(bool); !ok || !set { + panic("not implemented in the test environment") + } + doneCh := make(chan error) + t.env.cancelWorkflowByID(workflowID, runID, func(result *commonpb.Payloads, err error) { + doneCh <- err + }) + return <-doneCh +} + +// CheckHealth implements Client. +func (t *testSuiteClientForNexusOperations) CheckHealth(ctx context.Context, request *CheckHealthRequest) (*CheckHealthResponse, error) { + return &CheckHealthResponse{}, nil +} + +// Close implements Client. +func (t *testSuiteClientForNexusOperations) Close() { + // No op. +} + +// CompleteActivity implements Client. +func (t *testSuiteClientForNexusOperations) CompleteActivity(ctx context.Context, taskToken []byte, result interface{}, err error) error { + panic("not implemented in the test environment") +} + +// CompleteActivityByID implements Client. +func (t *testSuiteClientForNexusOperations) CompleteActivityByID(ctx context.Context, namespace string, workflowID string, runID string, activityID string, result interface{}, err error) error { + panic("not implemented in the test environment") +} + +// CountWorkflow implements Client. +func (t *testSuiteClientForNexusOperations) CountWorkflow(ctx context.Context, request *workflowservice.CountWorkflowExecutionsRequest) (*workflowservice.CountWorkflowExecutionsResponse, error) { + panic("not implemented in the test environment") +} + +// DescribeTaskQueue implements Client. +func (t *testSuiteClientForNexusOperations) DescribeTaskQueue(ctx context.Context, taskqueue string, taskqueueType enums.TaskQueueType) (*workflowservice.DescribeTaskQueueResponse, error) { + panic("not implemented in the test environment") +} + +// DescribeWorkflowExecution implements Client. +func (t *testSuiteClientForNexusOperations) DescribeWorkflowExecution(ctx context.Context, workflowID string, runID string) (*workflowservice.DescribeWorkflowExecutionResponse, error) { + panic("not implemented in the test environment") +} + +// ExecuteWorkflow implements Client. +func (t *testSuiteClientForNexusOperations) ExecuteWorkflow(ctx context.Context, options StartWorkflowOptions, workflow interface{}, args ...interface{}) (WorkflowRun, error) { + if set, ok := ctx.Value(IsWorkflowRunOpContextKey).(bool); !ok || !set { + panic("not implemented in the test environment") + } + wfType, input, err := getValidatedWorkflowFunction(workflow, args, t.env.dataConverter, t.env.GetRegistry()) + if err != nil { + return nil, fmt.Errorf("cannot validate workflow function: %w", err) + } + + run := &testEnvWorkflowRunForNexusOperations{} + doneCh := make(chan error) + + var callback *commonpb.Callback + + if len(options.callbacks) > 0 { + callback = options.callbacks[0] + } + + t.env.executeChildWorkflowWithDelay(options.StartDelay, ExecuteWorkflowParams{ + // Not propagating Header as this client does not support context propagation. + WorkflowType: wfType, + Input: input, + WorkflowOptions: WorkflowOptions{ + WaitForCancellation: true, + Namespace: t.env.workflowInfo.Namespace, + TaskQueueName: t.env.workflowInfo.TaskQueueName, + WorkflowID: options.ID, + WorkflowExecutionTimeout: options.WorkflowExecutionTimeout, + WorkflowRunTimeout: options.WorkflowRunTimeout, + WorkflowTaskTimeout: options.WorkflowTaskTimeout, + DataConverter: t.env.dataConverter, + WorkflowIDReusePolicy: options.WorkflowIDReusePolicy, + ContextPropagators: t.env.contextPropagators, + SearchAttributes: options.SearchAttributes, + TypedSearchAttributes: options.TypedSearchAttributes, + ParentClosePolicy: enums.PARENT_CLOSE_POLICY_ABANDON, + Memo: options.Memo, + CronSchedule: options.CronSchedule, + RetryPolicy: convertToPBRetryPolicy(options.RetryPolicy), + }, + }, func(result *commonpb.Payloads, wfErr error) { + ncb := callback.GetNexus() + if ncb == nil { + return + } + seqStr := ncb.GetHeader()["operation-sequence"] + if seqStr == "" { + return + } + seq, err := strconv.ParseInt(seqStr, 10, 64) + if err != nil { + panic(fmt.Errorf("unexpected operation sequence in callback header: %s: %w", seqStr, err)) + } + + if wfErr != nil { + t.env.resolveNexusOperation(seq, nil, wfErr) + } else { + var payload *commonpb.Payload + if len(result.GetPayloads()) > 0 { + payload = result.Payloads[0] + } + t.env.resolveNexusOperation(seq, payload, nil) + } + }, func(r WorkflowExecution, err error) { + run.WorkflowExecution = r + doneCh <- err + }) + err = <-doneCh + if err != nil { + return nil, err + } + return run, nil +} + +// GetSearchAttributes implements Client. +func (t *testSuiteClientForNexusOperations) GetSearchAttributes(ctx context.Context) (*workflowservice.GetSearchAttributesResponse, error) { + panic("not implemented in the test environment") +} + +// GetWorkerBuildIdCompatibility implements Client. +func (t *testSuiteClientForNexusOperations) GetWorkerBuildIdCompatibility(ctx context.Context, options *GetWorkerBuildIdCompatibilityOptions) (*WorkerBuildIDVersionSets, error) { + panic("not implemented in the test environment") +} + +// GetWorkerTaskReachability implements Client. +func (t *testSuiteClientForNexusOperations) GetWorkerTaskReachability(ctx context.Context, options *GetWorkerTaskReachabilityOptions) (*WorkerTaskReachability, error) { + panic("not implemented in the test environment") +} + +// GetWorkflow implements Client. +func (t *testSuiteClientForNexusOperations) GetWorkflow(ctx context.Context, workflowID string, runID string) WorkflowRun { + panic("not implemented in the test environment") +} + +// GetWorkflowHistory implements Client. +func (t *testSuiteClientForNexusOperations) GetWorkflowHistory(ctx context.Context, workflowID string, runID string, isLongPoll bool, filterType enums.HistoryEventFilterType) HistoryEventIterator { + panic("not implemented in the test environment") +} + +// GetWorkflowUpdateHandle implements Client. +func (t *testSuiteClientForNexusOperations) GetWorkflowUpdateHandle(GetWorkflowUpdateHandleOptions) WorkflowUpdateHandle { + panic("not implemented in the test environment") +} + +// ListArchivedWorkflow implements Client. +func (t *testSuiteClientForNexusOperations) ListArchivedWorkflow(ctx context.Context, request *workflowservice.ListArchivedWorkflowExecutionsRequest) (*workflowservice.ListArchivedWorkflowExecutionsResponse, error) { + panic("not implemented in the test environment") +} + +// ListClosedWorkflow implements Client. +func (t *testSuiteClientForNexusOperations) ListClosedWorkflow(ctx context.Context, request *workflowservice.ListClosedWorkflowExecutionsRequest) (*workflowservice.ListClosedWorkflowExecutionsResponse, error) { + panic("not implemented in the test environment") +} + +// ListOpenWorkflow implements Client. +func (t *testSuiteClientForNexusOperations) ListOpenWorkflow(ctx context.Context, request *workflowservice.ListOpenWorkflowExecutionsRequest) (*workflowservice.ListOpenWorkflowExecutionsResponse, error) { + panic("not implemented in the test environment") +} + +// ListWorkflow implements Client. +func (t *testSuiteClientForNexusOperations) ListWorkflow(ctx context.Context, request *workflowservice.ListWorkflowExecutionsRequest) (*workflowservice.ListWorkflowExecutionsResponse, error) { + panic("not implemented in the test environment") +} + +// OperatorService implements Client. +func (t *testSuiteClientForNexusOperations) OperatorService() operatorservice.OperatorServiceClient { + panic("not implemented in the test environment") +} + +// QueryWorkflow implements Client. +func (t *testSuiteClientForNexusOperations) QueryWorkflow(ctx context.Context, workflowID string, runID string, queryType string, args ...interface{}) (converter.EncodedValue, error) { + panic("not implemented in the test environment") +} + +// QueryWorkflowWithOptions implements Client. +func (t *testSuiteClientForNexusOperations) QueryWorkflowWithOptions(ctx context.Context, request *QueryWorkflowWithOptionsRequest) (*QueryWorkflowWithOptionsResponse, error) { + panic("not implemented in the test environment") +} + +// RecordActivityHeartbeat implements Client. +func (t *testSuiteClientForNexusOperations) RecordActivityHeartbeat(ctx context.Context, taskToken []byte, details ...interface{}) error { + panic("not implemented in the test environment") +} + +// RecordActivityHeartbeatByID implements Client. +func (t *testSuiteClientForNexusOperations) RecordActivityHeartbeatByID(ctx context.Context, namespace string, workflowID string, runID string, activityID string, details ...interface{}) error { + panic("not implemented in the test environment") +} + +// ResetWorkflowExecution implements Client. +func (t *testSuiteClientForNexusOperations) ResetWorkflowExecution(ctx context.Context, request *workflowservice.ResetWorkflowExecutionRequest) (*workflowservice.ResetWorkflowExecutionResponse, error) { + panic("not implemented in the test environment") +} + +// ScanWorkflow implements Client. +func (t *testSuiteClientForNexusOperations) ScanWorkflow(ctx context.Context, request *workflowservice.ScanWorkflowExecutionsRequest) (*workflowservice.ScanWorkflowExecutionsResponse, error) { + panic("not implemented in the test environment") +} + +// ScheduleClient implements Client. +func (t *testSuiteClientForNexusOperations) ScheduleClient() ScheduleClient { + panic("not implemented in the test environment") +} + +// SignalWithStartWorkflow implements Client. +func (t *testSuiteClientForNexusOperations) SignalWithStartWorkflow(ctx context.Context, workflowID string, signalName string, signalArg interface{}, options StartWorkflowOptions, workflow interface{}, workflowArgs ...interface{}) (WorkflowRun, error) { + panic("not implemented in the test environment") +} + +// SignalWorkflow implements Client. +func (t *testSuiteClientForNexusOperations) SignalWorkflow(ctx context.Context, workflowID string, runID string, signalName string, arg interface{}) error { + panic("not implemented in the test environment") +} + +// TerminateWorkflow implements Client. +func (t *testSuiteClientForNexusOperations) TerminateWorkflow(ctx context.Context, workflowID string, runID string, reason string, details ...interface{}) error { + panic("not implemented in the test environment") +} + +// UpdateWorkerBuildIdCompatibility implements Client. +func (t *testSuiteClientForNexusOperations) UpdateWorkerBuildIdCompatibility(ctx context.Context, options *UpdateWorkerBuildIdCompatibilityOptions) error { + panic("not implemented in the test environment") +} + +// UpdateWorkflow implements Client. +func (t *testSuiteClientForNexusOperations) UpdateWorkflow(ctx context.Context, workflowID string, workflowRunID string, updateName string, args ...interface{}) (WorkflowUpdateHandle, error) { + panic("not implemented in the test environment") +} + +// UpdateWorkflowWithOptions implements Client. +func (t *testSuiteClientForNexusOperations) UpdateWorkflowWithOptions(ctx context.Context, request *UpdateWorkflowWithOptionsRequest) (WorkflowUpdateHandle, error) { + panic("not implemented in the test environment") +} + +// WorkflowService implements Client. +func (t *testSuiteClientForNexusOperations) WorkflowService() workflowservice.WorkflowServiceClient { + panic("not implemented in the test environment") +} + +var _ Client = &testSuiteClientForNexusOperations{} + +// testEnvWorkflowRunForNexusOperations is a partial [WorkflowRun] implementation for the test workflow environment used +// to support basic Nexus functionality. +type testEnvWorkflowRunForNexusOperations struct { + WorkflowExecution +} + +// Get implements WorkflowRun. +func (t *testEnvWorkflowRunForNexusOperations) Get(ctx context.Context, valuePtr interface{}) error { + panic("not implemented in the test environment") +} + +// GetID implements WorkflowRun. +func (t *testEnvWorkflowRunForNexusOperations) GetID() string { + return t.ID +} + +// GetRunID implements WorkflowRun. +func (t *testEnvWorkflowRunForNexusOperations) GetRunID() string { + return t.RunID +} + +// GetWithOptions implements WorkflowRun. +func (t *testEnvWorkflowRunForNexusOperations) GetWithOptions(ctx context.Context, valuePtr interface{}, options WorkflowRunGetOptions) error { + panic("not implemented in the test environment") +} + +var _ WorkflowRun = &testEnvWorkflowRunForNexusOperations{} diff --git a/internal/workflow_testsuite.go b/internal/workflow_testsuite.go index 811290c56..6ac2dd4ac 100644 --- a/internal/workflow_testsuite.go +++ b/internal/workflow_testsuite.go @@ -32,6 +32,7 @@ import ( "testing" "time" + "github.com/nexus-rpc/sdk-go/nexus" "github.com/stretchr/testify/mock" commonpb "go.temporal.io/api/common/v1" enumspb "go.temporal.io/api/enums/v1" @@ -290,6 +291,11 @@ func (e *TestWorkflowEnvironment) RegisterActivityWithOptions(a interface{}, opt e.impl.RegisterActivityWithOptions(a, options) } +// RegisterWorkflow registers a Nexus Service with the TestWorkflowEnvironment. +func (e *TestWorkflowEnvironment) RegisterNexusService(s *nexus.Service) { + e.impl.RegisterNexusService(s) +} + // SetStartTime sets the start time of the workflow. This is optional, default start time will be the wall clock time when // workflow starts. Start time is the workflow.Now(ctx) time at the beginning of the workflow. func (e *TestWorkflowEnvironment) SetStartTime(startTime time.Time) { diff --git a/temporalnexus/operation.go b/temporalnexus/operation.go index 059e60e1b..4e0171012 100644 --- a/temporalnexus/operation.go +++ b/temporalnexus/operation.go @@ -145,6 +145,9 @@ func MustNewWorkflowRunOperationWithOptions[I, O any](options WorkflowRunOperati } func (*workflowRunOperation[I, O]) Cancel(ctx context.Context, id string, options nexus.CancelOperationOptions) error { + // Prevent the test env client from panicking when we try to use it from a workflow run operation. + ctx = context.WithValue(ctx, internal.IsWorkflowRunOpContextKey, true) + nctx, ok := internal.NexusOperationContextFromGoContext(ctx) if !ok { return nexus.HandlerErrorf(nexus.HandlerErrorTypeInternal, "internal error") @@ -157,6 +160,9 @@ func (o *workflowRunOperation[I, O]) Name() string { } func (o *workflowRunOperation[I, O]) Start(ctx context.Context, input I, options nexus.StartOperationOptions) (nexus.HandlerStartOperationResult[O], error) { + // Prevent the test env client from panicking when we try to use it from a workflow run operation. + ctx = context.WithValue(ctx, internal.IsWorkflowRunOpContextKey, true) + if o.options.Handler != nil { handle, err := o.options.Handler(ctx, input, options) if err != nil { diff --git a/test/nexus_test.go b/test/nexus_test.go index 816b045eb..3955feec3 100644 --- a/test/nexus_test.go +++ b/test/nexus_test.go @@ -24,6 +24,7 @@ package test_test import ( "context" + "errors" "fmt" "net/http" "slices" @@ -44,6 +45,7 @@ import ( ilog "go.temporal.io/sdk/internal/log" "go.temporal.io/sdk/temporal" "go.temporal.io/sdk/temporalnexus" + "go.temporal.io/sdk/testsuite" "go.temporal.io/sdk/worker" "go.temporal.io/sdk/workflow" ) @@ -657,3 +659,251 @@ func TestReplay(t *testing.T) { require.ErrorContains(t, err, "[TMPRL1100]") }) } + +func TestWorkflowTestSuite_NexusSyncOperation(t *testing.T) { + op := nexus.NewSyncOperation("op", func(ctx context.Context, outcome string, opts nexus.StartOperationOptions) (string, error) { + switch outcome { + case "ok": + return outcome, nil + case "failure": + return "", &nexus.UnsuccessfulOperationError{ + State: nexus.OperationStateFailed, + Failure: nexus.Failure{ + Message: "test operation failed", + }, + } + case "handler-error": + return "", &nexus.HandlerError{ + Type: nexus.HandlerErrorTypeBadRequest, + Failure: &nexus.Failure{ + Message: "test operation failed", + }, + } + } + panic(fmt.Errorf("invalid outcome: %q", outcome)) + }) + wf := func(ctx workflow.Context, outcome string) error { + client := workflow.NewNexusClient("endpoint", "test") + fut := client.ExecuteOperation(ctx, op, outcome, workflow.NexusOperationOptions{}) + var exec workflow.NexusOperationExecution + if err := fut.GetNexusOperationExecution().Get(ctx, &exec); err != nil { + return err + } + var res string + if err := fut.Get(ctx, &res); err != nil { + return err + } + if res != "ok" { + return fmt.Errorf("unexpected result: %v", res) + } + return nil + } + + service := nexus.NewService("test") + service.Register(op) + + t.Run("ok", func(t *testing.T) { + suite := testsuite.WorkflowTestSuite{} + env := suite.NewTestWorkflowEnvironment() + env.RegisterNexusService(service) + env.ExecuteWorkflow(wf, "ok") + require.True(t, env.IsWorkflowCompleted()) + require.NoError(t, env.GetWorkflowError()) + }) + + for _, outcome := range []string{"failure", "handler-error"} { + outcome := outcome // capture just in case. + t.Run(outcome, func(t *testing.T) { + suite := testsuite.WorkflowTestSuite{} + env := suite.NewTestWorkflowEnvironment() + env.RegisterNexusService(service) + env.ExecuteWorkflow(wf, "failure") + require.True(t, env.IsWorkflowCompleted()) + var execErr *temporal.WorkflowExecutionError + err := env.GetWorkflowError() + require.ErrorAs(t, err, &execErr) + var opErr *temporal.NexusOperationError + err = execErr.Unwrap() + require.ErrorAs(t, err, &opErr) + require.Equal(t, "endpoint", opErr.Endpoint) + require.Equal(t, "test", opErr.Service) + require.Equal(t, op.Name(), opErr.Operation) + require.Empty(t, opErr.OperationID) + require.Equal(t, "nexus operation completed unsuccessfully", opErr.Message) + err = opErr.Unwrap() + var appErr *temporal.ApplicationError + require.ErrorAs(t, err, &appErr) + require.Equal(t, "test operation failed", appErr.Message()) + }) + } +} + +func TestWorkflowTestSuite_WorkflowRunOperation(t *testing.T) { + handlerWF := func(ctx workflow.Context, outcome string) (string, error) { + if outcome == "ok" { + return "ok", nil + } + return "", fmt.Errorf("expected failure") + } + + op := temporalnexus.NewWorkflowRunOperation( + "op", + handlerWF, + func(ctx context.Context, id string, opts nexus.StartOperationOptions) (client.StartWorkflowOptions, error) { + return client.StartWorkflowOptions{ID: opts.RequestID}, nil + }) + + callerWF := func(ctx workflow.Context, outcome string) error { + client := workflow.NewNexusClient("endpoint", "test") + fut := client.ExecuteOperation(ctx, op, outcome, workflow.NexusOperationOptions{}) + var exec workflow.NexusOperationExecution + if err := fut.GetNexusOperationExecution().Get(ctx, &exec); err != nil { + return err + } + if exec.OperationID == "" { + return errors.New("got empty operation ID") + } + + var result string + if err := fut.Get(ctx, &result); err != nil { + return err + } + if result != "ok" { + return fmt.Errorf("expected result to be 'ok', got: %s", result) + } + return nil + } + + service := nexus.NewService("test") + service.Register(op) + + t.Run("ok", func(t *testing.T) { + suite := testsuite.WorkflowTestSuite{} + env := suite.NewTestWorkflowEnvironment() + env.RegisterWorkflow(handlerWF) + env.RegisterNexusService(service) + + env.ExecuteWorkflow(callerWF, "ok") + require.True(t, env.IsWorkflowCompleted()) + require.NoError(t, env.GetWorkflowError()) + }) + + t.Run("fail", func(t *testing.T) { + suite := testsuite.WorkflowTestSuite{} + env := suite.NewTestWorkflowEnvironment() + env.RegisterWorkflow(handlerWF) + env.RegisterNexusService(service) + + env.ExecuteWorkflow(callerWF, "fail") + require.True(t, env.IsWorkflowCompleted()) + + var execErr *temporal.WorkflowExecutionError + err := env.GetWorkflowError() + require.ErrorAs(t, err, &execErr) + var opErr *temporal.NexusOperationError + err = execErr.Unwrap() + require.ErrorAs(t, err, &opErr) + require.Equal(t, "endpoint", opErr.Endpoint) + require.Equal(t, "test", opErr.Service) + require.Equal(t, op.Name(), opErr.Operation) + require.Empty(t, opErr.OperationID) + require.Equal(t, "nexus operation completed unsuccessfully", opErr.Message) + err = opErr.Unwrap() + var appErr *temporal.ApplicationError + require.ErrorAs(t, err, &appErr) + require.Equal(t, "expected failure", appErr.Message()) + }) +} + +func TestWorkflowTestSuite_WorkflowRunOperation_WithCancel(t *testing.T) { + wf := func(ctx workflow.Context, cancelBeforeStarted bool) error { + childCtx, cancel := workflow.WithCancel(ctx) + defer cancel() + + client := workflow.NewNexusClient("endpoint", "test") + fut := client.ExecuteOperation(childCtx, workflowOp, "op-id", workflow.NexusOperationOptions{}) + if cancelBeforeStarted { + cancel() + } + var exec workflow.NexusOperationExecution + if err := fut.GetNexusOperationExecution().Get(ctx, &exec); err != nil { + return err + } + if exec.OperationID != "op-id" { + return fmt.Errorf("unexpected operation ID: %q", exec.OperationID) + } + + if !cancelBeforeStarted { + cancel() + } + err := fut.Get(ctx, nil) + return err + } + + service := nexus.NewService("test") + service.Register(workflowOp) + + cases := []struct { + cancelBeforeStarted bool + name string + }{ + {false, "AfterStarted"}, + {true, "BeforeStarted"}, + } + for _, tc := range cases { + tc := tc // capture just in case. + t.Run(tc.name, func(t *testing.T) { + suite := testsuite.WorkflowTestSuite{} + env := suite.NewTestWorkflowEnvironment() + env.RegisterWorkflow(waitForCancelWorkflow) + env.RegisterNexusService(service) + env.ExecuteWorkflow(wf, tc.cancelBeforeStarted) + require.True(t, env.IsWorkflowCompleted()) + // Error wrapping is different in the test environment than the server (same as for child workflows). + var execErr *temporal.WorkflowExecutionError + err := env.GetWorkflowError() + require.ErrorAs(t, err, &execErr) + var opErr *temporal.NexusOperationError + err = execErr.Unwrap() + require.ErrorAs(t, err, &opErr) + require.Equal(t, "endpoint", opErr.Endpoint) + require.Equal(t, "test", opErr.Service) + require.Equal(t, workflowOp.Name(), opErr.Operation) + require.Equal(t, "op-id", opErr.OperationID) + require.Equal(t, "nexus operation completed unsuccessfully", opErr.Message) + err = opErr.Unwrap() + var canceledError *temporal.CanceledError + require.ErrorAs(t, err, &canceledError) + }) + } +} + +func TestWorkflowTestSuite_NexusSyncOperation_ClientMethods_Panic(t *testing.T) { + var panicReason any + op := temporalnexus.NewSyncOperation("signal-op", func(ctx context.Context, c client.Client, _ nexus.NoValue, opts nexus.StartOperationOptions) (nexus.NoValue, error) { + func() { + defer func() { + panicReason = recover() + }() + c.ExecuteWorkflow(ctx, client.StartWorkflowOptions{}, "test", "", "get-secret") + }() + return nil, nil + }) + wf := func(ctx workflow.Context) error { + client := workflow.NewNexusClient("endpoint", "test") + fut := client.ExecuteOperation(ctx, op, nil, workflow.NexusOperationOptions{}) + return fut.Get(ctx, nil) + } + + service := nexus.NewService("test") + service.Register(op) + + suite := testsuite.WorkflowTestSuite{} + env := suite.NewTestWorkflowEnvironment() + env.RegisterWorkflow(waitForCancelWorkflow) + env.RegisterNexusService(service) + env.ExecuteWorkflow(wf) + require.True(t, env.IsWorkflowCompleted()) + require.NoError(t, env.GetWorkflowError()) + require.Equal(t, "not implemented in the test environment", panicReason) +}