Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adds a stream interceptor to keep state between the send and receive calls #437

Merged
merged 1 commit into from
May 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions runner/data.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,16 @@ type StreamMessageProviderFunc func(*CallData) (*dynamic.Message, error)
// Clients can return ErrEndStream to end the call early
type StreamRecvMsgInterceptFunc func(*dynamic.Message, error) error

// StreamInterceptorProviderFunc is an interface for a function invoked to generate a stream interceptor
type StreamInterceptorProviderFunc func(*CallData) StreamInterceptor

// StreamInterceptor is an interface for sending and receiving stream messages.
// The interceptor can keep shared state for the send and receive calls.
type StreamInterceptor interface {
Recv(*dynamic.Message, error) error
Send(*CallData) (*dynamic.Message, error)
}

type dataProvider struct {
binary bool
data []byte
Expand Down
22 changes: 16 additions & 6 deletions runner/options.go
Original file line number Diff line number Diff line change
Expand Up @@ -129,12 +129,13 @@ type RunConfig struct {
disableTemplateData bool

// misc
name string
cpus int
tags []byte
skipFirst int
countErrors bool
recvMsgFunc StreamRecvMsgInterceptFunc
name string
cpus int
tags []byte
skipFirst int
countErrors bool
recvMsgFunc StreamRecvMsgInterceptFunc
streamInterceptorProviderFunc StreamInterceptorProviderFunc
}

// Option controls some aspect of run
Expand Down Expand Up @@ -1034,6 +1035,15 @@ func WithStreamRecvMsgIntercept(fn StreamRecvMsgInterceptFunc) Option {
}
}

// WithStreamInterceptor specifies the stream interceptor provider function
func WithStreamInterceptorProviderFunc(interceptor StreamInterceptorProviderFunc) Option {
return func(o *RunConfig) error {
o.streamInterceptorProviderFunc = interceptor

return nil
}
}

// WithDataProvider provides custom data provider
//
// WithDataProvider(func(*CallData) ([]*dynamic.Message, error) {
Expand Down
23 changes: 12 additions & 11 deletions runner/requester.go
Original file line number Diff line number Diff line change
Expand Up @@ -389,17 +389,18 @@ func (b *Requester) runWorkers(wt load.WorkerTicker, p load.Pacer) error {
}

w := Worker{
ticks: ticks,
active: true,
stub: b.stubs[n],
mtd: b.mtd,
config: b.config,
stopCh: make(chan bool),
workerID: wID,
dataProvider: b.dataProvider,
metadataProvider: b.metadataProvider,
streamRecv: b.config.recvMsgFunc,
msgProvider: b.config.dataStreamFunc,
ticks: ticks,
active: true,
stub: b.stubs[n],
mtd: b.mtd,
config: b.config,
stopCh: make(chan bool),
workerID: wID,
dataProvider: b.dataProvider,
metadataProvider: b.metadataProvider,
streamRecv: b.config.recvMsgFunc,
msgProvider: b.config.dataStreamFunc,
streamInterceptorProviderFunc: b.config.streamInterceptorProviderFunc,
}

wc++ // increment worker id
Expand Down
45 changes: 40 additions & 5 deletions runner/worker.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,8 @@ type Worker struct {
metadataProvider MetadataProviderFunc
msgProvider StreamMessageProviderFunc

streamRecv StreamRecvMsgInterceptFunc
streamRecv StreamRecvMsgInterceptFunc
streamInterceptorProviderFunc StreamInterceptorProviderFunc
}

func (w *Worker) runWorker() error {
Expand Down Expand Up @@ -83,6 +84,13 @@ func (w *Worker) makeRequest(tv TickValue) error {

ctd := newCallData(w.mtd, w.workerID, reqNum, !w.config.disableTemplateFuncs, !w.config.disableTemplateData, w.config.funcs)

var streamInterceptor StreamInterceptor
if w.mtd.IsClientStreaming() || w.mtd.IsServerStreaming() {
if w.streamInterceptorProviderFunc != nil {
streamInterceptor = w.streamInterceptorProviderFunc(ctd)
}
}

reqMD, err := w.metadataProvider(ctd)
if err != nil {
return err
Expand Down Expand Up @@ -115,6 +123,8 @@ func (w *Worker) makeRequest(tv TickValue) error {
var msgProvider StreamMessageProviderFunc
if w.msgProvider != nil {
msgProvider = w.msgProvider
} else if streamInterceptor != nil {
msgProvider = streamInterceptor.Send
} else if w.mtd.IsClientStreaming() {
if w.config.streamDynamicMessages {
mp, err := newDynamicMessageProvider(w.mtd, w.config.data, w.config.streamCallCount, !w.config.disableTemplateFuncs, !w.config.disableTemplateData)
Expand Down Expand Up @@ -155,11 +165,11 @@ func (w *Worker) makeRequest(tv TickValue) error {

// RPC errors are handled via stats handler
if w.mtd.IsClientStreaming() && w.mtd.IsServerStreaming() {
_ = w.makeBidiRequest(&ctx, ctd, msgProvider)
_ = w.makeBidiRequest(&ctx, ctd, msgProvider, streamInterceptor)
} else if w.mtd.IsClientStreaming() {
_ = w.makeClientStreamingRequest(&ctx, ctd, msgProvider)
} else if w.mtd.IsServerStreaming() {
_ = w.makeServerStreamingRequest(&ctx, inputs[0])
_ = w.makeServerStreamingRequest(&ctx, inputs[0], streamInterceptor)
} else {
_ = w.makeUnaryRequest(&ctx, reqMD, inputs[0])
}
Expand Down Expand Up @@ -314,7 +324,7 @@ func (w *Worker) makeClientStreamingRequest(ctx *context.Context,
return nil
}

func (w *Worker) makeServerStreamingRequest(ctx *context.Context, input *dynamic.Message) error {
func (w *Worker) makeServerStreamingRequest(ctx *context.Context, input *dynamic.Message, streamInterceptor StreamInterceptor) error {
var callOptions = []grpc.CallOption{}
if w.config.enableCompression {
callOptions = append(callOptions, grpc.UseCompressor(gzip.Name))
Expand Down Expand Up @@ -388,6 +398,18 @@ func (w *Worker) makeServerStreamingRequest(ctx *context.Context, input *dynamic
}
}

if streamInterceptor != nil {
if converted, ok := res.(*dynamic.Message); ok {
err = streamInterceptor.Recv(converted, err)
if errors.Is(err, ErrEndStream) && !interceptCanceled {
interceptCanceled = true
err = nil

callCancel()
}
}
}

if err != nil {
if err == io.EOF {
err = nil
Expand Down Expand Up @@ -415,7 +437,7 @@ func (w *Worker) makeServerStreamingRequest(ctx *context.Context, input *dynamic
}

func (w *Worker) makeBidiRequest(ctx *context.Context,
ctd *CallData, messageProvider StreamMessageProviderFunc) error {
ctd *CallData, messageProvider StreamMessageProviderFunc, streamInterceptor StreamInterceptor) error {

var callOptions = []grpc.CallOption{}

Expand Down Expand Up @@ -494,6 +516,19 @@ func (w *Worker) makeBidiRequest(ctx *context.Context,
}
}

if streamInterceptor != nil {
if converted, ok := res.(*dynamic.Message); ok {
iErr := streamInterceptor.Recv(converted, recvErr)
if errors.Is(iErr, ErrEndStream) && !interceptCanceled {
interceptCanceled = true
if len(cancel) == 0 {
cancel <- struct{}{}
}
recvErr = nil
}
}
}

if recvErr != nil {
close(recvDone)
break
Expand Down
Loading