Skip to content

Commit

Permalink
feat(cloudserver): turn DeadlineExceeded into appropriate gRPC status
Browse files Browse the repository at this point in the history
As internal context timeouts can be directly propagated up through the
cloudrunner-go middlewares, adding a conversion into a wrapped errorwith
a gRPC status code improves the reliability and monitoring of the system
by ensuring that the status code becomes DeadlineExceeded rather than
Internal.
  • Loading branch information
quoral committed Nov 16, 2023
1 parent 3baebd7 commit 17365f7
Show file tree
Hide file tree
Showing 4 changed files with 227 additions and 2 deletions.
25 changes: 23 additions & 2 deletions cloudserver/middleware.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@ package cloudserver

import (
"context"
"errors"
"fmt"
"runtime"

"go.einride.tech/cloudrunner/clouderror"
"go.einride.tech/cloudrunner/cloudrequestlog"
Expand Down Expand Up @@ -42,7 +44,16 @@ func (i *Middleware) GRPCUnaryServerInterceptor(
}
ctx, cancel := context.WithTimeout(ctx, i.Config.Timeout)
defer cancel()
return handler(ctx, req)
resp, err = handler(ctx, req)
if errors.Is(err, context.DeadlineExceeded) {
// below call is an inline version of cloudrunner.Wrap in order to avoid circular imports
return nil, clouderror.WrapCaller(
err,
status.New(codes.DeadlineExceeded, "context deadline exceeded"),
clouderror.NewCaller(runtime.Caller(1)),
)
}
return resp, err
}

// GRPCStreamServerInterceptor implements grpc.StreamServerInterceptor.
Expand All @@ -69,5 +80,15 @@ func (i *Middleware) GRPCStreamServerInterceptor(
ctx, cancel := context.WithTimeout(ss.Context(), i.Config.Timeout)
defer cancel()

return handler(srv, cloudstream.NewContextualServerStream(ctx, ss))
if err := handler(srv, cloudstream.NewContextualServerStream(ctx, ss)); err != nil {
if errors.Is(err, context.DeadlineExceeded) {
return clouderror.WrapCaller(
err,
status.New(codes.DeadlineExceeded, "context deadline exceeded"),
clouderror.NewCaller(runtime.Caller(1)),
)
}
return err
}
return nil
}
155 changes: 155 additions & 0 deletions cloudserver/middleware_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,155 @@
package cloudserver_test

import (
"context"
"log"
"net"
"testing"
"time"

testproto "github.com/grpc-ecosystem/go-grpc-middleware/testing/testproto"
"go.einride.tech/cloudrunner/cloudserver"
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/credentials/insecure"
"google.golang.org/grpc/status"
"google.golang.org/grpc/test/bufconn"
"gotest.tools/v3/assert"
)

const bufSize = 1024 * 1024

type Server struct {
panicOnRequest bool
deadlineExceeeded bool
}

// Ping implements mwitkow_testproto.TestServiceServer.
func (s *Server) Ping(context.Context, *testproto.PingRequest) (*testproto.PingResponse, error) {
if s.panicOnRequest {
panic("boom!")
}
if s.deadlineExceeeded {
return nil, context.DeadlineExceeded
}
return &testproto.PingResponse{}, nil
}

// PingEmpty implements mwitkow_testproto.TestServiceServer.
func (*Server) PingEmpty(context.Context, *testproto.Empty) (*testproto.PingResponse, error) {
panic("unimplemented")
}

// PingError implements mwitkow_testproto.TestServiceServer.
func (*Server) PingError(context.Context, *testproto.PingRequest) (*testproto.Empty, error) {
panic("unimplemented")
}

// PingList implements mwitkow_testproto.TestServiceServer.
func (*Server) PingList(*testproto.PingRequest, testproto.TestService_PingListServer) error {
panic("unimplemented")
}

// PingStream implements mwitkow_testproto.TestServiceServer.
func (s *Server) PingStream(out testproto.TestService_PingStreamServer) error {
if s.panicOnRequest {
panic("boom!")
}
if s.deadlineExceeeded {
return context.DeadlineExceeded
}
return out.Send(&testproto.PingResponse{})
}

var _ testproto.TestServiceServer = &Server{}

func bufDialer(lis *bufconn.Listener) func(context.Context, string) (net.Conn, error) {
return func(context.Context, string) (net.Conn, error) { return lis.Dial() }
}

func TestGRPCUnary_ContextTimeoutWithDeadlineExceededErr(t *testing.T) {
ctx := context.Background()
server, client := grpcSetup(ctx, t)
server.deadlineExceeeded = true

_, err := client.Ping(ctx, &testproto.PingRequest{})
assert.ErrorIs(t, err, status.Error(codes.DeadlineExceeded, "context deadline exceeded"))
}

func TestGRPCUnary_RescuePanicsWithStatusInternalError(t *testing.T) {
ctx := context.Background()
server, client := grpcSetup(ctx, t)
server.panicOnRequest = true

_, err := client.Ping(ctx, &testproto.PingRequest{})
assert.ErrorIs(t, err, status.Error(codes.Internal, "internal error"))
}

func TestGRPCStream_ContextTimeoutWithDeadlineExceededErr(t *testing.T) {
ctx := context.Background()
server, client := grpcSetup(ctx, t)
server.deadlineExceeeded = true

stream, err := client.PingStream(ctx)
assert.NilError(t, err) // while it looks strange, this is setting up the stream
_, err = stream.Recv()
assert.ErrorIs(t, err, status.Error(codes.DeadlineExceeded, "context deadline exceeded"))
}

func TestGRPCStream_RescuePanicsWithStatusInternalError(t *testing.T) {
ctx := context.Background()
server, client := grpcSetup(ctx, t)
server.panicOnRequest = true

stream, err := client.PingStream(ctx)
assert.NilError(t, err) // while it looks strange, this is setting up the stream

_, err = stream.Recv()
assert.ErrorIs(t, err, status.Error(codes.Internal, "internal error"))
}

func TestGRPCUnary_NoRequestError(t *testing.T) {
ctx := context.Background()
_, client := grpcSetup(ctx, t)

_, err := client.Ping(ctx, &testproto.PingRequest{})
assert.NilError(t, err)
}

func TestGRPCStream_NoRequestError(t *testing.T) {
ctx := context.Background()
_, client := grpcSetup(ctx, t)

stream, err := client.PingStream(ctx)
assert.NilError(t, err) // while it looks strange, this is setting up the stream

_, err = stream.Recv()
assert.NilError(t, err)
_, err = stream.Recv()
assert.Error(t, err, "EOF")
}

func grpcSetup(ctx context.Context, t *testing.T) (*Server, testproto.TestServiceClient) {
lis := bufconn.Listen(bufSize)
middleware := cloudserver.Middleware{Config: cloudserver.Config{Timeout: time.Second * 5}}
server := grpc.NewServer(
grpc.ChainUnaryInterceptor(middleware.GRPCUnaryServerInterceptor),
grpc.ChainStreamInterceptor(middleware.GRPCStreamServerInterceptor),
)
testServer := &Server{}
testproto.RegisterTestServiceServer(server, testServer)
go func() {
if err := server.Serve(lis); err != nil {
log.Fatalf("Server exited with error: %v", err)
}
}()
conn, err := grpc.DialContext(
ctx,
"bufnet",
grpc.WithContextDialer(bufDialer(lis)),
grpc.WithTransportCredentials(insecure.NewCredentials()),
)
assert.NilError(t, err)
client := testproto.NewTestServiceClient(conn)
return testServer, client
}
1 change: 1 addition & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ require (
github.com/GoogleCloudPlatform/opentelemetry-operations-go/exporter/trace v1.20.0
github.com/GoogleCloudPlatform/opentelemetry-operations-go/propagator v0.44.0
github.com/google/go-cmp v0.6.0
github.com/grpc-ecosystem/go-grpc-middleware v1.4.0
github.com/soheilhy/cmux v0.1.5
go.einride.tech/protobuf-sensitive v0.5.0
go.opencensus.io v0.24.0
Expand Down
Loading

0 comments on commit 17365f7

Please sign in to comment.