Skip to content

Commit a0a28d0

Browse files
committed
generate rpcs with context as a first argument
1 parent 0492cd2 commit a0a28d0

File tree

6 files changed

+87
-25
lines changed

6 files changed

+87
-25
lines changed

protobuf/protobuf.go

+2
Original file line numberDiff line numberDiff line change
@@ -331,6 +331,8 @@ type RPC struct {
331331
Recv string
332332
// Method is the name of the Go method or function.
333333
Method string
334+
// HasCtx reports whether the Go function accepts context.
335+
HasCtx bool
334336
// HasError reports whether the Go function returns an error.
335337
HasError bool
336338
// IsVariadic reports whether the Go function is variadic or not.

protobuf/transform.go

+21-1
Original file line numberDiff line numberDiff line change
@@ -110,15 +110,17 @@ func (t *Transformer) transformFunc(pkg *Package, f *scanner.Func, names nameSet
110110
receiverName = n.Name
111111
}
112112

113+
input, hasCtx := removeFirstCtx(f.Input)
113114
output, hasError := removeLastError(f.Output)
114115
rpc := &RPC{
115116
Docs: f.Doc,
116117
Name: name,
117118
Recv: receiverName,
118119
Method: f.Name,
120+
HasCtx: hasCtx,
119121
HasError: hasError,
120122
IsVariadic: f.IsVariadic,
121-
Input: t.transformInputTypes(pkg, f.Input, names, name),
123+
Input: t.transformInputTypes(pkg, input, names, name),
122124
Output: t.transformOutputTypes(pkg, output, names, name),
123125
}
124126
if rpc.Input == nil || rpc.Output == nil {
@@ -383,6 +385,17 @@ func (t *Transformer) findMapping(name string) *ProtoType {
383385
return typ
384386
}
385387

388+
func removeFirstCtx(types []scanner.Type) ([]scanner.Type, bool) {
389+
if len(types) > 0 {
390+
first := types[0]
391+
if isCtx(first) {
392+
return types[1:], true
393+
}
394+
}
395+
396+
return types, false
397+
}
398+
386399
func removeLastError(types []scanner.Type) ([]scanner.Type, bool) {
387400
if len(types) > 0 {
388401
ln := len(types)
@@ -400,6 +413,13 @@ func isNamed(typ scanner.Type) bool {
400413
return ok
401414
}
402415

416+
func isCtx(typ scanner.Type) bool {
417+
if ctx, ok := typ.(*scanner.Named); ok {
418+
return ctx.Path == "context" && ctx.Name == "Context"
419+
}
420+
return false
421+
}
422+
403423
func isError(typ scanner.Type) bool {
404424
if err, ok := typ.(*scanner.Named); ok {
405425
return err.Path == "" && err.Name == "error"

resolver/resolver.go

+4-3
Original file line numberDiff line numberDiff line change
@@ -24,9 +24,10 @@ type Resolver struct {
2424
func New() *Resolver {
2525
return &Resolver{
2626
customTypes: map[string]struct{}{
27-
"time.Time": {},
28-
"time.Duration": {},
29-
"error": {},
27+
"time.Time": {},
28+
"time.Duration": {},
29+
"context.Context": {},
30+
"error": {},
3031
},
3132
}
3233
}

rpc/context.go

+8-4
Original file line numberDiff line numberDiff line change
@@ -48,15 +48,19 @@ func (c *context) findSignature(rpc *protobuf.RPC) *types.Signature {
4848

4949
func (c *context) argumentType(rpc *protobuf.RPC) string {
5050
signature := c.findSignature(rpc)
51-
obj := firstTypeName(signature.Params())
51+
skip := 0
52+
if rpc.HasCtx {
53+
skip++
54+
}
55+
obj := firstTypeName(skip, signature.Params())
5256
c.addImport(obj.Pkg().Path())
5357

5458
return c.objectNameInContext(obj)
5559
}
5660

5761
func (c *context) returnType(rpc *protobuf.RPC) string {
5862
signature := c.findSignature(rpc)
59-
obj := firstTypeName(signature.Results())
63+
obj := firstTypeName(0, signature.Results())
6064
c.addImport(obj.Pkg().Path())
6165

6266
return c.objectNameInContext(obj)
@@ -72,8 +76,8 @@ func (c *context) objectNameInContext(obj types.Object) string {
7276
}
7377
}
7478

75-
func firstTypeName(tuple *types.Tuple) types.Object {
76-
t := tuple.At(0).Type()
79+
func firstTypeName(skip int, tuple *types.Tuple) types.Object {
80+
t := tuple.At(skip).Type()
7781
if inner, ok := t.(*types.Pointer); ok {
7882
t = inner.Elem()
7983
}

rpc/rpc.go

+15-2
Original file line numberDiff line numberDiff line change
@@ -158,7 +158,7 @@ func (g *Generator) genMethodType(ctx *context, rpc *protobuf.RPC) *ast.FuncType
158158

159159
return &ast.FuncType{
160160
Params: fields(
161-
field("ctx", ast.NewIdent("context.Context")),
161+
field("ctx", ast.NewIdent("xcontext.Context")),
162162
field("in", ptr(ast.NewIdent(in))),
163163
),
164164
Results: fields(
@@ -174,6 +174,9 @@ func (g *Generator) genMethodCall(ctx *context, rpc *protobuf.RPC) ast.Expr {
174174
call.Fun = ast.NewIdent(fmt.Sprintf("s.%s.%s", rpc.Recv, rpc.Method))
175175
}
176176

177+
if rpc.HasCtx {
178+
call.Args = append(call.Args, ast.NewIdent("ctx"))
179+
}
177180
if rpc.IsVariadic {
178181
call.Ellipsis = token.Pos(1)
179182
}
@@ -329,7 +332,7 @@ func (g *Generator) buildFile(ctx *context, decls []ast.Decl) *ast.File {
329332
Name: ast.NewIdent(ctx.pkg.Name()),
330333
}
331334

332-
var specs = []ast.Spec{newImport("golang.org/x/net/context")}
335+
var specs = []ast.Spec{newNamedImport("xcontext", "golang.org/x/net/context")}
333336
for _, i := range ctx.imports {
334337
specs = append(specs, newImport(i))
335338
}
@@ -380,6 +383,16 @@ func newImport(path string) *ast.ImportSpec {
380383
}
381384
}
382385

386+
func newNamedImport(name, path string) *ast.ImportSpec {
387+
return &ast.ImportSpec{
388+
Name: &ast.Ident{Name: name},
389+
Path: &ast.BasicLit{
390+
Kind: token.STRING,
391+
Value: fmt.Sprintf(`"%s"`, removeGoPath(path)),
392+
},
393+
}
394+
}
395+
383396
func field(name string, typ ast.Expr) *ast.Field {
384397
return &ast.Field{
385398
Names: []*ast.Ident{ast.NewIdent(name)},

rpc/rpc_test.go

+37-15
Original file line numberDiff line numberDiff line change
@@ -47,61 +47,67 @@ func (s *RPCSuite) TestDeclConstructor() {
4747
s.Equal(expectedConstructor, output)
4848
}
4949

50-
const expectedFuncNotGenerated = `func (s *FooServer) DoFoo(ctx context.Context, in *Foo) (result *Bar, err error) {
50+
const expectedFuncNotGenerated = `func (s *FooServer) DoFoo(ctx xcontext.Context, in *Foo) (result *Bar, err error) {
5151
result = new(Bar)
5252
result = DoFoo(in)
5353
return
5454
}`
5555

56-
const expectedFuncNotGeneratedAndNotNullable = `func (s *FooServer) DoFoo(ctx context.Context, in *Foo) (result *Bar, err error) {
56+
const expectedFuncNotGeneratedCtx = `func (s *FooServer) DoFooCtx(ctx xcontext.Context, in *Foo) (result *Bar, err error) {
57+
result = new(Bar)
58+
result = DoFooCtx(ctx, in)
59+
return
60+
}`
61+
62+
const expectedFuncNotGeneratedAndNotNullable = `func (s *FooServer) DoFoo(ctx xcontext.Context, in *Foo) (result *Bar, err error) {
5763
result = new(Bar)
5864
aux := DoFoo(in)
5965
result = &aux
6066
return
6167
}`
6268

63-
const expectedFuncNotGeneratedAndNotNullableIn = `func (s *FooServer) DoFoo(ctx context.Context, in *Foo) (result *Bar, err error) {
69+
const expectedFuncNotGeneratedAndNotNullableIn = `func (s *FooServer) DoFoo(ctx xcontext.Context, in *Foo) (result *Bar, err error) {
6470
result = new(Bar)
6571
result = DoFoo(*in)
6672
return
6773
}`
6874

69-
const expectedFuncGenerated = `func (s *FooServer) DoFoo(ctx context.Context, in *FooRequest) (result *FooResponse, err error) {
75+
const expectedFuncGenerated = `func (s *FooServer) DoFoo(ctx xcontext.Context, in *FooRequest) (result *FooResponse, err error) {
7076
result = new(FooResponse)
7177
result.Result1, result.Result2, result.Result3 = DoFoo(in.Arg1, in.Arg2, in.Arg3)
7278
return
7379
}`
7480

75-
const expectedFuncGeneratedVariadic = `func (s *FooServer) DoFoo(ctx context.Context, in *FooRequest) (result *FooResponse, err error) {
81+
const expectedFuncGeneratedVariadic = `func (s *FooServer) DoFoo(ctx xcontext.Context, in *FooRequest) (result *FooResponse, err error) {
7682
result = new(FooResponse)
7783
result.Result1, result.Result2, result.Result3 = DoFoo(in.Arg1, in.Arg2, in.Arg3...)
7884
return
7985
}`
8086

81-
const expectedFuncGeneratedWithError = `func (s *FooServer) DoFoo(ctx context.Context, in *FooRequest) (result *FooResponse, err error) {
87+
const expectedFuncGeneratedWithError = `func (s *FooServer) DoFoo(ctx xcontext.Context, in *FooRequest) (result *FooResponse, err error) {
8288
result = new(FooResponse)
8389
result.Result1, result.Result2, result.Result3, err = DoFoo(in.Arg1, in.Arg2, in.Arg3)
8490
return
8591
}`
8692

87-
const expectedMethod = `func (s *FooServer) Fooer_DoFoo(ctx context.Context, in *FooRequest) (result *FooResponse, err error) {
93+
const expectedMethod = `func (s *FooServer) Fooer_DoFoo(ctx xcontext.Context, in *FooRequest) (result *FooResponse, err error) {
8894
result = new(FooResponse)
8995
result.Result1, result.Result2, result.Result3, err = s.Fooer.DoFoo(in.Arg1, in.Arg2, in.Arg3)
9096
return
9197
}`
9298

93-
const expectedMethodExternalInput = `func (s *FooServer) T_Foo(ctx context.Context, in *ast.BlockStmt) (result *T_FooResponse, err error) {
99+
const expectedMethodExternalInput = `func (s *FooServer) T_Foo(ctx xcontext.Context, in *ast.BlockStmt) (result *T_FooResponse, err error) {
94100
result = new(T_FooResponse)
95101
_ = s.T.Foo(in)
96102
return
97103
}`
98104

99-
const expectedFuncEmptyInAndOut = `func (s *FooServer) Empty(ctx context.Context, in *Empty) (result *Empty, err error) {
105+
const expectedFuncEmptyInAndOut = `func (s *FooServer) Empty(ctx xcontext.Context, in *Empty) (result *Empty, err error) {
100106
Empty()
101107
return
102108
}`
103109

104-
const expectedFuncEmptyInAndOutWithError = `func (s *FooServer) Empty(ctx context.Context, in *Empty) (result *Empty, err error) {
110+
const expectedFuncEmptyInAndOutWithError = `func (s *FooServer) Empty(ctx xcontext.Context, in *Empty) (result *Empty, err error) {
105111
err = Empty()
106112
return
107113
}`
@@ -122,6 +128,17 @@ func (s *RPCSuite) TestDeclMethod() {
122128
},
123129
expectedFuncNotGenerated,
124130
},
131+
{
132+
"func not generated with ctx",
133+
&protobuf.RPC{
134+
Name: "DoFooCtx",
135+
Method: "DoFooCtx",
136+
HasCtx: true,
137+
Input: nullable(protobuf.NewNamed("", "Foo")),
138+
Output: nullable(protobuf.NewNamed("", "Bar")),
139+
},
140+
expectedFuncNotGeneratedCtx,
141+
},
125142
{
126143
"func output not generated and not nullable",
127144
&protobuf.RPC{
@@ -295,7 +312,7 @@ func (s *RPCSuite) TestDeclMethod() {
295312
const expectedGeneratedFile = `package subpkg
296313
297314
import (
298-
"golang.org/x/net/context"
315+
xcontext "golang.org/x/net/context"
299316
)
300317
301318
type subpkgServiceServer struct {
@@ -304,22 +321,22 @@ type subpkgServiceServer struct {
304321
func NewSubpkgServiceServer() *subpkgServiceServer {
305322
return &subpkgServiceServer{}
306323
}
307-
func (s *subpkgServiceServer) Generated(ctx context.Context, in *GeneratedRequest) (result *GeneratedResponse, err error) {
324+
func (s *subpkgServiceServer) Generated(ctx xcontext.Context, in *GeneratedRequest) (result *GeneratedResponse, err error) {
308325
result = new(GeneratedResponse)
309326
result.Result1, err = Generated(in.Arg1)
310327
return
311328
}
312-
func (s *subpkgServiceServer) MyContainer_Name(ctx context.Context, in *MyContainer_NameRequest) (result *MyContainer_NameResponse, err error) {
329+
func (s *subpkgServiceServer) MyContainer_Name(ctx xcontext.Context, in *MyContainer_NameRequest) (result *MyContainer_NameResponse, err error) {
313330
result = new(MyContainer_NameResponse)
314331
result.Result1 = s.MyContainer.Name()
315332
return
316333
}
317-
func (s *subpkgServiceServer) Point_GeneratedMethod(ctx context.Context, in *Point_GeneratedMethodRequest) (result *Point, err error) {
334+
func (s *subpkgServiceServer) Point_GeneratedMethod(ctx xcontext.Context, in *Point_GeneratedMethodRequest) (result *Point, err error) {
318335
result = new(Point)
319336
result = s.Point.GeneratedMethod(in.Arg1)
320337
return
321338
}
322-
func (s *subpkgServiceServer) Point_GeneratedMethodOnPointer(ctx context.Context, in *Point_GeneratedMethodOnPointerRequest) (result *Point, err error) {
339+
func (s *subpkgServiceServer) Point_GeneratedMethodOnPointer(ctx xcontext.Context, in *Point_GeneratedMethodOnPointerRequest) (result *Point, err error) {
323340
result = new(Point)
324341
result = s.Point.GeneratedMethodOnPointer(in.Arg1)
325342
return
@@ -362,6 +379,7 @@ func TestConstructorName(t *testing.T) {
362379
const testPkg = `package fake
363380
364381
import "go/ast"
382+
import "context"
365383
366384
type Foo struct{}
367385
type Bar struct {}
@@ -370,6 +388,10 @@ func DoFoo(in *Foo) *Bar {
370388
return nil
371389
}
372390
391+
func DoFooCtx(ctx context.Context, in *Foo) *Bar {
392+
return nil
393+
}
394+
373395
func MoreFoo(a int) *ast.BlockStmt {
374396
return nil
375397
}

0 commit comments

Comments
 (0)