Skip to content

Commit

Permalink
optimize(exit): server graceful shutdown logic to avoid EOF when idle…
Browse files Browse the repository at this point in the history
… connections receive new requests after being closed (#1681)
  • Loading branch information
jayantxie authored Jan 22, 2025
1 parent 1de9e03 commit 4f3c3fc
Show file tree
Hide file tree
Showing 6 changed files with 117 additions and 78 deletions.
38 changes: 27 additions & 11 deletions pkg/remote/trans/default_server_handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import (
"fmt"
"net"
"runtime/debug"
"sync/atomic"

"github.com/cloudwego/kitex/pkg/endpoint"
"github.com/cloudwego/kitex/pkg/kerrors"
Expand Down Expand Up @@ -49,13 +50,14 @@ func NewDefaultSvrTransHandler(opt *remote.ServerOption, ext Extension) (remote.
}

type svrTransHandler struct {
opt *remote.ServerOption
svcSearcher remote.ServiceSearcher
targetSvcInfo *serviceinfo.ServiceInfo
inkHdlFunc endpoint.Endpoint
codec remote.Codec
transPipe *remote.TransPipeline
ext Extension
opt *remote.ServerOption
svcSearcher remote.ServiceSearcher
targetSvcInfo *serviceinfo.ServiceInfo
inkHdlFunc endpoint.Endpoint
codec remote.Codec
transPipe *remote.TransPipeline
ext Extension
inGracefulShutdown uint32
}

// Write implements the remote.ServerTransHandler interface.
Expand Down Expand Up @@ -123,13 +125,22 @@ func (t *svrTransHandler) Read(ctx context.Context, conn net.Conn, recvMsg remot
}

func (t *svrTransHandler) newCtxWithRPCInfo(ctx context.Context, conn net.Conn) (context.Context, rpcinfo.RPCInfo) {
var ri rpcinfo.RPCInfo
if rpcinfo.PoolEnabled() { // reuse per-connection rpcinfo
return ctx, rpcinfo.GetRPCInfo(ctx)
ri = rpcinfo.GetRPCInfo(ctx)
// delayed reinitialize for faster response
} else {
// new rpcinfo if reuse is disabled
ri = t.opt.InitOrResetRPCInfoFunc(nil, conn.RemoteAddr())
ctx = rpcinfo.NewCtxWithRPCInfo(ctx, ri)
}
// new rpcinfo if reuse is disabled
ri := t.opt.InitOrResetRPCInfoFunc(nil, conn.RemoteAddr())
return rpcinfo.NewCtxWithRPCInfo(ctx, ri), ri
if atomic.LoadUint32(&t.inGracefulShutdown) == 1 {
// If server is in graceful shutdown status, mark connection reset flag to all responses to let client close the connections.
if ei := rpcinfo.AsTaggable(ri.To()); ei != nil {
ei.SetTag(rpcinfo.ConnResetTag, "1")
}
}
return ctx, ri
}

// OnRead implements the remote.ServerTransHandler interface.
Expand Down Expand Up @@ -348,6 +359,11 @@ func (t *svrTransHandler) finishProfiler(ctx context.Context) {
t.opt.Profiler.Untag(ctx)
}

func (t *svrTransHandler) GracefulShutdown(ctx context.Context) error {
atomic.StoreUint32(&t.inGracefulShutdown, 1)
return nil
}

func getRemoteInfo(ri rpcinfo.RPCInfo, conn net.Conn) (string, net.Addr) {
rAddr := conn.RemoteAddr()
if ri == nil {
Expand Down
9 changes: 9 additions & 0 deletions pkg/remote/trans/netpoll/trans_server.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ import (
"runtime/debug"
"sync"
"syscall"
"time"

"github.com/cloudwego/netpoll"

Expand Down Expand Up @@ -119,6 +120,14 @@ func (ts *transServer) Shutdown() (err error) {
if err != nil {
klog.Warnf("KITEX: server graceful shutdown error: %v", err)
}
// 3. wait some time to receive requests before closing idle conns
/*
When the netpoll eventloop shutdown, all idle connections will be closed.
At this time, these connections may just receive requests, and then the peer side will report an EOF error.
To reduce such cases, wait for some time to try to receive these requests as much as possible,
so that the closing of connections can be controlled by the upper-layer protocol and the EOF problem can be reduced.
*/
time.Sleep(100 * time.Millisecond)
}
}
if ts.evl != nil {
Expand Down
2 changes: 1 addition & 1 deletion pkg/retry/failure_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ func TestFixedBackOff_Wait(t *testing.T) {
bk.Wait(1)
waitTime := time.Since(startTime)
test.Assert(t, time.Millisecond*fix <= waitTime)
test.Assert(t, waitTime < time.Millisecond*(fix+5))
test.Assert(t, waitTime < time.Millisecond*(fix*2))
}

func TestFixedBackOff_String(t *testing.T) {
Expand Down
8 changes: 8 additions & 0 deletions pkg/transmeta/ttheader.go
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,11 @@ func (ch *clientTTHeaderHandler) ReadMeta(ctx context.Context, msg remote.Messag
if setter, ok := ri.Invocation().(rpcinfo.InvocationSetter); ok && bizErr != nil {
setter.SetBizStatusErr(bizErr)
}
if val, ok := strInfo[transmeta.HeaderConnectionReadyToReset]; ok {
if ei := rpcinfo.AsTaggable(ri.To()); ei != nil {
ei.SetTag(rpcinfo.ConnResetTag, val)
}
}
return ctx, nil
}

Expand Down Expand Up @@ -190,6 +195,9 @@ func (sh *serverTTHeaderHandler) WriteMeta(ctx context.Context, msg remote.Messa
strInfo[bizExtra], _ = utils.Map2JSONStr(bizErr.BizExtra())
}
}
if val, ok := ri.To().Tag(rpcinfo.ConnResetTag); ok {
strInfo[transmeta.HeaderConnectionReadyToReset] = val
}

return ctx, nil
}
Expand Down
2 changes: 1 addition & 1 deletion pkg/transmeta/ttheader_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ func TestTTHeaderServerReadMetainfo(t *testing.T) {

func TestTTHeaderServerWriteMetainfo(t *testing.T) {
ctx := context.Background()
ri := rpcinfo.NewRPCInfo(nil, nil, rpcinfo.NewInvocation("", ""),
ri := rpcinfo.NewRPCInfo(nil, rpcinfo.NewEndpointInfo("", "mock", nil, nil), rpcinfo.NewInvocation("", ""),
rpcinfo.NewRPCConfig(), rpcinfo.NewRPCStats())
msg := remote.NewMessage(nil, mocks.ServiceInfo(), ri, remote.Call, remote.Client)

Expand Down
136 changes: 71 additions & 65 deletions server/server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -979,80 +979,86 @@ func TestInvokeHandlerPanic(t *testing.T) {
}

func TestRegisterService(t *testing.T) {
svr := NewServer()
time.AfterFunc(time.Second, func() {
err := svr.Stop()
test.Assert(t, err == nil, err)
})
{
svr := NewServer()
time.AfterFunc(time.Second, func() {
err := svr.Stop()
test.Assert(t, err == nil, err)
})

svr.Run()
svr.Run()

test.PanicAt(t, func() {
_ = svr.RegisterService(mocks.ServiceInfo(), mocks.MyServiceHandler())
}, func(err interface{}) bool {
if errMsg, ok := err.(string); ok {
return strings.Contains(errMsg, "server is running")
}
return true
})
svr.Stop()
test.PanicAt(t, func() {
_ = svr.RegisterService(mocks.ServiceInfo(), mocks.MyServiceHandler())
}, func(err interface{}) bool {
if errMsg, ok := err.(string); ok {
return strings.Contains(errMsg, "server is running")
}
return true
})
svr.Stop()
}

svr = NewServer()
time.AfterFunc(time.Second, func() {
err := svr.Stop()
test.Assert(t, err == nil, err)
})
{
svr := NewServer()
time.AfterFunc(time.Second, func() {
err := svr.Stop()
test.Assert(t, err == nil, err)
})

test.PanicAt(t, func() {
_ = svr.RegisterService(nil, mocks.MyServiceHandler())
}, func(err interface{}) bool {
if errMsg, ok := err.(string); ok {
return strings.Contains(errMsg, "svcInfo is nil")
}
return true
})
test.PanicAt(t, func() {
_ = svr.RegisterService(nil, mocks.MyServiceHandler())
}, func(err interface{}) bool {
if errMsg, ok := err.(string); ok {
return strings.Contains(errMsg, "svcInfo is nil")
}
return true
})

test.PanicAt(t, func() {
_ = svr.RegisterService(mocks.ServiceInfo(), nil)
}, func(err interface{}) bool {
if errMsg, ok := err.(string); ok {
return strings.Contains(errMsg, "handler is nil")
}
return true
})
test.PanicAt(t, func() {
_ = svr.RegisterService(mocks.ServiceInfo(), nil)
}, func(err interface{}) bool {
if errMsg, ok := err.(string); ok {
return strings.Contains(errMsg, "handler is nil")
}
return true
})

test.PanicAt(t, func() {
_ = svr.RegisterService(mocks.ServiceInfo(), mocks.MyServiceHandler(), WithFallbackService())
_ = svr.RegisterService(mocks.ServiceInfo(), mocks.MyServiceHandler())
}, func(err interface{}) bool {
if errMsg, ok := err.(string); ok {
return strings.Contains(errMsg, "Service[MockService] is already defined")
}
return true
})
test.PanicAt(t, func() {
_ = svr.RegisterService(mocks.ServiceInfo(), mocks.MyServiceHandler(), WithFallbackService())
_ = svr.RegisterService(mocks.ServiceInfo(), mocks.MyServiceHandler())
}, func(err interface{}) bool {
if errMsg, ok := err.(string); ok {
return strings.Contains(errMsg, "Service[MockService] is already defined")
}
return true
})

test.PanicAt(t, func() {
_ = svr.RegisterService(mocks.Service2Info(), mocks.MyServiceHandler(), WithFallbackService())
}, func(err interface{}) bool {
if errMsg, ok := err.(string); ok {
return strings.Contains(errMsg, "multiple fallback services cannot be registered")
}
return true
})
svr.Stop()
test.PanicAt(t, func() {
_ = svr.RegisterService(mocks.Service2Info(), mocks.MyServiceHandler(), WithFallbackService())
}, func(err interface{}) bool {
if errMsg, ok := err.(string); ok {
return strings.Contains(errMsg, "multiple fallback services cannot be registered")
}
return true
})
svr.Stop()
}

svr = NewServer()
time.AfterFunc(time.Second, func() {
err := svr.Stop()
test.Assert(t, err == nil, err)
})
{
svr := NewServer()
time.AfterFunc(time.Second, func() {
err := svr.Stop()
test.Assert(t, err == nil, err)
})

_ = svr.RegisterService(mocks.ServiceInfo(), mocks.MyServiceHandler())
_ = svr.RegisterService(mocks.Service3Info(), mocks.MyServiceHandler())
err := svr.Run()
test.Assert(t, err != nil)
test.Assert(t, err.Error() == "method name [mock] is conflicted between services but no fallback service is specified")
svr.Stop()
_ = svr.RegisterService(mocks.ServiceInfo(), mocks.MyServiceHandler())
_ = svr.RegisterService(mocks.Service3Info(), mocks.MyServiceHandler())
err := svr.Run()
test.Assert(t, err != nil)
test.Assert(t, err.Error() == "method name [mock] is conflicted between services but no fallback service is specified")
svr.Stop()
}
}

func TestRegisterServiceWithMiddleware(t *testing.T) {
Expand Down

0 comments on commit 4f3c3fc

Please sign in to comment.