diff --git a/pkg/remote/trans/nphttp2/grpc/controlbuf.go b/pkg/remote/trans/nphttp2/grpc/controlbuf.go index 7dd4701124..a728cca685 100644 --- a/pkg/remote/trans/nphttp2/grpc/controlbuf.go +++ b/pkg/remote/trans/nphttp2/grpc/controlbuf.go @@ -451,19 +451,23 @@ func (c *controlBuffer) get(block bool) (interface{}, error) { select { case <-c.ch: case <-c.done: - c.finish() - return nil, ErrConnClosing + var err error + c.finish(errStatusConnClosing) + c.mu.Lock() + err = c.err + c.mu.Unlock() + return nil, err } } } -func (c *controlBuffer) finish() { +func (c *controlBuffer) finish(err error) { c.mu.Lock() if c.err != nil { c.mu.Unlock() return } - c.err = ErrConnClosing + c.err = err // There may be headers for streams in the control buffer. // These streams need to be cleaned out since the transport // is still not aware of these yet. @@ -473,7 +477,7 @@ func (c *controlBuffer) finish() { continue } if hdr.onOrphaned != nil { // It will be nil on the server-side. - hdr.onOrphaned(ErrConnClosing) + hdr.onOrphaned(err) } } c.mu.Unlock() diff --git a/pkg/remote/trans/nphttp2/grpc/controlbuf_test.go b/pkg/remote/trans/nphttp2/grpc/controlbuf_test.go index 643a4e3de5..d69e660f94 100644 --- a/pkg/remote/trans/nphttp2/grpc/controlbuf_test.go +++ b/pkg/remote/trans/nphttp2/grpc/controlbuf_test.go @@ -18,6 +18,7 @@ package grpc import ( "context" + "errors" "testing" "time" @@ -25,7 +26,7 @@ import ( ) func TestControlBuf(t *testing.T) { - ctx := context.Background() + ctx, cancel := context.WithCancel(context.Background()) cb := newControlBuffer(ctx.Done()) // test put() @@ -52,7 +53,8 @@ func TestControlBuf(t *testing.T) { test.Assert(t, !success, err) // test throttle() mock a lot of response frame so throttle() will block current goroutine - for i := 0; i < maxQueuedTransportResponseFrames+5; i++ { + exceedSize := 5 + for i := 0; i < maxQueuedTransportResponseFrames+exceedSize; i++ { err := cb.put(&ping{}) test.Assert(t, err == nil, err) } @@ -60,16 +62,44 @@ func TestControlBuf(t *testing.T) { // start a new goroutine to consume response frame go func() { time.Sleep(time.Millisecond * 100) - for { + for i := 0; i < exceedSize+1; i++ { it, err := cb.get(false) - if err != nil || it == nil { - break - } + test.Assert(t, err == nil, err) + test.Assert(t, it != nil) } }() cb.throttle() + // consumes all of the frames + for { + it, err := cb.get(false) + if err != nil || it == nil { + break + } + } + + finishErr := errors.New("finish") + go func() { + ticker := time.NewTicker(10 * time.Millisecond) + defer ticker.Stop() + for range ticker.C { + var block bool + cb.mu.Lock() + block = cb.consumerWaiting + cb.mu.Unlock() + if block { + cb.finish(finishErr) + cancel() + return + } + } + }() + item, err = cb.get(true) + test.Assert(t, err == finishErr, err) + test.Assert(t, item == nil, item) - // test finish() - cb.finish() + err = cb.put(testItem) + test.Assert(t, err == finishErr, err) + _, err = cb.get(false) + test.Assert(t, err == finishErr, err) } diff --git a/pkg/remote/trans/nphttp2/grpc/error_prompt.go b/pkg/remote/trans/nphttp2/grpc/error_prompt.go new file mode 100644 index 0000000000..77d3530701 --- /dev/null +++ b/pkg/remote/trans/nphttp2/grpc/error_prompt.go @@ -0,0 +1,52 @@ +/* + * Copyright 2024 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package grpc + +import ( + "fmt" +) + +const ( + remoteErrTpl = "[triggered by %s]" + remoteErrSuffix = "[triggered by remote service]" + handlerSideErrSuffix = "[triggered by handler side]" + sendRSTStreamFrameSuffix = "[send RST Stream Frame]" +) + +func remoteErrMsg(msg string) string { + return msg + remoteErrSuffix +} + +func remoteErrMsgf(msg string, a ...interface{}) string { + return fmt.Sprintf(msg+remoteErrSuffix, a...) +} + +func handlerSideErrMsg(msg string) string { + return msg + handlerSideErrSuffix +} + +func handlerSideErrMsgf(msg string, a ...interface{}) string { + return fmt.Sprintf(msg+handlerSideErrSuffix, a...) +} + +func sendRSTStreamFrameMsg(msg string) string { + return msg + sendRSTStreamFrameSuffix +} + +func sendRSTStreamFrameMsgf(msg string, a ...interface{}) string { + return fmt.Sprintf(msg+sendRSTStreamFrameSuffix, a...) +} diff --git a/pkg/remote/trans/nphttp2/grpc/error_prompt_test.go b/pkg/remote/trans/nphttp2/grpc/error_prompt_test.go new file mode 100644 index 0000000000..17799560bf --- /dev/null +++ b/pkg/remote/trans/nphttp2/grpc/error_prompt_test.go @@ -0,0 +1,32 @@ +/* + * Copyright 2024 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package grpc + +import ( + "testing" + + "github.com/cloudwego/kitex/internal/test" +) + +func TestErrMsg(t *testing.T) { + test.Assert(t, remoteErrMsg("test") == "test[triggered by remote service]") + test.Assert(t, remoteErrMsgf("test: %s", "val") == "test: val[triggered by remote service]") + test.Assert(t, handlerSideErrMsg("test") == "test[triggered by handler side]") + test.Assert(t, handlerSideErrMsgf("test: %s", "val") == "test: val[triggered by handler side]") + test.Assert(t, sendRSTStreamFrameMsg("test") == "test[send RST Stream Frame]") + test.Assert(t, sendRSTStreamFrameMsgf("test: %s", "val") == "test: val[send RST Stream Frame]") +} diff --git a/pkg/remote/trans/nphttp2/grpc/http2_client.go b/pkg/remote/trans/nphttp2/grpc/http2_client.go index 58b2eaba01..b825e9f23c 100644 --- a/pkg/remote/trans/nphttp2/grpc/http2_client.go +++ b/pkg/remote/trans/nphttp2/grpc/http2_client.go @@ -534,6 +534,7 @@ func (t *http2Client) CloseStream(s *Stream, err error) { if err != nil { rst = true rstCode = http2.ErrCodeCancel + klog.CtxInfof(s.ctx, sendRSTStreamFrameMsgf("KITEX: stream closed by user side ctx canceled, err: %v, code: %d", err, rstCode)) } t.closeStream(s, err, rst, rstCode, status.Convert(err), nil, false) } @@ -557,6 +558,15 @@ func (t *http2Client) closeStream(s *Stream, err error, rst bool, rstCode http2. // This will unblock reads eventually. s.write(recvMsg{err: err}) } + + // store closeStreamErr + if err == io.EOF { + err = st.Err() + } + if err != nil { + s.closeStreamErr.Store(err) + } + // If headerChan isn't closed, then close it. if atomic.CompareAndSwapUint32(&s.headerChanClosed, 0, 1) { s.noHeaders = true @@ -597,6 +607,9 @@ func (t *http2Client) closeStream(s *Stream, err error, rst bool, rstCode http2. // re-connected. This happens because t.onClose() begins reconnect logic at the // addrConn level and blocks until the addrConn is successfully connected. func (t *http2Client) Close(err error) error { + if rawErr, ok := err.(ConnectionError); ok { + err = status.Err(codes.Unavailable, rawErr.Desc) + } t.mu.Lock() // Make sure we only Close once. if t.state == closing { @@ -617,7 +630,7 @@ func (t *http2Client) Close(err error) error { t.kpDormancyCond.Signal() } t.mu.Unlock() - t.controlBuf.finish() + t.controlBuf.finish(err) t.cancel() cErr := t.conn.Close() @@ -656,10 +669,10 @@ func (t *http2Client) Write(s *Stream, hdr, data []byte, opts *Options) error { if opts.Last { // If it's the last message, update stream state. if !s.compareAndSwapState(streamActive, streamWriteDone) { - return errStreamDone + return s.getCloseStreamErr() } } else if s.getState() != streamActive { - return errStreamDone + return s.getCloseStreamErr() } df := newDataFrame() df.streamID = s.id @@ -670,7 +683,7 @@ func (t *http2Client) Write(s *Stream, hdr, data []byte, opts *Options) error { df.originD = df.d if hdr != nil || data != nil { // If it's not an empty data frame, check quota. if err := s.wq.get(int32(len(hdr) + len(data))); err != nil { - return err + return s.getCloseStreamErr() } } return t.controlBuf.put(df) @@ -766,6 +779,7 @@ func (t *http2Client) handleData(f *grpcframe.DataFrame) { } if size > 0 { if err := s.fc.onData(size); err != nil { + klog.CtxErrorf(s.ctx, sendRSTStreamFrameMsgf("KITEX: http2Client.handleData inflow control err: %v, code: %d", err, http2.ErrCodeFlowControl)) t.closeStream(s, io.EOF, true, http2.ErrCodeFlowControl, status.New(codes.Internal, err.Error()), nil, false) return } @@ -983,6 +997,7 @@ func (t *http2Client) operateHeaders(frame *grpcframe.MetaHeadersFrame) { if !initialHeader && !endStream { // As specified by gRPC over HTTP2, a HEADERS frame (and associated CONTINUATION frames) can only appear at the start or end of a stream. Therefore, second HEADERS frame must have EOS bit set. st := status.New(codes.Internal, "a HEADERS frame cannot appear in the middle of a stream") + klog.CtxErrorf(s.ctx, sendRSTStreamFrameMsgf("KITEX: http2Client.operateHeaders received HEADERS frame in the middle of a stream, code: %d", http2.ErrCodeProtocol)) t.closeStream(s, st.Err(), true, http2.ErrCodeProtocol, st, nil, false) return } @@ -991,6 +1006,7 @@ func (t *http2Client) operateHeaders(frame *grpcframe.MetaHeadersFrame) { // Initialize isGRPC value to be !initialHeader, since if a gRPC Response-Headers has already been received, then it means that the peer is speaking gRPC and we are in gRPC mode. state.data.isGRPC = !initialHeader if err := state.decodeHeader(frame); err != nil { + klog.CtxErrorf(s.ctx, sendRSTStreamFrameMsgf("KITEX: http2Client.operateHeaders decode HEADERS frame failed, err: %v, code: %d", err, http2.ErrCodeProtocol)) t.closeStream(s, err, true, http2.ErrCodeProtocol, status.Convert(err), nil, endStream) return } @@ -1034,7 +1050,7 @@ func (t *http2Client) reader() { // Check the validity of server preface. frame, err := t.framer.ReadFrame() if err != nil { - err = connectionErrorf(true, err, "error reading from server, remoteAddress=%s, error=%v", t.conn.RemoteAddr(), err) + err = connectionErrorf(true, err, "error reading from server, remoteAddress=%s, error=%v", t.conn.RemoteAddr(), err) t.Close(err) // this kicks off resetTransport, so must be last before return return } @@ -1073,12 +1089,13 @@ func (t *http2Client) reader() { if err != nil { msg = err.Error() } + klog.CtxErrorf(s.ctx, sendRSTStreamFrameMsgf("KITEX: http2Client.reader encountered http2.StreamError: %v, code: %d", se, http2.ErrCodeProtocol)) t.closeStream(s, status.New(code, msg).Err(), true, http2.ErrCodeProtocol, status.New(code, msg), nil, false) } continue } else { // Transport error. - err = connectionErrorf(true, err, "error reading from server, remoteAddress=%s, error=%v", t.conn.RemoteAddr(), err) + err = connectionErrorf(true, err, "error reading from server, remoteAddress=%s, error=%v", t.conn.RemoteAddr(), err) t.Close(err) return } diff --git a/pkg/remote/trans/nphttp2/grpc/http2_server.go b/pkg/remote/trans/nphttp2/grpc/http2_server.go index c2b84efe13..30597db1f2 100644 --- a/pkg/remote/trans/nphttp2/grpc/http2_server.go +++ b/pkg/remote/trans/nphttp2/grpc/http2_server.go @@ -53,20 +53,21 @@ import ( var ( // ErrIllegalHeaderWrite indicates that setting header is illegal because of // the stream's state. - ErrIllegalHeaderWrite = errors.New("transport: the stream is done or WriteHeader was already called") + ErrIllegalHeaderWrite = errors.New("transport: WriteHeader was already called") + errStatusIllegalHeaderWrite = status.Err(codes.Internal, handlerSideErrMsg(ErrIllegalHeaderWrite.Error())) // ErrHeaderListSizeLimitViolation indicates that the header list size is larger // than the limit set by peer. - ErrHeaderListSizeLimitViolation = errors.New("transport: trying to send header list size larger than the limit set by peer") + ErrHeaderListSizeLimitViolation = errors.New("transport: trying to send header list size larger than the limit set by peer") + errStatusHeaderListSizeLimitViolation = status.Err(codes.Internal, handlerSideErrMsg(ErrHeaderListSizeLimitViolation.Error())) // errors used for cancelling stream. // the code should be codes.Canceled coz it's NOT returned from remote - errConnectionEOF = status.New(codes.Canceled, "transport: connection EOF").Err() - errStreamClosing = status.New(codes.Canceled, "transport: stream is closing").Err() - errMaxStreamsExceeded = status.New(codes.Canceled, "transport: max streams exceeded").Err() - errNotReachable = status.New(codes.Canceled, "transport: server not reachable").Err() - errMaxAgeClosing = status.New(codes.Canceled, "transport: closing server transport due to maximum connection age").Err() - errIdleClosing = status.New(codes.Canceled, "transport: closing server transport due to idleness").Err() + errConnectionEOF = status.Err(codes.Canceled, remoteErrMsg("transport: connection EOF")) + errMaxStreamsExceeded = status.Err(codes.Canceled, remoteErrMsg("transport: max streams exceeded")) + errNotReachable = status.Err(codes.Canceled, remoteErrMsg("transport: server not reachable")) + errMaxAgeClosing = status.Err(codes.Canceled, "transport: closing server transport due to maximum connection age") + errIdleClosing = status.Err(codes.Canceled, remoteErrMsg("transport: closing server transport due to idleness")) ) func init() { @@ -299,10 +300,12 @@ func (t *http2Server) operateHeaders(frame *grpcframe.MetaHeadersFrame, handle f } if err := state.decodeHeader(frame); err != nil { if se, ok := status.FromError(err); ok { + rstCode := statusCodeConvTab[se.Code()] + klog.CtxErrorf(t.ctx, sendRSTStreamFrameMsgf("KITEX: http2Server.operateHeaders failed to decode header frame, err=%v, code: %d", err, rstCode)) t.controlBuf.put(&cleanupStream{ streamID: streamID, rst: true, - rstCode: statusCodeConvTab[se.Code()], + rstCode: rstCode, onWrite: func() {}, }) } @@ -344,13 +347,13 @@ func (t *http2Server) operateHeaders(frame *grpcframe.MetaHeadersFrame, handle f } if uint32(len(t.activeStreams)) >= t.maxStreams { t.mu.Unlock() + klog.CtxErrorf(t.ctx, sendRSTStreamFrameMsgf("KITEX: http2Server.operateHeaders failed to create stream, err=%v, code: %d", errMaxStreamsExceeded, http2.ErrCodeRefusedStream)) t.controlBuf.put(&cleanupStream{ streamID: streamID, rst: true, rstCode: http2.ErrCodeRefusedStream, onWrite: func() {}, }) - s.cancel(errMaxStreamsExceeded) return nil } if streamID%2 != 1 || streamID <= t.maxStreamID { @@ -405,11 +408,10 @@ func (t *http2Server) HandleStreams(handle func(*Stream), traceCtx func(context. s := t.activeStreams[se.StreamID] t.mu.Unlock() if s != nil { - // it will be codes.Internal error for GRPC - // TODO: map http2.StreamError to status.Error? - s.cancel(err) - t.closeStream(s, true, se.Code, false) + klog.CtxErrorf(s.ctx, sendRSTStreamFrameMsgf("KITEX: http2Server.HandleStreams encountered http2.StreamError, err=%v, code: %d", err, se.Code)) + t.closeStream(s, status.Err(codes.Canceled, s.getRemoteErrMsgf("transport: ReadFrame encountered http2.StreamError: %v", err)), true, se.Code, false) } else { + klog.CtxErrorf(t.ctx, sendRSTStreamFrameMsgf("KITEX: http2Server.HandleStreams failed to ReadFrame, err=%v, code: %d", err, se.Code)) t.controlBuf.put(&cleanupStream{ streamID: se.StreamID, rst: true, @@ -424,7 +426,7 @@ func (t *http2Server) HandleStreams(handle func(*Stream), traceCtx func(context. return } klog.CtxWarnf(t.ctx, "transport: http2Server.HandleStreams failed to read frame: %v", err) - t.closeWithErr(err) + t.closeWithErr(status.Err(codes.Canceled, remoteErrMsgf("transport: ReadFrame encountered err: %v", err))) return } switch frame := frame.(type) { @@ -551,7 +553,8 @@ func (t *http2Server) handleData(f *grpcframe.DataFrame) { } if size > 0 { if err := s.fc.onData(size); err != nil { - t.closeStream(s, true, http2.ErrCodeFlowControl, false) + klog.CtxErrorf(s.ctx, sendRSTStreamFrameMsgf("KITEX: http2Server.handleData inflow control err: %v, code: %d", err, http2.ErrCodeFlowControl)) + t.closeStream(s, status.Err(codes.Canceled, s.getRemoteErrMsgf("transport: inflow control err: %v", err)), true, http2.ErrCodeFlowControl, false) return } if f.Header().Flags.Has(http2.FlagDataPadded) { @@ -579,7 +582,7 @@ func (t *http2Server) handleData(f *grpcframe.DataFrame) { func (t *http2Server) handleRSTStream(f *http2.RSTStreamFrame) { // If the stream is not deleted from the transport's active streams map, then do a regular close stream. if s, ok := t.getStream(f); ok { - t.closeStream(s, false, 0, false) + t.closeStream(s, status.Err(codes.Canceled, s.getRemoteErrMsgf("transport: RSTStream Frame received with error code: %v", f.ErrCode)), false, 0, false) return } // If the stream is already deleted from the active streams map, then put a cleanupStream item into controlbuf to delete the stream from loopy writer's established streams map. @@ -589,6 +592,7 @@ func (t *http2Server) handleRSTStream(f *http2.RSTStreamFrame) { rstCode: 0, onWrite: func() {}, }) + // since we do not need to send RstStream Frame, do not add log here } func (t *http2Server) handleSettings(f *grpcframe.SettingsFrame) { @@ -711,8 +715,11 @@ func (t *http2Server) checkForHeaderListSize(it interface{}) bool { // WriteHeader sends the header metadata md back to the client. func (t *http2Server) WriteHeader(s *Stream, md metadata.MD) error { - if s.updateHeaderSent() || s.getState() == streamDone { - return ErrIllegalHeaderWrite + if s.updateHeaderSent() { + return errStatusIllegalHeaderWrite + } + if s.getState() == streamDone { + return ContextErr(s.ctx.Err()) } s.hdrMu.Lock() if md.Len() > 0 { @@ -756,8 +763,9 @@ func (t *http2Server) writeHeaderLocked(s *Stream) error { if err != nil { return err } - t.closeStream(s, true, http2.ErrCodeInternal, false) - return ErrHeaderListSizeLimitViolation + klog.CtxErrorf(s.ctx, sendRSTStreamFrameMsgf("KITEX: http2Server.writeHeaderLocked checkForHeaderListSize failed, code: %d", http2.ErrCodeInternal)) + t.closeStream(s, errStatusHeaderListSizeLimitViolation, true, http2.ErrCodeInternal, false) + return errStatusHeaderListSizeLimitViolation } return nil } @@ -819,7 +827,7 @@ func (t *http2Server) WriteStatus(s *Stream, st *status.Status) error { if err != nil { return err } - t.closeStream(s, true, http2.ErrCodeInternal, false) + t.closeStream(s, errStatusHeaderListSizeLimitViolation, true, http2.ErrCodeInternal, false) return ErrHeaderListSizeLimitViolation } // Send a RST_STREAM after the trailers if the client has not already half-closed. @@ -838,13 +846,6 @@ func (t *http2Server) Write(s *Stream, hdr, data []byte, opts *Options) error { } else { // Writing headers checks for this condition. if s.getState() == streamDone { - // TODO(mmukhi, dfawley): Should the server write also return io.EOF? - s.cancel(errStreamClosing) - select { - case <-t.done: - return ErrConnClosing - default: - } return ContextErr(s.ctx.Err()) } } @@ -856,11 +857,6 @@ func (t *http2Server) Write(s *Stream, hdr, data []byte, opts *Options) error { df.originD = df.d df.resetPingStrikes = &t.resetPingStrikes if err := s.wq.get(int32(len(hdr) + len(data))); err != nil { - select { - case <-t.done: - return ErrConnClosing - default: - } return ContextErr(s.ctx.Err()) } return t.controlBuf.put(df) @@ -976,7 +972,7 @@ func (t *http2Server) closeWithErr(reason error) error { streams := t.activeStreams t.activeStreams = nil t.mu.Unlock() - t.controlBuf.finish() + t.controlBuf.finish(reason) close(t.done) err := t.conn.Close() @@ -990,11 +986,6 @@ func (t *http2Server) closeWithErr(reason error) error { // deleteStream deletes the stream s from transport's active streams. func (t *http2Server) deleteStream(s *Stream, eosReceived bool) { - // In case stream sending and receiving are invoked in separate - // goroutines (e.g., bi-directional streaming), cancel needs to be - // called to interrupt the potential blocking on other goroutines. - s.cancel(nil) // more details about the reason? - t.mu.Lock() if _, ok := t.activeStreams[s.id]; ok { delete(t.activeStreams, s.id) @@ -1012,6 +1003,7 @@ func (t *http2Server) finishStream(s *Stream, rst bool, rstCode http2.ErrCode, h // If the stream was already done, return. return } + s.cancel(nil) hdr.cleanup = &cleanupStream{ streamID: s.id, @@ -1025,7 +1017,11 @@ func (t *http2Server) finishStream(s *Stream, rst bool, rstCode http2.ErrCode, h } // closeStream clears the footprint of a stream when the stream is not needed any more. -func (t *http2Server) closeStream(s *Stream, rst bool, rstCode http2.ErrCode, eosReceived bool) { +func (t *http2Server) closeStream(s *Stream, err error, rst bool, rstCode http2.ErrCode, eosReceived bool) { + // In case stream sending and receiving are invoked in separate + // goroutines (e.g., bi-directional streaming), cancel needs to be + // called to interrupt the potential blocking on other goroutines. + s.cancel(err) s.swapState(streamDone) t.deleteStream(s, eosReceived) diff --git a/pkg/remote/trans/nphttp2/grpc/stream_test.go b/pkg/remote/trans/nphttp2/grpc/stream_test.go new file mode 100644 index 0000000000..fedb24e06e --- /dev/null +++ b/pkg/remote/trans/nphttp2/grpc/stream_test.go @@ -0,0 +1,39 @@ +/* + * Copyright 2024 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package grpc + +import ( + "testing" + + "github.com/cloudwego/kitex/internal/test" +) + +func TestStream_SetSourceService(t *testing.T) { + s := new(Stream) + testSvc := "test service" + testErrMsg := "test err" + testErrMsgTpl := "test err: %s" + testErrMsgVal := "val" + test.Assert(t, s.getSourceService() == "remote service") + test.Assert(t, s.getRemoteErrMsg(testErrMsg) == "test err[triggered by remote service]") + test.Assert(t, s.getRemoteErrMsgf(testErrMsgTpl, testErrMsgVal) == "test err: val[triggered by remote service]") + + s.SetSourceService(testSvc) + test.Assert(t, s.getSourceService() == "test service") + test.Assert(t, s.getRemoteErrMsg(testErrMsg) == "test err[triggered by test service]") + test.Assert(t, s.getRemoteErrMsgf(testErrMsgTpl, testErrMsgVal) == "test err: val[triggered by test service]") +} diff --git a/pkg/remote/trans/nphttp2/grpc/transport.go b/pkg/remote/trans/nphttp2/grpc/transport.go index e2b5f0ef1f..6b43b682f5 100644 --- a/pkg/remote/trans/nphttp2/grpc/transport.go +++ b/pkg/remote/trans/nphttp2/grpc/transport.go @@ -285,6 +285,11 @@ type Stream struct { // contentSubtype is the content-subtype for requests. // this must be lowercase or the behavior is undefined. contentSubtype string + + // closeStreamErr is used to store the error when stream is closed + closeStreamErr atomic.Value + // sourceService is the source service name of this stream + sourceService atomic.Value } // isHeaderSent is only valid on the server-side. @@ -479,6 +484,34 @@ func (s *Stream) Read(p []byte) (n int, err error) { return io.ReadFull(s.trReader, p) } +func (s *Stream) getCloseStreamErr() error { + rawErr := s.closeStreamErr.Load() + if rawErr != nil { + return rawErr.(error) + } + return errStatusStreamDone +} + +func (s *Stream) SetSourceService(svc string) { + s.sourceService.Store(svc) +} + +func (s *Stream) getSourceService() string { + rawSvc := s.sourceService.Load() + if rawSvc != nil { + return rawSvc.(string) + } + return "remote service" +} + +func (s *Stream) getRemoteErrMsg(msg string) string { + return msg + fmt.Sprintf(remoteErrTpl, s.getSourceService()) +} + +func (s *Stream) getRemoteErrMsgf(msg string, a ...interface{}) string { + return fmt.Sprintf(msg, a...) + fmt.Sprintf(remoteErrTpl, s.getSourceService()) +} + // StreamWrite only used for unit test func StreamWrite(s *Stream, buffer *bytes.Buffer) { s.write(recvMsg{buffer: buffer}) @@ -496,10 +529,8 @@ func CreateStream(ctx context.Context, id uint32, requestRead func(i int), metho }, windowHandler: func(i int) {}, } - stream := &Stream{ id: id, - ctx: ctx, method: method, buf: recvBuffer, trReader: trReader, @@ -508,6 +539,9 @@ func CreateStream(ctx context.Context, id uint32, requestRead func(i int), metho hdrMu: sync.Mutex{}, } + ctx, cancel := context.WithCancel(ctx) + stream.ctx, stream.cancel = newContextWithCancelReason(ctx, cancel) + return stream } @@ -762,11 +796,13 @@ func (e ConnectionError) Origin() error { var ( // ErrConnClosing indicates that the transport is closing. - ErrConnClosing = connectionErrorf(true, nil, "transport is closing") + ErrConnClosing = connectionErrorf(true, nil, "transport is closing") + errStatusConnClosing = status.Err(codes.Unavailable, ErrConnClosing.Desc) // errStreamDone is returned from write at the client side to indicate application // layer of an error. - errStreamDone = errors.New("the stream is done") + errStreamDone = errors.New("the stream is done") + errStatusStreamDone = status.Err(codes.Internal, errStreamDone.Error()) // errStreamDrain indicates that the stream is rejected because the // connection is draining. This could be caused by goaway or balancer diff --git a/pkg/remote/trans/nphttp2/grpc/transport_test.go b/pkg/remote/trans/nphttp2/grpc/transport_test.go index ec98259a52..e853db7ab6 100644 --- a/pkg/remote/trans/nphttp2/grpc/transport_test.go +++ b/pkg/remote/trans/nphttp2/grpc/transport_test.go @@ -858,9 +858,7 @@ func TestLargeMessageSuspension(t *testing.T) { msg := make([]byte, initialWindowSize*8) ct.Write(s, nil, msg, &Options{}) err = ct.Write(s, nil, msg, &Options{Last: true}) - if err != errStreamDone { - t.Fatalf("write got %v, want io.EOF", err) - } + test.Assert(t, errors.Is(err, ContextErr(ctx.Err())), err) expectedErr := status.Err(codes.DeadlineExceeded, context.DeadlineExceeded.Error()) if _, err := s.Read(make([]byte, 8)); err.Error() != expectedErr.Error() { t.Fatalf("Read got %v of type %T, want %v", err, err, expectedErr) diff --git a/pkg/remote/trans/nphttp2/server_handler.go b/pkg/remote/trans/nphttp2/server_handler.go index f249f84242..9ed3d6e390 100644 --- a/pkg/remote/trans/nphttp2/server_handler.go +++ b/pkg/remote/trans/nphttp2/server_handler.go @@ -185,6 +185,7 @@ func (t *svrTransHandler) handleFunc(s *grpcTransport.Stream, svrTrans *SvrTrans return } } + s.SetSourceService(ri.From().ServiceName()) rCtx = t.startTracer(rCtx, ri) defer func() { panicErr := recover()