diff --git a/client/client.go b/client/client.go index ee34c7ebab..f7bf0c6a5b 100644 --- a/client/client.go +++ b/client/client.go @@ -361,7 +361,13 @@ func (kc *kClient) Call(ctx context.Context, method string, request, response in recycleRI = true } } else { - ri, recycleRI, err = kc.opt.RetryContainer.WithRetryIfNeeded(ctx, callOptRetry, kc.rpcCallWithRetry(ri, method, request, response), ri, request) + var lastRI rpcinfo.RPCInfo + lastRI, recycleRI, err = kc.opt.RetryContainer.WithRetryIfNeeded(ctx, callOptRetry, kc.rpcCallWithRetry(ri, method, request, response), ri, request) + if ri != lastRI { + // reset ri of ctx to lastRI + ctx = rpcinfo.NewCtxWithRPCInfo(ctx, lastRI) + } + ri = lastRI } // do fallback if with setup diff --git a/client/service_inline.go b/client/service_inline.go index a5967bcb60..1b557437ab 100644 --- a/client/service_inline.go +++ b/client/service_inline.go @@ -140,10 +140,16 @@ func (kc *serviceInlineClient) Call(ctx context.Context, method string, request, var reportErr error defer func() { if panicInfo := recover(); panicInfo != nil { - reportErr = rpcinfo.ClientPanicToErr(ctx, panicInfo, ri, false) + reportErr = rpcinfo.ClientPanicToErr(ctx, panicInfo, ri, true) } kc.opt.TracerCtl.DoFinish(ctx, ri, reportErr) - rpcinfo.PutRPCInfo(ri) + // If the user start a new goroutine and return before endpoint finished, it may cause panic. + // For example,, if the user writes a timeout Middleware and times out, rpcinfo will be recycled, + // but in fact, rpcinfo is still being used when it is executed inside + // So if endpoint returns err, client won't recycle rpcinfo. + if reportErr == nil { + rpcinfo.PutRPCInfo(ri) + } callOpts.Recycle() }() reportErr = kc.eps(ctx, request, response) diff --git a/go.mod b/go.mod index edf0406982..2d93b1005b 100644 --- a/go.mod +++ b/go.mod @@ -13,8 +13,8 @@ require ( github.com/cloudwego/fastpb v0.0.4 github.com/cloudwego/frugal v0.1.8 github.com/cloudwego/localsession v0.0.2 - github.com/cloudwego/netpoll v0.5.0 - github.com/cloudwego/thriftgo v0.3.0 + github.com/cloudwego/netpoll v0.5.1 + github.com/cloudwego/thriftgo v0.3.2-0.20230828085742-edaddf2c17af github.com/golang/mock v1.6.0 github.com/google/pprof v0.0.0-20220608213341-c488b8fa1db3 github.com/jhump/protoreflect v1.8.2 diff --git a/go.sum b/go.sum index bcd5830029..be95ef4dac 100644 --- a/go.sum +++ b/go.sum @@ -62,18 +62,14 @@ github.com/cloudwego/localsession v0.0.2/go.mod h1:kiJxmvAcy4PLgKtEnPS5AXed3xCiX github.com/cloudwego/netpoll v0.2.4/go.mod h1:1T2WVuQ+MQw6h6DpE45MohSvDTKdy2DlzCx2KsnPI4E= github.com/cloudwego/netpoll v0.3.1/go.mod h1:1T2WVuQ+MQw6h6DpE45MohSvDTKdy2DlzCx2KsnPI4E= github.com/cloudwego/netpoll v0.4.0/go.mod h1:xVefXptcyheopwNDZjDPcfU6kIjZXZ4nY550k1yH9eQ= -github.com/cloudwego/netpoll v0.4.2-0.20230913081710-1a27688e2033 h1:/VYzCYH+Brp8CW1u475U+mPS7lHv5ulKx0vFJbp3YZ0= -github.com/cloudwego/netpoll v0.4.2-0.20230913081710-1a27688e2033/go.mod h1:xVefXptcyheopwNDZjDPcfU6kIjZXZ4nY550k1yH9eQ= -github.com/cloudwego/netpoll v0.4.2-0.20230918061532-5719b5310f34 h1:AbZPQaXr7MzOiUf1OZauww5rjmBpeLlyhM+hD7UsCn8= -github.com/cloudwego/netpoll v0.4.2-0.20230918061532-5719b5310f34/go.mod h1:xVefXptcyheopwNDZjDPcfU6kIjZXZ4nY550k1yH9eQ= -github.com/cloudwego/netpoll v0.5.0 h1:oRrOp58cPCvK2QbMozZNDESvrxQaEHW2dCimmwH1lcU= -github.com/cloudwego/netpoll v0.5.0/go.mod h1:xVefXptcyheopwNDZjDPcfU6kIjZXZ4nY550k1yH9eQ= +github.com/cloudwego/netpoll v0.5.1 h1:zDUF7xF0C97I10fGlQFJ4jg65khZZMUvSu/TWX44Ohc= +github.com/cloudwego/netpoll v0.5.1/go.mod h1:xVefXptcyheopwNDZjDPcfU6kIjZXZ4nY550k1yH9eQ= github.com/cloudwego/thriftgo v0.1.2/go.mod h1:LzeafuLSiHA9JTiWC8TIMIq64iadeObgRUhmVG1OC/w= github.com/cloudwego/thriftgo v0.2.4/go.mod h1:8i9AF5uDdWHGqzUhXDlubCjx4MEfKvWXGQlMWyR0tM4= github.com/cloudwego/thriftgo v0.2.7/go.mod h1:8i9AF5uDdWHGqzUhXDlubCjx4MEfKvWXGQlMWyR0tM4= github.com/cloudwego/thriftgo v0.2.11/go.mod h1:dAyXHEmKXo0LfMCrblVEY3mUZsdeuA5+i0vF5f09j7E= -github.com/cloudwego/thriftgo v0.3.0 h1:BBb9hVcqmu9p4iKUP/PSIaDB21Vfutgd7k2zgK37Q9Q= -github.com/cloudwego/thriftgo v0.3.0/go.mod h1:AvH0iEjvKHu3cdxG7JvhSAaffkS4h2f4/ZxpJbm48W4= +github.com/cloudwego/thriftgo v0.3.2-0.20230828085742-edaddf2c17af h1:xsNmlAdSnh6zuovEON4Ab0iT+fTfQUWqZ50tk+6OGW8= +github.com/cloudwego/thriftgo v0.3.2-0.20230828085742-edaddf2c17af/go.mod h1:AvH0iEjvKHu3cdxG7JvhSAaffkS4h2f4/ZxpJbm48W4= github.com/cncf/udpa/go v0.0.0-20201120205902-5459f2c99403/go.mod h1:WmhPx2Nbnhtbo57+VJT5O0JRkEi1Wbu0z5j0R8u5Hbk= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= diff --git a/pkg/remote/compression.go b/pkg/remote/compression.go index 58499a21f3..78d71f32eb 100644 --- a/pkg/remote/compression.go +++ b/pkg/remote/compression.go @@ -31,26 +31,30 @@ const ( ) func SetRecvCompressor(ri rpcinfo.RPCInfo, compressorName string) { - if ri == nil { + if ri == nil || compressorName == "" { return } - rpcinfo.AsMutableEndpointInfo(ri.From()).SetTag("recv-compressor", compressorName) + if v, ok := ri.Invocation().(rpcinfo.InvocationSetter); ok { + v.SetExtra("recv-compressor", compressorName) + } } func SetSendCompressor(ri rpcinfo.RPCInfo, compressorName string) { - if ri == nil { + if ri == nil || compressorName == "" { return } - rpcinfo.AsMutableEndpointInfo(ri.From()).SetTag("send-compressor", compressorName) + if v, ok := ri.Invocation().(rpcinfo.InvocationSetter); ok { + v.SetExtra("send-compressor", compressorName) + } } func GetSendCompressor(ri rpcinfo.RPCInfo) string { if ri == nil { return "" } - v, exist := ri.From().Tag("send-compressor") - if exist { - return v + v := ri.Invocation().Extra("send-compressor") + if name, ok := v.(string); ok { + return name } return "" } @@ -59,9 +63,9 @@ func GetRecvCompressor(ri rpcinfo.RPCInfo) string { if ri == nil { return "" } - v, exist := ri.From().Tag("recv-compressor") - if exist { - return v + v := ri.Invocation().Extra("recv-compressor") + if name, ok := v.(string); ok { + return name } return "" } diff --git a/pkg/rpcinfo/interface.go b/pkg/rpcinfo/interface.go index c78d4a9b39..6c882ef208 100644 --- a/pkg/rpcinfo/interface.go +++ b/pkg/rpcinfo/interface.go @@ -83,6 +83,7 @@ type Invocation interface { MethodName() string SeqID() int32 BizStatusErr() kerrors.BizStatusErrorIface + Extra(key string) interface{} } // RPCInfo is the core abstraction of information about an RPC in Kitex. diff --git a/pkg/rpcinfo/invocation.go b/pkg/rpcinfo/invocation.go index eeb60cd9bb..cf6dc0165f 100644 --- a/pkg/rpcinfo/invocation.go +++ b/pkg/rpcinfo/invocation.go @@ -41,15 +41,16 @@ type InvocationSetter interface { SetMethodName(name string) SetSeqID(seqID int32) SetBizStatusErr(err kerrors.BizStatusErrorIface) + SetExtra(key string, value interface{}) Reset() } - type invocation struct { packageName string serviceName string methodName string seqID int32 bizErr kerrors.BizStatusErrorIface + extra map[string]interface{} } // NewInvocation creates a new Invocation with the given service, method and optional package. @@ -130,6 +131,20 @@ func (i *invocation) SetBizStatusErr(err kerrors.BizStatusErrorIface) { i.bizErr = err } +func (i *invocation) SetExtra(key string, value interface{}) { + if i.extra == nil { + i.extra = map[string]interface{}{} + } + i.extra[key] = value +} + +func (i *invocation) Extra(key string) interface{} { + if i.extra == nil { + return nil + } + return i.extra[key] +} + // Reset implements the InvocationSetter interface. func (i *invocation) Reset() { i.zero() @@ -147,4 +162,7 @@ func (i *invocation) zero() { i.serviceName = "" i.methodName = "" i.bizErr = nil + for key := range i.extra { + delete(i.extra, key) + } } diff --git a/pkg/serviceinfo/serviceinfo.go b/pkg/serviceinfo/serviceinfo.go index f89fd6a374..4876fbe3c0 100644 --- a/pkg/serviceinfo/serviceinfo.go +++ b/pkg/serviceinfo/serviceinfo.go @@ -27,6 +27,7 @@ type PayloadCodec int const ( Thrift PayloadCodec = iota Protobuf + Hessian2 ) const ( diff --git a/tool/cmd/kitex/args/args.go b/tool/cmd/kitex/args/args.go index 583c4e6e02..34364b55bd 100644 --- a/tool/cmd/kitex/args/args.go +++ b/tool/cmd/kitex/args/args.go @@ -114,6 +114,8 @@ func (a *Arguments) buildFlags(version string) *flag.FlagSet { "Specify a code gen path.") f.BoolVar(&a.DeepCopyAPI, "deep-copy-api", false, "Generate codes with injecting deep copy method.") + f.StringVar(&a.Protocol, "protocol", "", + "Specify a protocol for codec") a.RecordCmd = os.Args a.Version = version a.ThriftOptions = append(a.ThriftOptions, @@ -298,12 +300,16 @@ func (a *Arguments) BuildCmd(out io.Writer) *exec.Cmd { Stdout: io.MultiWriter(out, os.Stdout), Stderr: io.MultiWriter(out, os.Stderr), } + if a.IDLType == "thrift" { os.Setenv(EnvPluginMode, thriftgo.PluginName) cmd.Args = append(cmd.Args, "thriftgo") for _, inc := range a.Includes { cmd.Args = append(cmd.Args, "-i", inc) } + if strings.EqualFold(a.Protocol, "hessian2") { + a.ThriftOptions = append(a.ThriftOptions, "template=slim") + } a.ThriftOptions = append(a.ThriftOptions, "package_prefix="+a.PackagePrefix) gas := "go:" + strings.Join(a.ThriftOptions, ",") if a.Verbose { diff --git a/tool/internal_pkg/generator/generator.go b/tool/internal_pkg/generator/generator.go index 56bac9a5ea..00e8feda33 100644 --- a/tool/internal_pkg/generator/generator.go +++ b/tool/internal_pkg/generator/generator.go @@ -27,6 +27,7 @@ import ( "github.com/cloudwego/kitex/tool/internal_pkg/log" "github.com/cloudwego/kitex/tool/internal_pkg/tpl" "github.com/cloudwego/kitex/tool/internal_pkg/util" + "github.com/cloudwego/kitex/transport" ) // Constants . @@ -132,6 +133,7 @@ type Config struct { GenPath string DeepCopyAPI bool + Protocol string } // Pack packs the Config into a slice of "key=val" strings. @@ -456,6 +458,9 @@ func (g *generator) updatePackageInfo(pkg *PackageInfo) { pkg.ExternalKitexGen = g.Use pkg.FrugalPretouch = g.FrugalPretouch pkg.Module = g.ModuleName + if strings.EqualFold(g.Protocol, transport.HESSIAN2.String()) { + pkg.Protocol = transport.HESSIAN2 + } if pkg.Dependencies == nil { pkg.Dependencies = make(map[string]string) } @@ -481,10 +486,10 @@ func (g *generator) setImports(name string, pkg *PackageInfo) { } fallthrough case HandlerFileName: - if len(pkg.AllMethods()) > 0 { - pkg.AddImports("context") - } for _, m := range pkg.ServiceInfo.AllMethods() { + if !m.ServerStreaming && !m.ClientStreaming { + pkg.AddImports("context") + } for _, a := range m.Args { for _, dep := range a.Deps { pkg.AddImport(dep.PkgRefName, dep.ImportPath) diff --git a/tool/internal_pkg/generator/generator_test.go b/tool/internal_pkg/generator/generator_test.go index 5d9bfb13b4..f935ab9809 100644 --- a/tool/internal_pkg/generator/generator_test.go +++ b/tool/internal_pkg/generator/generator_test.go @@ -53,6 +53,7 @@ func TestConfig_Pack(t *testing.T) { RecordCmd string ThriftPluginTimeLimit time.Duration TemplateDir string + Protocol string } tests := []struct { name string @@ -63,7 +64,7 @@ func TestConfig_Pack(t *testing.T) { { name: "some", fields: fields{Features: []feature{feature(999)}, ThriftPluginTimeLimit: 30 * time.Second}, - wantRes: []string{"Verbose=false", "GenerateMain=false", "GenerateInvoker=false", "Version=", "NoFastAPI=false", "ModuleName=", "ServiceName=", "Use=", "IDLType=", "Includes=", "ThriftOptions=", "ProtobufOptions=", "IDL=", "OutputPath=", "PackagePrefix=", "CombineService=false", "CopyIDL=false", "ProtobufPlugins=", "Features=999", "FrugalPretouch=false", "ThriftPluginTimeLimit=30s", "ExtensionFile=", "Record=false", "RecordCmd=", "TemplateDir=", "GenPath=", "DeepCopyAPI=false"}, + wantRes: []string{"Verbose=false", "GenerateMain=false", "GenerateInvoker=false", "Version=", "NoFastAPI=false", "ModuleName=", "ServiceName=", "Use=", "IDLType=", "Includes=", "ThriftOptions=", "ProtobufOptions=", "IDL=", "OutputPath=", "PackagePrefix=", "CombineService=false", "CopyIDL=false", "ProtobufPlugins=", "Features=999", "FrugalPretouch=false", "ThriftPluginTimeLimit=30s", "ExtensionFile=", "Record=false", "RecordCmd=", "TemplateDir=", "GenPath=", "DeepCopyAPI=false", "Protocol="}, }, } for _, tt := range tests { @@ -91,6 +92,7 @@ func TestConfig_Pack(t *testing.T) { FrugalPretouch: tt.fields.FrugalPretouch, ThriftPluginTimeLimit: tt.fields.ThriftPluginTimeLimit, TemplateDir: tt.fields.TemplateDir, + Protocol: tt.fields.Protocol, } if gotRes := c.Pack(); !reflect.DeepEqual(gotRes, tt.wantRes) { t.Errorf("Config.Pack() = \n%v\nwant\n%v", gotRes, tt.wantRes) @@ -121,6 +123,7 @@ func TestConfig_Unpack(t *testing.T) { Features []feature FrugalPretouch bool TemplateDir string + Protocol string } type args struct { args []string @@ -162,6 +165,7 @@ func TestConfig_Unpack(t *testing.T) { Features: tt.fields.Features, FrugalPretouch: tt.fields.FrugalPretouch, TemplateDir: tt.fields.TemplateDir, + Protocol: tt.fields.Protocol, } if err := c.Unpack(tt.args.args); (err != nil) != tt.wantErr { t.Errorf("Config.Unpack() error = %v, wantErr %v", err, tt.wantErr) diff --git a/tool/internal_pkg/generator/type.go b/tool/internal_pkg/generator/type.go index c890c71139..e5bd09bb05 100644 --- a/tool/internal_pkg/generator/type.go +++ b/tool/internal_pkg/generator/type.go @@ -21,6 +21,7 @@ import ( "text/template" "github.com/cloudwego/kitex/tool/internal_pkg/util" + "github.com/cloudwego/kitex/transport" ) // File . @@ -45,6 +46,7 @@ type PackageInfo struct { Features []feature FrugalPretouch bool Module string + Protocol transport.Protocol } // AddImport . @@ -97,6 +99,7 @@ type ServiceInfo struct { CombineServices []*ServiceInfo HasStreaming bool ServiceFilePath string + Protocol string } // AllMethods returns all methods that the service have. diff --git a/tool/internal_pkg/pluginmode/thriftgo/convertor.go b/tool/internal_pkg/pluginmode/thriftgo/convertor.go index 50f80add62..dc082b52af 100644 --- a/tool/internal_pkg/pluginmode/thriftgo/convertor.go +++ b/tool/internal_pkg/pluginmode/thriftgo/convertor.go @@ -25,6 +25,7 @@ import ( "strings" "github.com/cloudwego/kitex/tool/internal_pkg/util" + "github.com/cloudwego/kitex/transport" "github.com/cloudwego/thriftgo/generator/backend" "github.com/cloudwego/thriftgo/generator/golang" @@ -188,10 +189,9 @@ func (c *converter) copyTypeWithRef(t *parser.Type, ref string) (res *parser.Typ default: if strings.Contains(t.Name, ".") { return &parser.Type{ - Name: t.Name, - KeyType: t.KeyType, - ValueType: t.ValueType, - Annotations: t.Annotations, + Name: t.Name, + KeyType: t.KeyType, + ValueType: t.ValueType, } } return &parser.Type{ @@ -353,8 +353,10 @@ func (c *converter) convertTypes(req *plugin.Request) error { } // combine service if ast == req.AST && c.Config.CombineService && len(ast.Services) > 0 { - var svcs []*generator.ServiceInfo - var methods []*generator.MethodInfo + var ( + svcs []*generator.ServiceInfo + methods []*generator.MethodInfo + ) for _, s := range all[ast.Filename] { svcs = append(svcs, s) methods = append(methods, s.AllMethods()...) @@ -376,6 +378,11 @@ func (c *converter) convertTypes(req *plugin.Request) error { Methods: methods, ServiceFilePath: ast.Filename, } + + if c.IsHessian2() { + si.Protocol = transport.HESSIAN2.String() + } + si.ServiceTypeName = func() string { return si.ServiceName } all[ast.Filename] = append(all[ast.Filename], si) c.svc2ast[si] = ast @@ -404,6 +411,10 @@ func (c *converter) makeService(pkg generator.PkgInfo, svc *golang.Service) (*ge } si.Methods = append(si.Methods, mi) } + + if c.IsHessian2() { + si.Protocol = transport.HESSIAN2.String() + } return si, nil } @@ -484,3 +495,7 @@ func (c *converter) getCombineServiceName(name string, svcs []*generator.Service } return name } + +func (c *converter) IsHessian2() bool { + return strings.EqualFold(c.Config.Protocol, transport.HESSIAN2.String()) +} diff --git a/tool/internal_pkg/pluginmode/thriftgo/file_tpl.go b/tool/internal_pkg/pluginmode/thriftgo/file_tpl.go index 7810566fd9..cbd7985cd8 100644 --- a/tool/internal_pkg/pluginmode/thriftgo/file_tpl.go +++ b/tool/internal_pkg/pluginmode/thriftgo/file_tpl.go @@ -27,6 +27,7 @@ import ( "reflect" "github.com/apache/thrift/lib/go/thrift" + {{if GenerateDeepCopyAPIs -}} kutils "github.com/cloudwego/kitex/pkg/utils" {{- end}} @@ -92,6 +93,7 @@ const patchArgsAndResult = ` {{$argType := .ArgType}} {{$resType := .ResType}} + {{- if GenerateArgsResultTypes}} {{template "StructLike" $argType}} {{- end}}{{/* if GenerateArgsResultTypes */}} @@ -106,6 +108,7 @@ func (p *{{$argType.GoName}}) GetFirstArgument() interface{} { func (p *{{$resType.GoName}}) GetResult() interface{} { return {{if .Void}}nil{{else}}p.Success{{end}} } + {{- end}}{{/* if not .Oneway */}} {{- end}}{{/* range Functions */}} {{- end}}{{/* range .Scope.Service */}} @@ -116,4 +119,5 @@ var basicTemplates = []string{ patchArgsAndResult, file, body, + registerHessian, } diff --git a/tool/internal_pkg/pluginmode/thriftgo/patcher.go b/tool/internal_pkg/pluginmode/thriftgo/patcher.go index 38ab6fc58e..221ec4e508 100644 --- a/tool/internal_pkg/pluginmode/thriftgo/patcher.go +++ b/tool/internal_pkg/pluginmode/thriftgo/patcher.go @@ -60,6 +60,7 @@ type patcher struct { record bool recordCmd []string deepCopyAPI bool + protocol string fileTpl *template.Template } @@ -106,6 +107,9 @@ func (p *patcher) buildTemplates() (err error) { // p.XXX return strings.ToLower(s[2:3]) + s[3:] } + m["IsHessian"] = func() bool { + return p.IsHessian2() + } tpl := template.New("kitex").Funcs(m) allTemplates := basicTemplates @@ -123,6 +127,8 @@ func (p *patcher) buildTemplates() (err error) { fieldDeepCopySet, fieldDeepCopyBaseType, structLikeCodec, + structLikeProtocol, + javaClassName, processor) } else { allTemplates = append(allTemplates, structLikeCodec, @@ -134,6 +140,8 @@ func (p *patcher) buildTemplates() (err error) { structLikeLength, structLikeFastWriteField, structLikeFieldLength, + structLikeProtocol, + javaClassName, fieldFastRead, fieldFastReadStructLike, fieldFastReadBaseType, @@ -184,6 +192,14 @@ func (p *patcher) buildTemplates() (err error) { if err != nil { return fmt.Errorf("failed to parse extra templates: %w: %q", err, ext) } + + if p.IsHessian2() { + tpl, err = tpl.Parse(registerHessian) + if err != nil { + return fmt.Errorf("failed to parse hessian2 templates: %w: %q", err, registerHessian) + } + } + p.fileTpl = tpl return nil } @@ -195,7 +211,6 @@ func (p *patcher) patch(req *plugin.Request) (patches []*plugin.Generated, err e protection := make(map[string]*plugin.Generated) for ast := range req.AST.DepthFirstSearch() { - // scope, err := golang.BuildScope(p.utils, ast) scope, _, err := golang.BuildRefScope(p.utils, ast) if err != nil { @@ -228,6 +243,17 @@ func (p *patcher) patch(req *plugin.Request) (patches []*plugin.Generated, err e continue } + if p.IsHessian2() { + register := util.JoinPath(path, fmt.Sprintf("hessian2-register-%s", base)) + patch, err := p.patchHessian(path, scope, pkgName, base) + if err != nil { + return nil, fmt.Errorf("patch hessian fail for %q: %w", ast.Filename, err) + } + + patches = append(patches, patch) + protection[register] = patch + } + data := &struct { Scope *golang.Scope PkgName string @@ -275,10 +301,37 @@ func (p *patcher) patch(req *plugin.Request) (patches []*plugin.Generated, err e Name: &bashPath, }) } + } return } +func (p *patcher) patchHessian(path string, scope *golang.Scope, pkgName, base string) (patch *plugin.Generated, err error) { + buf := strings.Builder{} + resigterIDLName := fmt.Sprintf("hessian2-register-%s", base) + register := util.JoinPath(path, resigterIDLName) + data := &struct { + Scope *golang.Scope + PkgName string + Imports map[string]string + GoName string + IDLName string + }{Scope: scope, PkgName: pkgName, IDLName: util.UpperFirst(strings.Replace(base, ".go", "", -1))} + data.Imports, err = scope.ResolveImports() + if err != nil { + return nil, err + } + + if err = p.fileTpl.ExecuteTemplate(&buf, "register", data); err != nil { + return nil, err + } + patch = &plugin.Generated{ + Content: buf.String(), + Name: ®ister, + } + return patch, nil +} + func getBashPath() string { if runtime.GOOS == "windows" { return "kitex-all.bat" @@ -370,6 +423,10 @@ func (p *patcher) isBinaryOrStringType(t *parser.Type) bool { return t.Category.IsBinary() || t.Category.IsString() } +func (p *patcher) IsHessian2() bool { + return strings.EqualFold(p.protocol, "hessian2") +} + var typeIDToGoType = map[string]string{ "Bool": "bool", "Byte": "int8", diff --git a/tool/internal_pkg/pluginmode/thriftgo/plugin.go b/tool/internal_pkg/pluginmode/thriftgo/plugin.go index 77822bedcb..1d0cefb5f1 100644 --- a/tool/internal_pkg/pluginmode/thriftgo/plugin.go +++ b/tool/internal_pkg/pluginmode/thriftgo/plugin.go @@ -114,7 +114,6 @@ func Run() int { } return conv.fail(err) } - p := &patcher{ noFastAPI: conv.Config.NoFastAPI, utils: conv.Utils, @@ -124,6 +123,7 @@ func Run() int { record: conv.Config.Record, recordCmd: conv.Config.RecordCmd, deepCopyAPI: conv.Config.DeepCopyAPI, + protocol: conv.Config.Protocol, } patches, err := p.patch(req) if err != nil { diff --git a/tool/internal_pkg/pluginmode/thriftgo/register_tpl.go b/tool/internal_pkg/pluginmode/thriftgo/register_tpl.go new file mode 100644 index 0000000000..493d88b6c3 --- /dev/null +++ b/tool/internal_pkg/pluginmode/thriftgo/register_tpl.go @@ -0,0 +1,182 @@ +// Copyright 2021 CloudWeGo Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package thriftgo + +const registerHessian = ` +{{- define "register"}} +package {{ .PkgName}} + +import ( + "fmt" + + "github.com/pkg/errors" + "github.com/kitex-contrib/codec-dubbo/pkg/hessian2" + codec "github.com/kitex-contrib/codec-dubbo/pkg/iface" +) + +var objects{{ .IDLName}} = []interface{}{ +{{- range .Scope.StructLikes}} + &{{ .GoName}}{}, +{{- end}} +} + +func init() { + hessian2.Register(objects{{ .IDLName}} ) +} + +{{range .Scope.Services }} +func Get{{.GoName}}IDLAnnotations() map[string][]string { + return map[string][]string { + {{- range .Annotations}} + "{{.Key}}": { {{range .Values}}"{{.}}", {{- end}}}, + {{- end}} + } +} +{{- end}} + +{{- range .Scope.StructLikes}} +{{template "StructLikeProtocol" .}} +{{- end}} + +{{- range .Scope.Services}} +{{- range .Functions}} + +{{$argType := .ArgType}} +{{$resType := .ResType}} + +func (p *{{$argType.GoName}}) Encode(e codec.Encoder) error { + var err error +{{- range $argType.Fields}} +{{- $FieldName := .GoName}} + err = e.Encode(p.{{$FieldName}}) + if err != nil { + return err + } +{{end}}{{/* range .Fields */}} + return nil +} + +func (p *{{$argType.GoName}}) Decode(d codec.Decoder) error { + var ( + err error + v interface{} + ) +{{- range $argType.Fields}} +{{- $Type := .Type }} +{{- $FieldName := .GoName}} +{{- $FieldTypeName := .GoTypeName}} + v, err = d.Decode() + if err != nil { + return err + } + err = hessian2.ReflectResponse(v, &p.{{$FieldName}}) + if err != nil { + return errors.Wrap(err, fmt.Sprintf("invalid data type: %T", v)) + } +{{end}}{{/* range .Fields */}} + return nil +} {{/* encode decode */}} + +func (p *{{$resType.GoName}}) Encode(e codec.Encoder) error { + var err error +{{- range $resType.Fields}} +{{- $FieldName := .GoName}} + err = e.Encode(p.{{$FieldName}}) + if err != nil { + return err + } +{{end}}{{/* range .Fields */}} + return nil +} + +func (p *{{$resType.GoName}}) Decode(d codec.Decoder) error { + var ( + err error + v interface{} + ) +{{- range $resType.Fields}} +{{- $Type := .Type }} +{{- $FieldName := .GoName}} +{{- $FieldTypeName := .GoTypeName}} + v, err = d.Decode() + if err != nil { + return err + } + err = hessian2.ReflectResponse(v, &p.{{$FieldName}}) + if err != nil { + return errors.Wrap(err, fmt.Sprintf("invalid data type: %T", v)) + } +{{end}}{{/* range .Fields */}} + return nil +} {{/* encode decode */}} + +{{- end}}{{/* range Functions */}} +{{- end}}{{/* range .Scope.Service */}} + +{{- end}}{{/* define RegisterHessian*/}} +` + +const structLikeProtocol = ` +{{define "StructLikeProtocol"}} +{{- $TypeName := .GoName}} +func (p *{{$TypeName}}) Encode(e codec.Encoder) error { + var err error +{{- range .Fields}} +{{- $FieldName := .GoName}} + err = e.Encode(p.{{$FieldName}}) + if err != nil { + return err + } +{{end}}{{/* range .Fields */}} + return nil +} + +func (p *{{$TypeName}}) Decode(d codec.Decoder) error { + var ( + err error + v interface{} + ) +{{- range .Fields}} +{{- $Type := .Type }} +{{- $FieldName := .GoName}} +{{- $FieldTypeName := .GoTypeName}} + v, err = d.Decode() + if err != nil { + return err + } + err = hessian2.ReflectResponse(v, &p.{{$FieldName}}) + if err != nil { + return errors.Wrap(err, fmt.Sprintf("invalid data type: %T", v)) + } +{{end}}{{/* range .Fields */}} + return nil +} {{/* encode decode */}} + +{{template "JavaClassName" .}} +{{- end}}{{/* define "StructLikeProtocol" */}} +` + +const javaClassName = ` +{{define "JavaClassName"}} +{{- $TypeName := .GoName}} +{{- $anno := .Annotations }} +{{- $value := $anno.ILocValueByKey "JavaClassName" 0}} +{{- if ne "" $value}} +func (p *{{$TypeName}}) JavaClassName() string { + return "{{$value}}" +} +{{- end}}{{/* end if */}} +{{- end}}{{/* end JavaClassName */}} +` diff --git a/tool/internal_pkg/tpl/service.go b/tool/internal_pkg/tpl/service.go index c4d5d54ac1..258f9914b6 100644 --- a/tool/internal_pkg/tpl/service.go +++ b/tool/internal_pkg/tpl/service.go @@ -71,7 +71,9 @@ func NewServiceInfo() *kitex.ServiceInfo { ServiceName: serviceName, HandlerType: handlerType, Methods: methods, + {{- if ne "Hessian2" .ServiceInfo.Protocol}} PayloadCodec: kitex.{{.Codec | UpperFirst}}, + {{- end}} KiteXGenVersion: "{{.Version}}", Extra: extra, } diff --git a/transport/keys.go b/transport/keys.go index c4c4b8afd5..f63080eaed 100644 --- a/transport/keys.go +++ b/transport/keys.go @@ -28,6 +28,7 @@ const ( Framed HTTP GRPC + HESSIAN2 TTHeaderFramed = TTHeader | Framed ) @@ -50,6 +51,8 @@ func (tp Protocol) String() string { return "TTHeaderFramed" case GRPC: return "GRPC" + case HESSIAN2: + return "Hessian2" } return Unknown } diff --git a/version.go b/version.go index ab7b95bd4c..90570b4b28 100644 --- a/version.go +++ b/version.go @@ -19,5 +19,5 @@ package kitex // Name and Version info of this framework, used for statistics and debug const ( Name = "Kitex" - Version = "v0.7.2" + Version = "v0.7.3" )