diff --git a/pkg/proxy/codec.go b/pkg/proxy/codec.go index b255404..35bd07c 100644 --- a/pkg/proxy/codec.go +++ b/pkg/proxy/codec.go @@ -6,8 +6,9 @@ import ( "google.golang.org/protobuf/proto" ) -// RawBytesCodec sets the received bytes as is to the target, -// which must always be a pointer to a byte slice. +// RawBytesCodec sets the received bytes as-is to the target, +// whether it is a byte slice or a proto.Message. +// For proto.Message, it uses proto.Marshal and proto.Unmarshal. type RawBytesCodec struct{} // Marshal returns the received byte slice as is. diff --git a/pkg/proxy/codec_test.go b/pkg/proxy/codec_test.go new file mode 100644 index 0000000..c012ea3 --- /dev/null +++ b/pkg/proxy/codec_test.go @@ -0,0 +1,43 @@ +package proxy + +import ( + "testing" + + "github.com/stretchr/testify/require" + "google.golang.org/genproto/googleapis/rpc/errdetails" + "google.golang.org/protobuf/proto" +) + +func TestRawBytesCodec_Marshal(t *testing.T) { + t.Run("proto message", func(t *testing.T) { + bts, err := RawBytesCodec{}.Marshal(&errdetails.RequestInfo{RequestId: "1"}) + require.NoError(t, err) + expected, err := proto.Marshal(&errdetails.RequestInfo{RequestId: "1"}) + require.NoError(t, err) + require.Equal(t, expected, bts) + }) + + t.Run("byte slice", func(t *testing.T) { + bts, err := RawBytesCodec{}.Marshal(&[]byte{1, 2, 3}) + require.NoError(t, err) + require.Equal(t, []byte{1, 2, 3}, bts) + }) +} + +func TestRawBytesCodec_Unmarshal(t *testing.T) { + t.Run("proto message", func(t *testing.T) { + bts, err := proto.Marshal(&errdetails.RequestInfo{RequestId: "1"}) + require.NoError(t, err) + + msg := &errdetails.RequestInfo{} + require.NoError(t, RawBytesCodec{}.Unmarshal(bts, msg)) + require.Truef(t, proto.Equal(&errdetails.RequestInfo{RequestId: "1"}, msg), "got: %v", msg) + }) + + t.Run("byte slice", func(t *testing.T) { + var bts []byte + err := RawBytesCodec{}.Unmarshal([]byte{1, 2, 3}, &bts) + require.NoError(t, err) + require.Equal(t, []byte{1, 2, 3}, bts) + }) +} diff --git a/pkg/proxy/middleware/middleware.go b/pkg/proxy/middleware/middleware.go index 546157f..9c56719 100644 --- a/pkg/proxy/middleware/middleware.go +++ b/pkg/proxy/middleware/middleware.go @@ -28,8 +28,6 @@ func Chain(base grpc.StreamHandler, mws ...Middleware) grpc.StreamHandler { func AppInfo(app, author, version string) Middleware { return func(next grpc.StreamHandler) grpc.StreamHandler { return func(srv any, stream grpc.ServerStream) error { - ctx := stream.Context() - md := metadata.Pairs( "app", app, "author", author, @@ -37,7 +35,7 @@ func AppInfo(app, author, version string) Middleware { ) if err := stream.SetHeader(md); err != nil { - slog.WarnContext(ctx, "failed to send app info", slogx.Error(err)) + slog.WarnContext(stream.Context(), "failed to send app info", slogx.Error(err)) } return next(srv, stream) diff --git a/pkg/proxy/middleware/middleware_test.go b/pkg/proxy/middleware/middleware_test.go new file mode 100644 index 0000000..247ad3f --- /dev/null +++ b/pkg/proxy/middleware/middleware_test.go @@ -0,0 +1,77 @@ +package middleware + +import ( + "bytes" + "context" + "log/slog" + "testing" + + "github.com/Semior001/groxy/pkg/proxy/mocks" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "google.golang.org/grpc" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/metadata" + "google.golang.org/grpc/status" +) + +func TestAppInfo(t *testing.T) { + mw := AppInfo("app", "author", "version") + var header metadata.MD + ss := &mocks.ServerStreamMock{ + SetHeaderFunc: func(md metadata.MD) error { + header = md + return nil + }, + } + + err := mw(func(_ any, _ grpc.ServerStream) error { return nil })(nil, ss) + require.NoError(t, err) + + assert.Equal(t, metadata.Pairs( + "app", "app", + "author", "author", + "version", "version", + ), header) +} + +func TestRecoverer(t *testing.T) { + bts := bytes.NewBuffer(nil) + slog.SetDefault(slog.New(slog.NewTextHandler(bts, &slog.HandlerOptions{}))) + mw := Recoverer(func(_ any, _ grpc.ServerStream) error { panic("test") }) + var err error + require.NotPanics(t, func() { + err = mw(nil, &mocks.ServerStreamMock{ + ContextFunc: func() context.Context { return context.Background() }, + }) + }) + st, ok := status.FromError(err) + require.True(t, ok) + assert.Equal(t, codes.ResourceExhausted, st.Code()) + assert.Equal(t, "{groxy} panic", st.Message()) +} + +func TestChain(t *testing.T) { + var calls []string + mw1 := func(next grpc.StreamHandler) grpc.StreamHandler { + return func(srv any, stream grpc.ServerStream) error { + calls = append(calls, "mw1") + return next(srv, stream) + } + } + mw2 := func(next grpc.StreamHandler) grpc.StreamHandler { + return func(srv any, stream grpc.ServerStream) error { + calls = append(calls, "mw2") + return next(srv, stream) + } + } + mw3 := func(next grpc.StreamHandler) grpc.StreamHandler { + return func(srv any, stream grpc.ServerStream) error { + calls = append(calls, "mw3") + return next(srv, stream) + } + } + h := Chain(func(_ any, _ grpc.ServerStream) error { return nil }, mw1, mw2, mw3) + require.NoError(t, h(nil, nil)) + assert.Equal(t, []string{"mw1", "mw2", "mw3"}, calls) +} diff --git a/pkg/proxy/mocks/mocks.go b/pkg/proxy/mocks/mocks.go new file mode 100644 index 0000000..61484c1 --- /dev/null +++ b/pkg/proxy/mocks/mocks.go @@ -0,0 +1,347 @@ +// Code generated by moq; DO NOT EDIT. +// github.com/matryer/moq + +package mocks + +import ( + "context" + "github.com/Semior001/groxy/pkg/discovery" + "google.golang.org/grpc/metadata" + "sync" +) + +// MatcherMock is a mock implementation of proxy.Matcher. +// +// func TestSomethingThatUsesMatcher(t *testing.T) { +// +// // make and configure a mocked proxy.Matcher +// mockedMatcher := &MatcherMock{ +// MatchMetadataFunc: func(s string, mD metadata.MD) discovery.Matches { +// panic("mock out the MatchMetadata method") +// }, +// } +// +// // use mockedMatcher in code that requires proxy.Matcher +// // and then make assertions. +// +// } +type MatcherMock struct { + // MatchMetadataFunc mocks the MatchMetadata method. + MatchMetadataFunc func(s string, mD metadata.MD) discovery.Matches + + // calls tracks calls to the methods. + calls struct { + // MatchMetadata holds details about calls to the MatchMetadata method. + MatchMetadata []struct { + // S is the s argument value. + S string + // MD is the mD argument value. + MD metadata.MD + } + } + lockMatchMetadata sync.RWMutex +} + +// MatchMetadata calls MatchMetadataFunc. +func (mock *MatcherMock) MatchMetadata(s string, mD metadata.MD) discovery.Matches { + if mock.MatchMetadataFunc == nil { + panic("MatcherMock.MatchMetadataFunc: method is nil but Matcher.MatchMetadata was just called") + } + callInfo := struct { + S string + MD metadata.MD + }{ + S: s, + MD: mD, + } + mock.lockMatchMetadata.Lock() + mock.calls.MatchMetadata = append(mock.calls.MatchMetadata, callInfo) + mock.lockMatchMetadata.Unlock() + return mock.MatchMetadataFunc(s, mD) +} + +// MatchMetadataCalls gets all the calls that were made to MatchMetadata. +// Check the length with: +// len(mockedMatcher.MatchMetadataCalls()) +func (mock *MatcherMock) MatchMetadataCalls() []struct { + S string + MD metadata.MD +} { + var calls []struct { + S string + MD metadata.MD + } + mock.lockMatchMetadata.RLock() + calls = mock.calls.MatchMetadata + mock.lockMatchMetadata.RUnlock() + return calls +} + +// ServerStreamMock is a mock implementation of proxy.ServerStream. +// +// func TestSomethingThatUsesServerStream(t *testing.T) { +// +// // make and configure a mocked proxy.ServerStream +// mockedServerStream := &ServerStreamMock{ +// ContextFunc: func() context.Context { +// panic("mock out the Context method") +// }, +// RecvMsgFunc: func(m any) error { +// panic("mock out the RecvMsg method") +// }, +// SendHeaderFunc: func(mD metadata.MD) error { +// panic("mock out the SendHeader method") +// }, +// SendMsgFunc: func(m any) error { +// panic("mock out the SendMsg method") +// }, +// SetHeaderFunc: func(mD metadata.MD) error { +// panic("mock out the SetHeader method") +// }, +// SetTrailerFunc: func(mD metadata.MD) { +// panic("mock out the SetTrailer method") +// }, +// } +// +// // use mockedServerStream in code that requires proxy.ServerStream +// // and then make assertions. +// +// } +type ServerStreamMock struct { + // ContextFunc mocks the Context method. + ContextFunc func() context.Context + + // RecvMsgFunc mocks the RecvMsg method. + RecvMsgFunc func(m any) error + + // SendHeaderFunc mocks the SendHeader method. + SendHeaderFunc func(mD metadata.MD) error + + // SendMsgFunc mocks the SendMsg method. + SendMsgFunc func(m any) error + + // SetHeaderFunc mocks the SetHeader method. + SetHeaderFunc func(mD metadata.MD) error + + // SetTrailerFunc mocks the SetTrailer method. + SetTrailerFunc func(mD metadata.MD) + + // calls tracks calls to the methods. + calls struct { + // Context holds details about calls to the Context method. + Context []struct { + } + // RecvMsg holds details about calls to the RecvMsg method. + RecvMsg []struct { + // M is the m argument value. + M any + } + // SendHeader holds details about calls to the SendHeader method. + SendHeader []struct { + // MD is the mD argument value. + MD metadata.MD + } + // SendMsg holds details about calls to the SendMsg method. + SendMsg []struct { + // M is the m argument value. + M any + } + // SetHeader holds details about calls to the SetHeader method. + SetHeader []struct { + // MD is the mD argument value. + MD metadata.MD + } + // SetTrailer holds details about calls to the SetTrailer method. + SetTrailer []struct { + // MD is the mD argument value. + MD metadata.MD + } + } + lockContext sync.RWMutex + lockRecvMsg sync.RWMutex + lockSendHeader sync.RWMutex + lockSendMsg sync.RWMutex + lockSetHeader sync.RWMutex + lockSetTrailer sync.RWMutex +} + +// Context calls ContextFunc. +func (mock *ServerStreamMock) Context() context.Context { + if mock.ContextFunc == nil { + panic("ServerStreamMock.ContextFunc: method is nil but ServerStream.Context was just called") + } + callInfo := struct { + }{} + mock.lockContext.Lock() + mock.calls.Context = append(mock.calls.Context, callInfo) + mock.lockContext.Unlock() + return mock.ContextFunc() +} + +// ContextCalls gets all the calls that were made to Context. +// Check the length with: +// len(mockedServerStream.ContextCalls()) +func (mock *ServerStreamMock) ContextCalls() []struct { +} { + var calls []struct { + } + mock.lockContext.RLock() + calls = mock.calls.Context + mock.lockContext.RUnlock() + return calls +} + +// RecvMsg calls RecvMsgFunc. +func (mock *ServerStreamMock) RecvMsg(m any) error { + if mock.RecvMsgFunc == nil { + panic("ServerStreamMock.RecvMsgFunc: method is nil but ServerStream.RecvMsg was just called") + } + callInfo := struct { + M any + }{ + M: m, + } + mock.lockRecvMsg.Lock() + mock.calls.RecvMsg = append(mock.calls.RecvMsg, callInfo) + mock.lockRecvMsg.Unlock() + return mock.RecvMsgFunc(m) +} + +// RecvMsgCalls gets all the calls that were made to RecvMsg. +// Check the length with: +// len(mockedServerStream.RecvMsgCalls()) +func (mock *ServerStreamMock) RecvMsgCalls() []struct { + M any +} { + var calls []struct { + M any + } + mock.lockRecvMsg.RLock() + calls = mock.calls.RecvMsg + mock.lockRecvMsg.RUnlock() + return calls +} + +// SendHeader calls SendHeaderFunc. +func (mock *ServerStreamMock) SendHeader(mD metadata.MD) error { + if mock.SendHeaderFunc == nil { + panic("ServerStreamMock.SendHeaderFunc: method is nil but ServerStream.SendHeader was just called") + } + callInfo := struct { + MD metadata.MD + }{ + MD: mD, + } + mock.lockSendHeader.Lock() + mock.calls.SendHeader = append(mock.calls.SendHeader, callInfo) + mock.lockSendHeader.Unlock() + return mock.SendHeaderFunc(mD) +} + +// SendHeaderCalls gets all the calls that were made to SendHeader. +// Check the length with: +// len(mockedServerStream.SendHeaderCalls()) +func (mock *ServerStreamMock) SendHeaderCalls() []struct { + MD metadata.MD +} { + var calls []struct { + MD metadata.MD + } + mock.lockSendHeader.RLock() + calls = mock.calls.SendHeader + mock.lockSendHeader.RUnlock() + return calls +} + +// SendMsg calls SendMsgFunc. +func (mock *ServerStreamMock) SendMsg(m any) error { + if mock.SendMsgFunc == nil { + panic("ServerStreamMock.SendMsgFunc: method is nil but ServerStream.SendMsg was just called") + } + callInfo := struct { + M any + }{ + M: m, + } + mock.lockSendMsg.Lock() + mock.calls.SendMsg = append(mock.calls.SendMsg, callInfo) + mock.lockSendMsg.Unlock() + return mock.SendMsgFunc(m) +} + +// SendMsgCalls gets all the calls that were made to SendMsg. +// Check the length with: +// len(mockedServerStream.SendMsgCalls()) +func (mock *ServerStreamMock) SendMsgCalls() []struct { + M any +} { + var calls []struct { + M any + } + mock.lockSendMsg.RLock() + calls = mock.calls.SendMsg + mock.lockSendMsg.RUnlock() + return calls +} + +// SetHeader calls SetHeaderFunc. +func (mock *ServerStreamMock) SetHeader(mD metadata.MD) error { + if mock.SetHeaderFunc == nil { + panic("ServerStreamMock.SetHeaderFunc: method is nil but ServerStream.SetHeader was just called") + } + callInfo := struct { + MD metadata.MD + }{ + MD: mD, + } + mock.lockSetHeader.Lock() + mock.calls.SetHeader = append(mock.calls.SetHeader, callInfo) + mock.lockSetHeader.Unlock() + return mock.SetHeaderFunc(mD) +} + +// SetHeaderCalls gets all the calls that were made to SetHeader. +// Check the length with: +// len(mockedServerStream.SetHeaderCalls()) +func (mock *ServerStreamMock) SetHeaderCalls() []struct { + MD metadata.MD +} { + var calls []struct { + MD metadata.MD + } + mock.lockSetHeader.RLock() + calls = mock.calls.SetHeader + mock.lockSetHeader.RUnlock() + return calls +} + +// SetTrailer calls SetTrailerFunc. +func (mock *ServerStreamMock) SetTrailer(mD metadata.MD) { + if mock.SetTrailerFunc == nil { + panic("ServerStreamMock.SetTrailerFunc: method is nil but ServerStream.SetTrailer was just called") + } + callInfo := struct { + MD metadata.MD + }{ + MD: mD, + } + mock.lockSetTrailer.Lock() + mock.calls.SetTrailer = append(mock.calls.SetTrailer, callInfo) + mock.lockSetTrailer.Unlock() + mock.SetTrailerFunc(mD) +} + +// SetTrailerCalls gets all the calls that were made to SetTrailer. +// Check the length with: +// len(mockedServerStream.SetTrailerCalls()) +func (mock *ServerStreamMock) SetTrailerCalls() []struct { + MD metadata.MD +} { + var calls []struct { + MD metadata.MD + } + mock.lockSetTrailer.RLock() + calls = mock.calls.SetTrailer + mock.lockSetTrailer.RUnlock() + return calls +} diff --git a/pkg/proxy/proxy.go b/pkg/proxy/proxy.go index 3db11af..5b8f20c 100644 --- a/pkg/proxy/proxy.go +++ b/pkg/proxy/proxy.go @@ -17,13 +17,24 @@ import ( "google.golang.org/grpc/status" ) +//go:generate moq -out mocks/mocks.go --skip-ensure -pkg mocks . Matcher ServerStream + +// ServerStream is a gRPC server stream. +type ServerStream grpc.ServerStream + +// Matcher matches the request URI and incoming metadata to the +// registered rules. +type Matcher interface { + MatchMetadata(string, metadata.MD) discovery.Matches +} + // Server is a gRPC server. type Server struct { version string serverOpts []grpc.ServerOption defaultResponder func(stream grpc.ServerStream, firstRecv []byte) error - matcher *discovery.Service + matcher Matcher debug bool l net.Listener @@ -31,7 +42,7 @@ type Server struct { } // NewServer creates a new server. -func NewServer(m *discovery.Service, opts ...Option) *Server { +func NewServer(m Matcher, opts ...Option) *Server { s := &Server{ matcher: m, defaultResponder: func(_ grpc.ServerStream, _ []byte) error { diff --git a/pkg/proxy/proxy_test.go b/pkg/proxy/proxy_test.go new file mode 100644 index 0000000..3fab1f5 --- /dev/null +++ b/pkg/proxy/proxy_test.go @@ -0,0 +1,9 @@ +package proxy + +import ( + "testing" +) + +func TestServer_handle(t *testing.T) { + +}