Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add a method filter for interceptors #459

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 43 additions & 0 deletions filter/client_interceptors.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
package filter

import (
"context"

"google.golang.org/grpc"
)

// UnaryClientMethods returns an interceptor that applies the provided interceptor only to outgoing unary calls to the specified methods.
// The allowlist parameter specifies whether the provided list of methods is to be treated as an allowlist (true) or a denylist (false).
// If it is an allowlist the interceptor will be applied only to the methods in the list; if it is a denylist the interceptor will be applied only to methods not in the list.
// The methods must be specified using the full name (e.g. "/package.service/method").
func UnaryClientMethods(interceptor grpc.UnaryClientInterceptor, allowlist bool, methods ...string) grpc.UnaryClientInterceptor {
if interceptor == nil {
panic("nil interceptor")
}
m := newMatchlist(methods, allowlist)

return func(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error {
if m.match(method) {
return interceptor(ctx, method, req, reply, cc, invoker, opts...)
}
return invoker(ctx, method, req, reply, cc, opts...)
}
}

// StreamClientMethods returns an interceptor that applies the provided interceptor only to outgoing unary calls to the specified methods.
// The allowlist parameter specifies whether the provided list of methods is to be treated as an allowlist (true) or a denylist (false).
// If it is an allowlist the interceptor will be applied only to the methods in the list; if it is a denylist the interceptor will be applied only to methods not in the list.
// The methods must be specified using the full name (e.g. "/package.service/method").
func StreamClientMethods(interceptor grpc.StreamClientInterceptor, allowlist bool, methods ...string) grpc.StreamClientInterceptor {
if interceptor == nil {
panic("nil interceptor")
}
m := newMatchlist(methods, allowlist)

return func(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, method string, streamer grpc.Streamer, opts ...grpc.CallOption) (grpc.ClientStream, error) {
if m.match(method) {
return interceptor(ctx, desc, cc, method, streamer, opts...)
}
return streamer(ctx, desc, cc, method, opts...)
}
}
87 changes: 87 additions & 0 deletions filter/client_interceptors_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
package filter_test

import (
"context"
"testing"

"github.com/grpc-ecosystem/go-grpc-middleware/filter"
grpc_testing "github.com/grpc-ecosystem/go-grpc-middleware/testing"
pb_testproto "github.com/grpc-ecosystem/go-grpc-middleware/testing/testproto"
"github.com/stretchr/testify/require"
"github.com/stretchr/testify/suite"
"google.golang.org/grpc"
)

type noopUnaryClientInterceptor struct {
called bool
}

func (i *noopUnaryClientInterceptor) intercept(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error {
i.called = true
return invoker(ctx, method, req, reply, cc, opts...)
}

type noopStreamClientInterceptor struct {
called bool
}

func (i *noopStreamClientInterceptor) intercept(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, method string, streamer grpc.Streamer, opts ...grpc.CallOption) (grpc.ClientStream, error) {
i.called = true
return streamer(ctx, desc, cc, method, opts...)
}

func TestClientMethods(t *testing.T) {
service := &someService{
TestPingService: grpc_testing.TestPingService{T: t},
}
si := &noopStreamClientInterceptor{}
ui := &noopUnaryClientInterceptor{}
suite.Run(t, &ClientFilterSuite{
srv: service,
si: si,
ui: ui,
InterceptorTestSuite: &grpc_testing.InterceptorTestSuite{
TestService: service,
ClientOpts: []grpc.DialOption{
grpc.WithUnaryInterceptor(filter.UnaryClientMethods(ui.intercept, true, "/mwitkow.testproto.TestService/Ping")),
grpc.WithStreamInterceptor(filter.StreamClientMethods(si.intercept, true, "/mwitkow.testproto.TestService/PingStream")),
},
},
})
}

type ClientFilterSuite struct {
*grpc_testing.InterceptorTestSuite
srv *someService
si *noopStreamClientInterceptor
ui *noopUnaryClientInterceptor
}

func (s *ClientFilterSuite) SetupTest() {
s.srv.pingCalled = false
s.srv.pingEmptyCalled = false
s.srv.pingStreamCalled = false
s.si.called = false
s.ui.called = false
}

func (s *ClientFilterSuite) TestUnary_CallAllowedUnaryMethod() {
res, err := s.Client.Ping(s.SimpleCtx(), &pb_testproto.PingRequest{Value: "hello"})
require.NoError(s.T(), err)
require.Equal(s.T(), res.Value, "hello")
require.True(s.T(), s.srv.pingCalled)
require.False(s.T(), s.srv.pingEmptyCalled)
require.False(s.T(), s.srv.pingStreamCalled)
require.True(s.T(), s.ui.called) // allowed
require.False(s.T(), s.si.called)
}

func (s *ClientFilterSuite) TestUnary_CallDisallowedUnaryMethod() {
_, err := s.Client.PingEmpty(s.SimpleCtx(), &pb_testproto.Empty{})
require.NoError(s.T(), err)
require.False(s.T(), s.srv.pingCalled)
require.True(s.T(), s.srv.pingEmptyCalled)
require.False(s.T(), s.srv.pingStreamCalled)
require.False(s.T(), s.ui.called) // disallowed
require.False(s.T(), s.si.called)
}
19 changes: 19 additions & 0 deletions filter/matchlist.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
package filter

type matchlist struct {
m map[string]struct{}
p bool
}

func newMatchlist(s []string, matchPresence bool) *matchlist {
m := make(map[string]struct{}, len(s))
for _, e := range s {
m[e] = struct{}{}
}
return &matchlist{m, matchPresence}
}

func (m *matchlist) match(s string) bool {
_, found := m.m[s]
return found == m.p
}
55 changes: 55 additions & 0 deletions filter/matchlist_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
package filter

import (
"fmt"
"strconv"
"strings"
"testing"
)

func TestMatchlist(t *testing.T) {
cases := map[string]struct {
list []string
presence bool
match string
res bool
}{
"positive match": {[]string{"a", "b"}, true, "a", true},
"positive match 2": {[]string{"a", "b"}, true, "b", true},
"positive no match": {[]string{"a", "b"}, true, "c", false},
"positive no match case insensitive": {[]string{"a", "b"}, true, "A", false},
"negative match": {[]string{"a", "b"}, false, "a", false},
"negative match 2": {[]string{"a", "b"}, false, "b", false},
"negative no match": {[]string{"a", "b"}, false, "c", true},
"negative no match case insensitive": {[]string{"a", "b"}, false, "A", true},

"positive empty list": {[]string{}, true, "a", false},
"negative empty list": {[]string{}, false, "a", true},
}
for n, c := range cases {
t.Run(n, func(t *testing.T) {
t.Log(c.list, c.match, c.presence, c.res)
m := newMatchlist(c.list, c.presence)
r := m.match(c.match)
if r != c.res {
t.Error("wrong result")
}
})
}
}

func BenchmarkMatchlist(b *testing.B) {
for _, i := range []int{0, 1, 2, 3, 4, 5, 6, 8, 10, 15, 20, 25, 30, 40, 50, 75, 100, 300, 1000} {
var s []string
for j := 0; j < i; j++ {
s = append(s, fmt.Sprintf("%30d", j))
}
m := newMatchlist(s, true)
c := strings.Repeat(" ", 30)
b.Run(strconv.Itoa(i), func(b *testing.B) {
for j := 0; j < b.N; j++ {
_ = m.match(c)
}
})
}
}
65 changes: 65 additions & 0 deletions filter/server_interceptors.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
package filter

import (
"context"

"google.golang.org/grpc"
)

// UnaryServerMethods returns an interceptor that applies the provided interceptor only to incoming unary calls to the specified methods.
// The allowlist parameter specifies whether the provided list of methods is to be treated as an allowlist (true) or a denylist (false).
// If it is an allowlist the interceptor will be applied only to the methods in the list; if it is a denylist the interceptor will be applied only to methods not in the list.
// The methods must be specified using the full name (e.g. "/package.service/method").
func UnaryServerMethods(interceptor grpc.UnaryServerInterceptor, allowlist bool, methods ...string) grpc.UnaryServerInterceptor {
if interceptor == nil {
panic("nil interceptor")
}
m := newMatchlist(methods, allowlist)

return func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) {
if m.match(info.FullMethod) {
return interceptor(ctx, req, info, handler)
}
return handler(ctx, req)
}
}

/*
func UnaryServerMethodsInterceptor(interceptor grpc.UnaryServerInterceptor, allowlist bool, methods ...string) grpc.UnaryServerInterceptor {
m := newMatchlist(methods, allowlist)

return UnaryServerConditionInterceptor(interceptor, func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo) bool {
return m.match(info.FullMethod)
})
}

func UnaryServerConditionInterceptor(interceptor grpc.UnaryServerInterceptor, condition func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo) bool) grpc.UnaryServerInterceptor {
if interceptor == nil {
panic("nil interceptor")
}
return func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) {
if condition(ctx, req, info) {
return interceptor(ctx, req, info, handler)
}
return handler(ctx, req)
}
}
*/
Comment on lines +27 to +47
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is an alternate implementation style in case we want to support plugging in arbitrary conditions


// StreamServerMethods returns an interceptor that applies the provided interceptor only to incoming stream calls to the specified methods.
// The allowlist parameter specifies whether the provided list of methods is to be treated as an allowlist (true) or a denylist (false).
// If it is an allowlist the interceptor will be applied only to the methods in the list; if it is a denylist the interceptor will be applied only to methods not in the list.
// The methods must be specified using the full name (e.g. "/package.service/method").
func StreamServerMethods(interceptor grpc.StreamServerInterceptor, allowlist bool, methods ...string) grpc.StreamServerInterceptor {
if interceptor == nil {
panic("nil interceptor")
}
m := newMatchlist(methods, allowlist)

return func(srv interface{}, ss grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error {
if m.match(info.FullMethod) {
return interceptor(srv, ss, info, handler)
}
return handler(srv, ss)
}
}
115 changes: 115 additions & 0 deletions filter/server_interceptors_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
package filter_test

import (
"context"
"testing"

"github.com/grpc-ecosystem/go-grpc-middleware/filter"
grpc_testing "github.com/grpc-ecosystem/go-grpc-middleware/testing"
pb_testproto "github.com/grpc-ecosystem/go-grpc-middleware/testing/testproto"
"github.com/stretchr/testify/require"
"github.com/stretchr/testify/suite"
"google.golang.org/grpc"
)

type someService struct {
grpc_testing.TestPingService
pingCalled bool
pingEmptyCalled bool
pingStreamCalled bool
}

func (s *someService) Ping(ctx context.Context, ping *pb_testproto.PingRequest) (*pb_testproto.PingResponse, error) {
s.pingCalled = true
return s.TestPingService.Ping(ctx, ping)
}

func (s *someService) PingEmpty(ctx context.Context, empty *pb_testproto.Empty) (*pb_testproto.PingResponse, error) {
s.pingEmptyCalled = true
return s.TestPingService.PingEmpty(ctx, empty)
}

func (s *someService) PingStream(stream pb_testproto.TestService_PingStreamServer) error {
s.pingStreamCalled = true
return s.TestPingService.PingStream(stream)
}

type noopUnaryServerInterceptor struct {
called bool
}

func (i *noopUnaryServerInterceptor) intercept(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) {
i.called = true
return handler(ctx, req)
}

type noopStreamServerInterceptor struct {
called bool
}

func (i *noopStreamServerInterceptor) intercept(srv interface{}, ss grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error {
i.called = true
return handler(srv, ss)
}

func TestServerMethods(t *testing.T) {
service := &someService{
TestPingService: grpc_testing.TestPingService{T: t},
}
si := &noopStreamServerInterceptor{}
ui := &noopUnaryServerInterceptor{}
suite.Run(t, &FilterSuite{
srv: service,
si: si,
ui: ui,
InterceptorTestSuite: &grpc_testing.InterceptorTestSuite{
TestService: service,
/*
ClientOpts: []grpc.DialOption{
grpc.WithStreamInterceptor(filter.StreamClientMethod()),
grpc.WithUnaryInterceptor(filter.UnaryClientMethod()),
},
*/
ServerOpts: []grpc.ServerOption{
grpc.UnaryInterceptor(filter.UnaryServerMethods(ui.intercept, true, "/mwitkow.testproto.TestService/Ping")),
grpc.StreamInterceptor(filter.StreamServerMethods(si.intercept, true, "/mwitkow.testproto.TestService/PingStream")),
},
},
})
}

type FilterSuite struct {
*grpc_testing.InterceptorTestSuite
srv *someService
si *noopStreamServerInterceptor
ui *noopUnaryServerInterceptor
}

func (s *FilterSuite) SetupTest() {
s.srv.pingCalled = false
s.srv.pingEmptyCalled = false
s.srv.pingStreamCalled = false
s.si.called = false
s.ui.called = false
}

func (s *FilterSuite) TestUnary_CallAllowedUnaryMethod() {
res, err := s.Client.Ping(s.SimpleCtx(), &pb_testproto.PingRequest{Value: "hello"})
require.NoError(s.T(), err)
require.Equal(s.T(), res.Value, "hello")
require.True(s.T(), s.srv.pingCalled)
require.False(s.T(), s.srv.pingEmptyCalled)
require.False(s.T(), s.srv.pingStreamCalled)
require.True(s.T(), s.ui.called) // allowed
require.False(s.T(), s.si.called)
}

func (s *FilterSuite) TestUnary_CallDisallowedUnaryMethod() {
_, err := s.Client.PingEmpty(s.SimpleCtx(), &pb_testproto.Empty{})
require.NoError(s.T(), err)
require.False(s.T(), s.srv.pingCalled)
require.True(s.T(), s.srv.pingEmptyCalled)
require.False(s.T(), s.srv.pingStreamCalled)
require.False(s.T(), s.ui.called) // disallowed
require.False(s.T(), s.si.called)
}