From 3cc884ac895b118d1c0c7b43a8f629721e331d52 Mon Sep 17 00:00:00 2001 From: "yuxuan.wang1" Date: Fri, 13 Sep 2024 16:38:28 +0800 Subject: [PATCH] feat: optimize gRPC error handling --- .../remote/trans/grpc}/status/mock_test.go | 0 internal/remote/trans/grpc/status/status.go | 273 +++++++++++++++ .../remote/trans/grpc}/status/status_test.go | 33 +- pkg/kerrors/kerrors.go | 54 +-- pkg/kerrors/kerrors_test.go | 23 +- pkg/kerrors/streaming_errors.go | 51 +++ pkg/kerrors/streaming_errors_test.go | 53 +++ pkg/remote/trans/nphttp2/codes_test.go | 8 + pkg/remote/trans/nphttp2/conn_pool_test.go | 4 +- pkg/remote/trans/nphttp2/grpc/controlbuf.go | 14 +- .../trans/nphttp2/grpc/controlbuf_test.go | 46 ++- .../trans/nphttp2/grpc/err_handling_test.go | 319 ++++++++++++++++++ pkg/remote/trans/nphttp2/grpc/errors.go | 115 +++++++ pkg/remote/trans/nphttp2/grpc/errors_test.go | 63 ++++ pkg/remote/trans/nphttp2/grpc/http2_client.go | 120 ++++--- pkg/remote/trans/nphttp2/grpc/http2_server.go | 170 ++++++---- pkg/remote/trans/nphttp2/grpc/http_util.go | 171 +++++++++- .../trans/nphttp2/grpc/http_util_test.go | 279 +++++++++++++++ pkg/remote/trans/nphttp2/grpc/transport.go | 103 +++++- .../trans/nphttp2/grpc/transport_test.go | 94 ++++-- pkg/remote/trans/nphttp2/status/status.go | 147 +------- pkg/streamx/provider/grpc/gerrors/gerrors.go | 43 +++ 22 files changed, 1822 insertions(+), 361 deletions(-) rename {pkg/remote/trans/nphttp2 => internal/remote/trans/grpc}/status/mock_test.go (100%) create mode 100644 internal/remote/trans/grpc/status/status.go rename {pkg/remote/trans/nphttp2 => internal/remote/trans/grpc}/status/status_test.go (74%) create mode 100644 pkg/kerrors/streaming_errors.go create mode 100644 pkg/kerrors/streaming_errors_test.go create mode 100644 pkg/remote/trans/nphttp2/grpc/err_handling_test.go create mode 100644 pkg/remote/trans/nphttp2/grpc/errors.go create mode 100644 pkg/remote/trans/nphttp2/grpc/errors_test.go create mode 100644 pkg/streamx/provider/grpc/gerrors/gerrors.go diff --git a/pkg/remote/trans/nphttp2/status/mock_test.go b/internal/remote/trans/grpc/status/mock_test.go similarity index 100% rename from pkg/remote/trans/nphttp2/status/mock_test.go rename to internal/remote/trans/grpc/status/mock_test.go diff --git a/internal/remote/trans/grpc/status/status.go b/internal/remote/trans/grpc/status/status.go new file mode 100644 index 0000000000..8dab33a5e6 --- /dev/null +++ b/internal/remote/trans/grpc/status/status.go @@ -0,0 +1,273 @@ +/* + * + * Copyright 2020 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * This file may have been modified by CloudWeGo authors. All CloudWeGo + * Modifications are Copyright 2021 CloudWeGo Authors. + */ + +// Package status implements errors returned by gRPC. These errors are +// serialized and transmitted on the wire between server and client, and allow +// for additional data to be transmitted via the Details field in the status +// proto. gRPC service handlers should return an error created by this +// package, and gRPC clients should expect a corresponding error to be +// returned from the RPC call. +// +// This package upholds the invariants that a non-nil error may not +// contain an OK code, and an OK code must result in a nil error. +package status + +import ( + "context" + "errors" + "fmt" + + spb "google.golang.org/genproto/googleapis/rpc/status" + "google.golang.org/protobuf/proto" + "google.golang.org/protobuf/types/known/anypb" + + "github.com/cloudwego/kitex/pkg/remote/trans/nphttp2/codes" +) + +type Iface interface { + GRPCStatus() *Status +} + +// Status represents an RPC status code, message, and details. It is immutable +// and should be created with New, Newf, or FromProto. +type Status struct { + s *spb.Status + // kerr is the Kitex custom error that status maps to + kerr error +} + +// New returns a Status representing c and msg. +func New(c codes.Code, msg string) *Status { + return &Status{s: &spb.Status{Code: int32(c), Message: msg}} +} + +// NewWithMappingErr returns as Status representing c and msg with mapping Kitex error +func NewWithMappingErr(c codes.Code, kerr error, msg string) *Status { + st := New(c, msg) + st.kerr = kerr + return st +} + +// Newf returns New(c, fmt.Sprintf(format, a...)). +func Newf(c codes.Code, format string, a ...interface{}) *Status { + return New(c, fmt.Sprintf(format, a...)) +} + +// NewfWithMappingErr return Newf result with mapping Kitex error +func NewfWithMappingErr(c codes.Code, kerr error, format string, a ...interface{}) *Status { + st := Newf(c, format, a...) + st.kerr = kerr + return st +} + +// InjectMappingErr injects mapping Kitex error into Status +func InjectMappingErr(st *Status, kerr error) { + st.kerr = kerr +} + +// ErrorProto returns an error representing s. If s.Code is OK, returns nil. +func ErrorProto(s *spb.Status) error { + return FromProto(s).Err() +} + +// FromProto returns a Status representing s. +func FromProto(s *spb.Status) *Status { + return &Status{s: proto.Clone(s).(*spb.Status)} +} + +// Err returns an error representing c and msg. If c is OK, returns nil. +func Err(c codes.Code, msg string) error { + return New(c, msg).Err() +} + +// Errorf returns Error(c, fmt.Sprintf(format, a...)). +func Errorf(c codes.Code, format string, a ...interface{}) error { + return Err(c, fmt.Sprintf(format, a...)) +} + +// Code returns the status code contained in s. +func (s *Status) Code() codes.Code { + if s == nil || s.s == nil { + return codes.OK + } + return codes.Code(s.s.Code) +} + +// Message returns the message contained in s. +func (s *Status) Message() string { + if s == nil || s.s == nil { + return "" + } + return s.s.Message +} + +// AppendMessage append extra msg for Status +func (s *Status) AppendMessage(extraMsg string) *Status { + if s == nil || s.s == nil || extraMsg == "" { + return s + } + s.s.Message = fmt.Sprintf("%s %s", s.s.Message, extraMsg) + return s +} + +// Proto returns s's status as an spb.Status proto message. +func (s *Status) Proto() *spb.Status { + if s == nil { + return nil + } + return proto.Clone(s.s).(*spb.Status) +} + +// Err returns an immutable error representing s; returns nil if s.Code() is OK. +func (s *Status) Err() error { + if s.Code() == codes.OK { + return nil + } + return &Error{e: s.Proto(), kerr: s.kerr} +} + +// WithDetails returns a new status with the provided details messages appended to the status. +// If any errors are encountered, it returns nil and the first error encountered. +func (s *Status) WithDetails(details ...proto.Message) (*Status, error) { + if s.Code() == codes.OK { + return nil, errors.New("no error details for status with code OK") + } + // s.Code() != OK implies that s.Proto() != nil. + p := s.Proto() + for _, detail := range details { + any, err := anypb.New(detail) + if err != nil { + return nil, err + } + p.Details = append(p.Details, any) + } + return &Status{s: p}, nil +} + +// Details returns a slice of details messages attached to the status. +// If a detail cannot be decoded, the error is returned in place of the detail. +func (s *Status) Details() []interface{} { + if s == nil || s.s == nil { + return nil + } + details := make([]interface{}, 0, len(s.s.Details)) + for _, any := range s.s.Details { + detail, err := any.UnmarshalNew() + if err != nil { + details = append(details, err) + continue + } + details = append(details, detail) + } + return details +} + +// Error wraps a pointer of a status proto. It implements error and Status, +// and a nil *Error should never be returned by this package. +type Error struct { + e *spb.Status + // kerr is the Kitex custom error that status maps to + kerr error +} + +// GetMappingErr returns the Kitex custom error that status Error maps to +func (e *Error) GetMappingErr() error { + return e.kerr +} + +func (e *Error) Error() string { + str := fmt.Sprintf("rpc error: code = %d desc = %s", codes.Code(e.e.GetCode()), e.e.GetMessage()) + if e.kerr == nil { + return str + } + return fmt.Sprintf("[%s] %s", e.kerr.Error(), str) +} + +// GRPCStatus returns the Status represented by se. +func (e *Error) GRPCStatus() *Status { + st := FromProto(e.e) + st.kerr = e.kerr + return st +} + +// Is implements future error.Is functionality. +// A Error is equivalent if the code and message are identical +// or if the underlying mapped kitex error conforms to errors.Is. +func (e *Error) Is(target error) bool { + tse, ok := target.(*Error) + if ok { + return proto.Equal(e.e, tse.e) + } + if e.kerr != nil { + return errors.Is(e.kerr, target) + } + return false +} + +// FromError returns a Status representing err if it was produced from this +// package or has a method `GRPCStatus() *Status`. Otherwise, ok is false and a +// Status is returned with codes.Unknown and the original error message. +func FromError(err error) (s *Status, ok bool) { + if err == nil { + return nil, true + } + var se Iface + if errors.As(err, &se) { + return se.GRPCStatus(), true + } + return New(codes.Unknown, err.Error()), false +} + +// Convert is a convenience function which removes the need to handle the +// boolean return value from FromError. +func Convert(err error) *Status { + s, _ := FromError(err) + return s +} + +// Code returns the Code of the error if it is a Status error, codes.OK if err +// is nil, or codes.Unknown otherwise. +func Code(err error) codes.Code { + // Don't use FromError to avoid allocation of OK status. + if err == nil { + return codes.OK + } + var se Iface + if errors.As(err, &se) { + return se.GRPCStatus().Code() + } + return codes.Unknown +} + +// FromContextError converts a context error into a Status. It returns a +// Status with codes.OK if err is nil, or a Status with codes.Unknown if err is +// non-nil and not a context error. +func FromContextError(err error) *Status { + switch err { + case nil: + return nil + case context.DeadlineExceeded: + return New(codes.DeadlineExceeded, err.Error()) + case context.Canceled: + return New(codes.Canceled, err.Error()) + default: + return New(codes.Unknown, err.Error()) + } +} diff --git a/pkg/remote/trans/nphttp2/status/status_test.go b/internal/remote/trans/grpc/status/status_test.go similarity index 74% rename from pkg/remote/trans/nphttp2/status/status_test.go rename to internal/remote/trans/grpc/status/status_test.go index 08cd55d82f..a7ad8324a7 100644 --- a/pkg/remote/trans/nphttp2/status/status_test.go +++ b/internal/remote/trans/grpc/status/status_test.go @@ -18,7 +18,10 @@ package status import ( "context" + "errors" "fmt" + "reflect" + "strings" "testing" spb "google.golang.org/genproto/googleapis/rpc/status" @@ -63,6 +66,22 @@ func TestStatus(t *testing.T) { statusNilErr, ok := FromError(nil) test.Assert(t, ok) test.Assert(t, statusNilErr == nil) + + mappingErr := errors.New("mappingErr") + oriSt := NewWithMappingErr(codes.Internal, mappingErr, "withMappingErr test") + rawStErr := oriSt.Err() + test.Assert(t, strings.Contains(rawStErr.Error(), mappingErr.Error()), rawStErr) + test.Assert(t, errors.Is(rawStErr, mappingErr), rawStErr) + stErr, ok := rawStErr.(*Error) + test.Assert(t, ok) + test.Assert(t, stErr.GetMappingErr() == mappingErr, stErr.GetMappingErr()) + st0 := stErr.GRPCStatus() + test.Assert(t, reflect.DeepEqual(st0, oriSt), st0) + st1, ok := FromError(rawStErr) + test.Assert(t, ok) + test.Assert(t, reflect.DeepEqual(st1, oriSt), st1) + st2 := Convert(rawStErr) + test.Assert(t, reflect.DeepEqual(st2, oriSt), st1) } func TestError(t *testing.T) { @@ -70,17 +89,19 @@ func TestError(t *testing.T) { s.Code = 1 s.Message = "test err" - er := &Error{s} + kerr := errors.New("kerr") + er := &Error{e: s, kerr: kerr} test.Assert(t, len(er.Error()) > 0) + test.Assert(t, strings.Contains(er.Error(), s.Message), er.Error()) + test.Assert(t, strings.Contains(er.Error(), kerr.Error()), er.Error()) status := er.GRPCStatus() test.Assert(t, status.Message() == s.Message) - is := er.Is(context.Canceled) - test.Assert(t, !is) + test.Assert(t, !er.Is(context.Canceled)) - is = er.Is(er) - test.Assert(t, is) + test.Assert(t, er.Is(er)) + test.Assert(t, er.Is(kerr)) } func TestFromContextError(t *testing.T) { @@ -101,7 +122,7 @@ func TestFromContextError(t *testing.T) { s := new(spb.Status) s.Code = 1 s.Message = "test err" - grpcErr := &Error{s} + grpcErr := &Error{e: s} // grpc err codeGrpcErr := Code(grpcErr) test.Assert(t, codeGrpcErr == codes.Canceled) diff --git a/pkg/kerrors/kerrors.go b/pkg/kerrors/kerrors.go index 783658f27e..e25b2d4407 100644 --- a/pkg/kerrors/kerrors.go +++ b/pkg/kerrors/kerrors.go @@ -26,28 +26,28 @@ import ( // Basic error types var ( - ErrInternalException = &basicError{"internal exception"} - ErrServiceDiscovery = &basicError{"service discovery error"} - ErrGetConnection = &basicError{"get connection error"} - ErrLoadbalance = &basicError{"loadbalance error"} - ErrNoMoreInstance = &basicError{"no more instances to retry"} - ErrRPCTimeout = &basicError{"rpc timeout"} - ErrCanceledByBusiness = &basicError{"canceled by business"} - ErrTimeoutByBusiness = &basicError{"timeout by business"} - ErrACL = &basicError{"request forbidden"} - ErrCircuitBreak = &basicError{"forbidden by circuitbreaker"} - ErrRemoteOrNetwork = &basicError{"remote or network error"} - ErrOverlimit = &basicError{"request over limit"} - ErrPanic = &basicError{"panic"} - ErrBiz = &basicError{"biz error"} - - ErrRetry = &basicError{"retry error"} + ErrInternalException = &basicError{message: "internal exception"} + ErrServiceDiscovery = &basicError{message: "service discovery error"} + ErrGetConnection = &basicError{message: "get connection error"} + ErrLoadbalance = &basicError{message: "loadbalance error"} + ErrNoMoreInstance = &basicError{message: "no more instances to retry"} + ErrRPCTimeout = &basicError{message: "rpc timeout"} + ErrCanceledByBusiness = &basicError{message: "canceled by business"} + ErrTimeoutByBusiness = &basicError{message: "timeout by business"} + ErrACL = &basicError{message: "request forbidden"} + ErrCircuitBreak = &basicError{message: "forbidden by circuitbreaker"} + ErrRemoteOrNetwork = &basicError{message: "remote or network error"} + ErrOverlimit = &basicError{message: "request over limit"} + ErrPanic = &basicError{message: "panic"} + ErrBiz = &basicError{message: "biz error"} + + ErrRetry = &basicError{message: "retry error"} // ErrRPCFinish happens when retry enabled and there is one call has finished - ErrRPCFinish = &basicError{"rpc call finished"} + ErrRPCFinish = &basicError{message: "rpc call finished"} // ErrRoute happens when router fail to route this call - ErrRoute = &basicError{"rpc route failed"} + ErrRoute = &basicError{message: "rpc route failed"} // ErrPayloadValidation happens when payload validation failed - ErrPayloadValidation = &basicError{"payload validation error"} + ErrPayloadValidation = &basicError{message: "payload validation error"} ) // More detailed error types @@ -67,11 +67,20 @@ var ( type basicError struct { message string + // parent basicError + parent error } // Error implements the error interface. func (be *basicError) Error() string { - return be.message + if be.parent == nil { + return be.message + } + return fmt.Sprintf("[%s] %s", be.parent, be.message) +} + +func (be *basicError) Is(target error) bool { + return be == target || errors.Is(be.parent, target) } // WithCause creates a detailed error which attach the given cause to current error. @@ -141,7 +150,7 @@ func (de *DetailedError) Unwrap() error { // Is returns if the given error matches the current error. func (de *DetailedError) Is(target error) bool { - return de == target || de.basic == target || errors.Is(de.cause, target) + return de == target || errors.Is(de.basic, target) || errors.Is(de.cause, target) } // As returns if the given target matches the current error, if so sets @@ -190,7 +199,8 @@ func IsKitexError(err error) bool { if _, ok := err.(*DetailedError); ok { return true } - return false + + return IsStreamingError(err) } // TimeoutCheckFunc is used to check whether the given err is a timeout error. diff --git a/pkg/kerrors/kerrors_test.go b/pkg/kerrors/kerrors_test.go index 1cd65cb257..6a5053304c 100644 --- a/pkg/kerrors/kerrors_test.go +++ b/pkg/kerrors/kerrors_test.go @@ -47,6 +47,17 @@ func TestIsKitexError(t *testing.T) { ErrNoMoreInstance, ErrConnOverLimit, ErrQPSOverLimit, + // streaming errors + errStreaming, + ErrStreamingProtocol, + errStreamingTimeout, + ErrStreamTimeout, + ErrStreamingCanceled, + ErrBizCanceled, + ErrGracefulShutdown, + ErrHandlerReturn, + errStreamingMeta, + ErrMetaSizeExceeded, } for _, e := range errs { test.Assert(t, IsKitexError(e)) @@ -76,7 +87,7 @@ func TestIs(t *testing.T) { func TestError(t *testing.T) { basic := "basic" extra := "extra" - be := &basicError{basic} + be := &basicError{message: basic} test.Assert(t, be.Error() == basic) detailedMsg := appendErrMsg(basic, extra) test.Assert(t, (&DetailedError{basic: be, extraMsg: extra}).Error() == detailedMsg) @@ -84,7 +95,7 @@ func TestError(t *testing.T) { func TestWithCause(t *testing.T) { ae := errors.New("any error") - be := &basicError{"basic"} + be := &basicError{message: "basic"} de := be.WithCause(ae) test.Assert(t, be.Error() == "basic") @@ -102,7 +113,7 @@ func TestWithCause(t *testing.T) { func TestWithCauseAndStack(t *testing.T) { ae := errors.New("any error") - be := &basicError{"basic"} + be := &basicError{message: "basic"} stack := string(debug.Stack()) de := be.WithCauseAndStack(ae, stack) @@ -135,7 +146,7 @@ func TestTimeout(t *testing.T) { return os.IsTimeout(err) } - ke = &basicError{"non-timeout"} + ke = &basicError{message: "non-timeout"} TimeoutCheckFunc = osCheck test.Assert(t, !IsTimeoutError(ke)) TimeoutCheckFunc = nil @@ -173,7 +184,7 @@ func TestTimeout(t *testing.T) { } func TestWithCause1(t *testing.T) { - ae := &basicError{"basic"} + ae := &basicError{message: "basic"} be := ErrRPCTimeout.WithCause(ae) if e2, ok := be.(*DetailedError); ok { e2.WithExtraMsg("retry circuite break") @@ -218,7 +229,7 @@ func BenchmarkWithCause3(b *testing.B) { b.ReportAllocs() b.ResetTimer() for i := 0; i < b.N; i++ { - ae := &basicError{"basic"} + ae := &basicError{message: "basic"} be := ErrRPCTimeout.WithCause(ae) if e2, ok := be.(*DetailedError); ok { e2.WithExtraMsg("测试") diff --git a/pkg/kerrors/streaming_errors.go b/pkg/kerrors/streaming_errors.go new file mode 100644 index 0000000000..55f78188b9 --- /dev/null +++ b/pkg/kerrors/streaming_errors.go @@ -0,0 +1,51 @@ +/* + * Copyright 2024 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package kerrors + +import "errors" + +var ( + // errStreaming is the parent type of all streaming errors. + errStreaming = &basicError{message: "Streaming"} + + // ErrStreamingProtocol is the parent type of all streaming protocol(e.g. gRPC, TTHeader Streaming) + // related but not user-aware errors. + ErrStreamingProtocol = &basicError{message: "protocol error", parent: errStreaming} + + // errStreamingTimeout is the parent type of all streaming timeout errors. + errStreamingTimeout = &basicError{message: "timeout error", parent: errStreaming} + // ErrStreamTimeout denotes the timeout of the whole stream. + ErrStreamTimeout = errStreamingTimeout.WithCause(errors.New("stream timeout")) + + // ErrStreamingCanceled is the parent type of all streaming canceled errors. + ErrStreamingCanceled = &basicError{message: "canceled error", parent: errStreaming} + // ErrBizCanceled denotes the stream is canceled by the biz code invoking cancel(). + ErrBizCanceled = ErrStreamingCanceled.WithCause(errors.New("business canceled")) + // ErrGracefulShutdown denotes the stream is canceled due to graceful shutdown. + ErrGracefulShutdown = ErrStreamingCanceled.WithCause(errors.New("graceful shutdown")) + ErrHandlerReturn = ErrStreamingCanceled.WithCause(errors.New("handler return")) + + // errStreamingMeta is the parent type of all streaming meta errors. + errStreamingMeta = &basicError{message: "meta error", parent: errStreaming} + // ErrMetaSizeExceeded denotes the streaming meta size exceeds the limit. + ErrMetaSizeExceeded = errStreamingMeta.WithCause(errors.New("size exceeds limit")) +) + +// IsStreamingError reports whether the given err is a streaming err +func IsStreamingError(err error) bool { + return errors.Is(err, errStreaming) +} diff --git a/pkg/kerrors/streaming_errors_test.go b/pkg/kerrors/streaming_errors_test.go new file mode 100644 index 0000000000..4e6ae746f4 --- /dev/null +++ b/pkg/kerrors/streaming_errors_test.go @@ -0,0 +1,53 @@ +/* + * Copyright 2024 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package kerrors + +import ( + "errors" + "testing" + + "github.com/cloudwego/kitex/internal/test" +) + +func TestIsStreamingError(t *testing.T) { + errs := []error{ + errStreaming, + ErrStreamingProtocol, + errStreamingTimeout, + ErrStreamTimeout, + ErrStreamingCanceled, + ErrBizCanceled, + ErrGracefulShutdown, + ErrHandlerReturn, + errStreamingMeta, + ErrMetaSizeExceeded, + } + for _, err := range errs { + test.Assert(t, IsStreamingError(err), err) + } +} + +func Test_streamingErr_inheritance(t *testing.T) { + // streaming timeout + test.Assert(t, errors.Is(ErrStreamTimeout, errStreamingTimeout)) + // streaming canceled + test.Assert(t, errors.Is(ErrBizCanceled, ErrStreamingCanceled)) + test.Assert(t, errors.Is(ErrGracefulShutdown, ErrStreamingCanceled)) + test.Assert(t, errors.Is(ErrHandlerReturn, ErrStreamingCanceled)) + // streaming meta + test.Assert(t, errors.Is(ErrMetaSizeExceeded, errStreamingMeta)) +} diff --git a/pkg/remote/trans/nphttp2/codes_test.go b/pkg/remote/trans/nphttp2/codes_test.go index 49dbc4f8be..e49d6d04a9 100644 --- a/pkg/remote/trans/nphttp2/codes_test.go +++ b/pkg/remote/trans/nphttp2/codes_test.go @@ -18,8 +18,10 @@ package nphttp2 import ( "errors" + "reflect" "testing" + istatus "github.com/cloudwego/kitex/internal/remote/trans/grpc/status" "github.com/cloudwego/kitex/internal/test" "github.com/cloudwego/kitex/pkg/kerrors" "github.com/cloudwego/kitex/pkg/remote" @@ -101,4 +103,10 @@ func TestConvertStatus(t *testing.T) { t.Run("user-defined-error", func(t *testing.T) { test.Assert(t, convertStatus(&mockError{}).Code() == codes.Internal) }) + + t.Run("GRPCStatusError with mapping Err", func(t *testing.T) { + oriSt := istatus.NewWithMappingErr(codes.Internal, errors.New("test"), "test") + st := convertStatus(oriSt.Err()) + test.Assert(t, reflect.DeepEqual(st, oriSt), st) + }) } diff --git a/pkg/remote/trans/nphttp2/conn_pool_test.go b/pkg/remote/trans/nphttp2/conn_pool_test.go index 21d20867a7..bc5f093f03 100644 --- a/pkg/remote/trans/nphttp2/conn_pool_test.go +++ b/pkg/remote/trans/nphttp2/conn_pool_test.go @@ -24,6 +24,8 @@ import ( "github.com/cloudwego/kitex/internal/test" "github.com/cloudwego/kitex/pkg/remote" + "github.com/cloudwego/kitex/pkg/remote/trans/nphttp2/codes" + "github.com/cloudwego/kitex/pkg/remote/trans/nphttp2/status" ) func TestConnPool(t *testing.T) { @@ -77,7 +79,7 @@ func TestReleaseConn(t *testing.T) { // close stream to ensure no active stream on this connection, // which will be released when put back to the connection pool and closed by GracefulClose s := conn.(*clientConn).s - conn.(*clientConn).tr.CloseStream(s, nil) + conn.(*clientConn).tr.CloseStream(s, status.Err(codes.Internal, "test")) test.Assert(t, err == nil, err) time.Sleep(100 * time.Millisecond) shortCP.Put(conn) diff --git a/pkg/remote/trans/nphttp2/grpc/controlbuf.go b/pkg/remote/trans/nphttp2/grpc/controlbuf.go index 7dd4701124..b77fb4ccee 100644 --- a/pkg/remote/trans/nphttp2/grpc/controlbuf.go +++ b/pkg/remote/trans/nphttp2/grpc/controlbuf.go @@ -451,19 +451,20 @@ func (c *controlBuffer) get(block bool) (interface{}, error) { select { case <-c.ch: case <-c.done: - c.finish() - return nil, ErrConnClosing + return nil, c.finish(errStatusControlBufFinished) } } } -func (c *controlBuffer) finish() { +func (c *controlBuffer) finish(err error) (rErr error) { c.mu.Lock() if c.err != nil { + rErr = c.err c.mu.Unlock() return } - c.err = ErrConnClosing + c.err = err + rErr = err // There may be headers for streams in the control buffer. // These streams need to be cleaned out since the transport // is still not aware of these yet. @@ -473,10 +474,11 @@ func (c *controlBuffer) finish() { continue } if hdr.onOrphaned != nil { // It will be nil on the server-side. - hdr.onOrphaned(ErrConnClosing) + hdr.onOrphaned(err) } } c.mu.Unlock() + return } type side int @@ -696,7 +698,7 @@ func (l *loopyWriter) originateStream(str *outStream) error { if err == ErrConnClosing { return err } - // Other errors(errStreamDrain) need not close transport. + // Other errors(errStatusStreamDrain) need not close transport. return nil } if err := l.writeHeader(str.id, hdr.endStream, hdr.hf, hdr.onWrite); err != nil { diff --git a/pkg/remote/trans/nphttp2/grpc/controlbuf_test.go b/pkg/remote/trans/nphttp2/grpc/controlbuf_test.go index 643a4e3de5..d69e660f94 100644 --- a/pkg/remote/trans/nphttp2/grpc/controlbuf_test.go +++ b/pkg/remote/trans/nphttp2/grpc/controlbuf_test.go @@ -18,6 +18,7 @@ package grpc import ( "context" + "errors" "testing" "time" @@ -25,7 +26,7 @@ import ( ) func TestControlBuf(t *testing.T) { - ctx := context.Background() + ctx, cancel := context.WithCancel(context.Background()) cb := newControlBuffer(ctx.Done()) // test put() @@ -52,7 +53,8 @@ func TestControlBuf(t *testing.T) { test.Assert(t, !success, err) // test throttle() mock a lot of response frame so throttle() will block current goroutine - for i := 0; i < maxQueuedTransportResponseFrames+5; i++ { + exceedSize := 5 + for i := 0; i < maxQueuedTransportResponseFrames+exceedSize; i++ { err := cb.put(&ping{}) test.Assert(t, err == nil, err) } @@ -60,16 +62,44 @@ func TestControlBuf(t *testing.T) { // start a new goroutine to consume response frame go func() { time.Sleep(time.Millisecond * 100) - for { + for i := 0; i < exceedSize+1; i++ { it, err := cb.get(false) - if err != nil || it == nil { - break - } + test.Assert(t, err == nil, err) + test.Assert(t, it != nil) } }() cb.throttle() + // consumes all of the frames + for { + it, err := cb.get(false) + if err != nil || it == nil { + break + } + } + + finishErr := errors.New("finish") + go func() { + ticker := time.NewTicker(10 * time.Millisecond) + defer ticker.Stop() + for range ticker.C { + var block bool + cb.mu.Lock() + block = cb.consumerWaiting + cb.mu.Unlock() + if block { + cb.finish(finishErr) + cancel() + return + } + } + }() + item, err = cb.get(true) + test.Assert(t, err == finishErr, err) + test.Assert(t, item == nil, item) - // test finish() - cb.finish() + err = cb.put(testItem) + test.Assert(t, err == finishErr, err) + _, err = cb.get(false) + test.Assert(t, err == finishErr, err) } diff --git a/pkg/remote/trans/nphttp2/grpc/err_handling_test.go b/pkg/remote/trans/nphttp2/grpc/err_handling_test.go new file mode 100644 index 0000000000..7ca0fad4e8 --- /dev/null +++ b/pkg/remote/trans/nphttp2/grpc/err_handling_test.go @@ -0,0 +1,319 @@ +/* + * + * Copyright 2024 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * This file may have been modified by CloudWeGo authors. All CloudWeGo + * Modifications are Copyright 2021 CloudWeGo Authors. + */ + +package grpc + +import ( + "context" + "errors" + "io" + "net" + "sync" + "testing" + "time" + + "github.com/cloudwego/netpoll" + "golang.org/x/net/http2/hpack" + + istatus "github.com/cloudwego/kitex/internal/remote/trans/grpc/status" + "github.com/cloudwego/kitex/internal/test" + "github.com/cloudwego/kitex/pkg/kerrors" + "github.com/cloudwego/kitex/pkg/remote/trans/nphttp2/codes" + "github.com/cloudwego/kitex/pkg/remote/trans/nphttp2/metadata" +) + +const ( + testTypeKey string = "testType" + errBizCanceledVal string = "errBizCanceledVal" + errMiddleHeaderVal string = "errMiddleHeaderVal" + errDecodeHeaderVal string = "errDecodeHeaderVal" + errHTTP2StreamVal string = "errHTTP2StreamVal" + errClosedWithoutTrailerVal string = "errClosedWithoutTrailerVal" + errRecvRstStreamVal string = "errRecvRstStreamVal" +) + +type expectedErrs map[string]struct { + cliErr error + srvErr error +} + +func (errs expectedErrs) getClientExpectedErr(testType string) error { + return errs[testType].cliErr +} + +func (errs expectedErrs) getServerExpectedErr(testType string) error { + return errs[testType].srvErr +} + +func TestErrorHandling(t *testing.T) { + t.Run("close stream", func(t *testing.T) { + testcases := []struct { + desc string + setup func(t *testing.T) + clean func(t *testing.T) + customRstCodeMapping map[uint32]error + errs expectedErrs + }{ + { + desc: "normal RstCode", + errs: expectedErrs{ + errBizCanceledVal: {kerrors.ErrBizCanceled, errRecvRstStream}, + errMiddleHeaderVal: {errMiddleHeader, errRecvRstStream}, + errDecodeHeaderVal: {errDecodeHeader, errRecvRstStream}, + errHTTP2StreamVal: {errHTTP2Stream, errRecvRstStream}, + errClosedWithoutTrailerVal: {errClosedWithoutTrailer, nil}, + errRecvRstStreamVal: {errRecvRstStream, nil}, + }, + }, + { + desc: "custom RstCode", + setup: customSetup, + clean: customClean, + errs: expectedErrs{ + errBizCanceledVal: {kerrors.ErrBizCanceled, kerrors.ErrBizCanceled}, + errMiddleHeaderVal: {errMiddleHeader, errMiddleHeader}, + errDecodeHeaderVal: {errDecodeHeader, errDecodeHeader}, + errHTTP2StreamVal: {errHTTP2Stream, errHTTP2Stream}, + errClosedWithoutTrailerVal: {errClosedWithoutTrailer, nil}, + errRecvRstStreamVal: {kerrors.ErrGracefulShutdown, nil}, + }, + }, + } + + for _, tc := range testcases { + t.Run(tc.desc, func(t *testing.T) { + if tc.setup != nil { + tc.setup(t) + } + if tc.clean != nil { + defer func() { + tc.clean(t) + }() + } + lis, err := netpoll.CreateListener("tcp", "localhost:0") + test.Assert(t, err == nil, err) + _, port, err := net.SplitHostPort(lis.Addr().String()) + test.Assert(t, err == nil, err) + cfg := &ServerConfig{} + var wg sync.WaitGroup + var onConnect netpoll.OnConnect = func(ctx context.Context, conn netpoll.Connection) context.Context { + rawSrv, err := newHTTP2Server(ctx, conn, cfg) + srv := rawSrv.(*http2Server) + test.Assert(t, err == nil, err) + srv.HandleStreams(func(stream *Stream) { + wg.Add(1) + go func() { + defer wg.Done() + md, ok := metadata.FromIncomingContext(stream.Context()) + test.Assert(t, ok) + vals := md.Get(testTypeKey) + test.Assert(t, len(vals) == 1, md) + testType := vals[0] + expectedErr := tc.errs.getServerExpectedErr(testType) + switch testType { + case errMiddleHeaderVal: + errMiddleHeaderHandler(t, srv, stream, expectedErr) + case errDecodeHeaderVal: + errDecodeHeaderHandler(t, srv, stream, expectedErr) + case errHTTP2StreamVal: + errHTTP2StreamHandler(t, srv, stream, expectedErr) + case errBizCanceledVal: + errBizCanceledHandler(t, srv, stream, expectedErr) + case errClosedWithoutTrailerVal: + errClosedWithoutTrailerHandler(t, srv, stream, expectedErr) + case errRecvRstStreamVal: + errRecvRstStreamHandler(t, srv, stream, expectedErr) + } + }() + }, func(ctx context.Context, s string) context.Context { + return ctx + }) + return nil + } + eventloop, err := netpoll.NewEventLoop(nil, + netpoll.WithOnConnect(onConnect), + netpoll.WithIdleTimeout(10*time.Second), + ) + test.Assert(t, err == nil, err) + go func() { + eventloop.Serve(lis) + }() + // create http2Client + conn, dErr := netpoll.NewDialer().DialTimeout("tcp", "localhost:"+port, time.Second) + test.Assert(t, dErr == nil, dErr) + cli, cErr := newHTTP2Client(context.Background(), conn.(netpoll.Connection), ConnectOptions{}, "", func(GoAwayReason) {}, func() {}) + test.Assert(t, cErr == nil, cErr) + defer func() { + wg.Wait() + cli.Close(istatus.Err(codes.Internal, "test")) + eventloop.Shutdown(context.Background()) + }() + callHdr := &CallHdr{ + Host: "host", + Method: "method", + } + buf := make([]byte, 1) + t.Run("Headers Frame appeared in the middle of the stream", func(t *testing.T) { + testType := errMiddleHeaderVal + expectedErr := tc.errs.getClientExpectedErr(testType) + ctx := metadata.AppendToOutgoingContext(context.Background(), testTypeKey, testType) + stream, err := cli.NewStream(ctx, callHdr) + test.Assert(t, err == nil, err) + _, err = stream.Header() + test.Assert(t, err == nil, err) + _, recvErr := stream.Read(buf) + test.Assert(t, errors.Is(recvErr, expectedErr), recvErr) + }) + t.Run("Decode Headers Frame failed", func(t *testing.T) { + testType := errDecodeHeaderVal + expectedErr := tc.errs.getClientExpectedErr(testType) + ctx := metadata.AppendToOutgoingContext(context.Background(), testTypeKey, errDecodeHeaderVal) + stream, err := cli.NewStream(ctx, callHdr) + test.Assert(t, err == nil, err) + _, recvErr := stream.Read(buf) + test.Assert(t, errors.Is(recvErr, expectedErr), recvErr) + }) + t.Run("HTTP2Stream err when parsing frame", func(t *testing.T) { + testType := errMiddleHeaderVal + expectedErr := tc.errs.getClientExpectedErr(testType) + ctx := metadata.AppendToOutgoingContext(context.Background(), testTypeKey, testType) + stream, err := cli.NewStream(ctx, callHdr) + test.Assert(t, err == nil, err) + _, recvErr := stream.Read(buf) + test.Assert(t, errors.Is(recvErr, expectedErr), recvErr) + }) + t.Run("Biz context canceled", func(t *testing.T) { + testType := errBizCanceledVal + expectedErr := tc.errs.getClientExpectedErr(testType) + ctx := metadata.AppendToOutgoingContext(context.Background(), testTypeKey, testType) + ctx, cancel := context.WithCancel(ctx) + stream, err := cli.NewStream(ctx, callHdr) + test.Assert(t, err == nil, err) + _, err = stream.Header() + test.Assert(t, err == nil, err) + cancel() + _, recvErr := stream.Read(buf) + test.Assert(t, errors.Is(recvErr, expectedErr), recvErr) + }) + t.Run("Stream closed without trailer frame", func(t *testing.T) { + testType := errMiddleHeaderVal + expectedErr := tc.errs.getClientExpectedErr(testType) + ctx := metadata.AppendToOutgoingContext(context.Background(), testTypeKey, testType) + stream, err := cli.NewStream(ctx, callHdr) + test.Assert(t, err == nil, err) + _, err = stream.Header() + test.Assert(t, err == nil, err) + _, recvErr := stream.Read(buf) + if errors.Is(recvErr, io.EOF) { + recvErr = stream.Status().Err() + } + test.Assert(t, errors.Is(recvErr, expectedErr), recvErr) + }) + t.Run("Receive RstStream Frame", func(t *testing.T) { + testType := errRecvRstStreamVal + expectedErr := tc.errs.getClientExpectedErr(testType) + ctx := metadata.AppendToOutgoingContext(context.Background(), testTypeKey, testType) + stream, err := cli.NewStream(ctx, callHdr) + test.Assert(t, err == nil, err) + _, err = stream.Header() + test.Assert(t, err == nil, err) + _, recvErr := stream.Read(buf) + if errors.Is(recvErr, io.EOF) { + recvErr = stream.Status().Err() + } + test.Assert(t, errors.Is(recvErr, expectedErr), recvErr) + }) + }) + } + }) +} + +func errMiddleHeaderHandler(t *testing.T, srv *http2Server, stream *Stream, expectedErr error) { + var err error + buf := make([]byte, 5) + err = stream.SendHeader(nil) + test.Assert(t, err == nil, err) + err = srv.controlBuf.put(&headerFrame{ + streamID: stream.id, + endStream: false, + }) + test.Assert(t, err == nil, err) + _, recvErr := stream.Read(buf) + test.Assert(t, errors.Is(recvErr, expectedErr), recvErr) +} + +func errDecodeHeaderHandler(t *testing.T, srv *http2Server, stream *Stream, expectedErr error) { + var err error + buf := make([]byte, 5) + err = srv.controlBuf.put(&headerFrame{ + streamID: stream.id, + }) + test.Assert(t, err == nil, err) + _, recvErr := stream.Read(buf) + test.Assert(t, errors.Is(recvErr, expectedErr), recvErr) +} + +func errHTTP2StreamHandler(t *testing.T, srv *http2Server, stream *Stream, expectedErr error) { + var err error + buf := make([]byte, 5) + err = srv.controlBuf.put(&headerFrame{ + streamID: stream.id, + // regular header field is previous to pseudo header field + hf: []hpack.HeaderField{ + {Name: "key", Value: "val"}, + {Name: ":status", Value: "200"}, + }, + }) + test.Assert(t, err == nil, err) + _, recvErr := stream.Read(buf) + test.Assert(t, errors.Is(recvErr, expectedErr), recvErr) +} + +func errBizCanceledHandler(t *testing.T, srv *http2Server, stream *Stream, expectedErr error) { + var err error + buf := make([]byte, 5) + err = stream.SendHeader(nil) + test.Assert(t, err == nil, err) + _, recvErr := stream.Read(buf) + test.Assert(t, errors.Is(recvErr, expectedErr), recvErr) +} + +func errClosedWithoutTrailerHandler(t *testing.T, srv *http2Server, stream *Stream, expectedErr error) { + var err error + err = stream.SendHeader(nil) + test.Assert(t, err == nil, err) + err = srv.controlBuf.put(&dataFrame{ + streamID: stream.id, + endStream: true, + }) + test.Assert(t, err == nil, err) +} + +func errRecvRstStreamHandler(t *testing.T, srv *http2Server, stream *Stream, expectedErr error) { + err := stream.SendHeader(nil) + test.Assert(t, err == nil, err) + err = srv.controlBuf.put(&cleanupStream{ + streamID: stream.id, + rst: true, + rstCode: getRstCode(newStatus(codes.Unavailable, kerrors.ErrGracefulShutdown, "test").Err()), + onWrite: func() {}, + }) + test.Assert(t, err == nil, err) +} diff --git a/pkg/remote/trans/nphttp2/grpc/errors.go b/pkg/remote/trans/nphttp2/grpc/errors.go new file mode 100644 index 0000000000..83e5410fbc --- /dev/null +++ b/pkg/remote/trans/nphttp2/grpc/errors.go @@ -0,0 +1,115 @@ +/* + * + * Copyright 2024 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * This file may have been modified by CloudWeGo authors. All CloudWeGo + * Modifications are Copyright 2021 CloudWeGo Authors. + */ + +package grpc + +import ( + "fmt" + + "github.com/cloudwego/kitex/pkg/kerrors" +) + +type ErrType = uint32 + +const ( + ErrTypeGracefulShutdown = iota + 1 + ErrTypeBizCanceled + ErrTypeHandlerReturn + ErrTypeStreamingCanceled + ErrTypeStreamTimeout + ErrTypeMetaSizeExceeded + + ErrTypeHTTP2Stream + ErrTypeClosedWithoutTrailer + ErrTypeMiddleHeader + ErrTypeDecodeHeader + ErrTypeRecvRstStream + ErrTypeStreamDrain + ErrTypeStreamFlowControl + ErrTypeIllegalHeaderWrite + ErrTypeStreamIsDone + ErrTypeMaxStreamExceeded + + ErrTypeHTTP2Connection + ErrTypeEstablishConnection + ErrTypeHandleGoAway + ErrTypeKeepAlive + ErrTypeOperateHeaders + ErrTypeNoActiveStream + ErrTypeControlBufFinished + ErrTypeNotReachable + ErrTypeConnectionIsClosing +) + +// This file contains all the errors suitable for Kitex errors model. +var ( + // stream error + errHTTP2Stream = fmt.Errorf("%w - %s", kerrors.ErrStreamingProtocol, "HTTP2Stream err when parsing HTTP2 frame") + errClosedWithoutTrailer = fmt.Errorf("%w - %s", kerrors.ErrStreamingProtocol, "client received Data frame with END_STREAM flag") + errMiddleHeader = fmt.Errorf("%w - %s", kerrors.ErrStreamingProtocol, "Headers frame appeared in the middle of a stream") + errDecodeHeader = fmt.Errorf("%w - %s", kerrors.ErrStreamingProtocol, "decoded Headers frame failed") + errRecvRstStream = fmt.Errorf("%w - %s", kerrors.ErrStreamingProtocol, "received RstStream frame") + errStreamDrain = fmt.Errorf("%w - %s", kerrors.ErrStreamingProtocol, "stream rejected by draining connection") + errStreamFlowControl = fmt.Errorf("%w - %s", kerrors.ErrStreamingProtocol, "stream-level flow control") + errIllegalHeaderWrite = fmt.Errorf("%w - %s", kerrors.ErrStreamingProtocol, "Headers frame has been already sent by server") + errStreamIsDone = fmt.Errorf("%w - %s", kerrors.ErrStreamingProtocol, "stream is done") + errMaxStreamExceeded = fmt.Errorf("%w - %s", kerrors.ErrStreamingProtocol, "max stream exceeded") + + // connection error + errHTTP2Connection = fmt.Errorf("%w - %s", kerrors.ErrStreamingProtocol, "HTTP2Connection err when parsing HTTP2 frame") + errEstablishConnection = fmt.Errorf("%w - %s", kerrors.ErrStreamingProtocol, "established connection failed") + errHandleGoAway = fmt.Errorf("%w - %s", kerrors.ErrStreamingProtocol, "handled GoAway Frame failed") + errKeepAlive = fmt.Errorf("%w - %s", kerrors.ErrStreamingProtocol, "keepalive failed") + errOperateHeaders = fmt.Errorf("%w - %s", kerrors.ErrStreamingProtocol, "operated Headers Frame failed") + errNoActiveStream = fmt.Errorf("%w - %s", kerrors.ErrStreamingProtocol, "no active stream") + errControlBufFinished = fmt.Errorf("%w - %s", kerrors.ErrStreamingProtocol, "controlbuf finished") + errNotReachable = fmt.Errorf("%w - %s", kerrors.ErrStreamingProtocol, "server transport is not reachable") + errConnectionIsClosing = fmt.Errorf("%w - %s", kerrors.ErrStreamingProtocol, "connection is closing") +) + +var errType2ErrMap = map[ErrType]error{ + ErrTypeGracefulShutdown: kerrors.ErrGracefulShutdown, + ErrTypeBizCanceled: kerrors.ErrBizCanceled, + ErrTypeHandlerReturn: kerrors.ErrHandlerReturn, + ErrTypeStreamingCanceled: kerrors.ErrStreamingCanceled, + ErrTypeStreamTimeout: kerrors.ErrStreamTimeout, + ErrTypeMetaSizeExceeded: kerrors.ErrMetaSizeExceeded, + // stream error type + ErrTypeHTTP2Stream: errHTTP2Stream, + ErrTypeClosedWithoutTrailer: errClosedWithoutTrailer, + ErrTypeMiddleHeader: errMiddleHeader, + ErrTypeDecodeHeader: errDecodeHeader, + ErrTypeRecvRstStream: errRecvRstStream, + ErrTypeStreamDrain: errStreamDrain, + ErrTypeStreamFlowControl: errStreamFlowControl, + ErrTypeIllegalHeaderWrite: errIllegalHeaderWrite, + ErrTypeStreamIsDone: errStreamIsDone, + ErrTypeMaxStreamExceeded: errMaxStreamExceeded, + // connection error type + ErrTypeHTTP2Connection: errHTTP2Connection, + ErrTypeEstablishConnection: errEstablishConnection, + ErrTypeHandleGoAway: errHandleGoAway, + ErrTypeKeepAlive: errKeepAlive, + ErrTypeOperateHeaders: errOperateHeaders, + ErrTypeNoActiveStream: errNoActiveStream, + ErrTypeControlBufFinished: errControlBufFinished, + ErrTypeNotReachable: errNotReachable, + ErrTypeConnectionIsClosing: errConnectionIsClosing, +} diff --git a/pkg/remote/trans/nphttp2/grpc/errors_test.go b/pkg/remote/trans/nphttp2/grpc/errors_test.go new file mode 100644 index 0000000000..aa5e005e36 --- /dev/null +++ b/pkg/remote/trans/nphttp2/grpc/errors_test.go @@ -0,0 +1,63 @@ +/* + * + * Copyright 2024 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * This file may have been modified by CloudWeGo authors. All CloudWeGo + * Modifications are Copyright 2021 CloudWeGo Authors. + */ + +package grpc + +import ( + "errors" + "strings" + "testing" + + "github.com/cloudwego/kitex/internal/test" + "github.com/cloudwego/kitex/pkg/kerrors" +) + +var errs = []error{ + // stream error + errHTTP2Stream, + errClosedWithoutTrailer, + errMiddleHeader, + errDecodeHeader, + errRecvRstStream, + errStreamDrain, + errStreamFlowControl, + errIllegalHeaderWrite, + errStreamIsDone, + errMaxStreamExceeded, + // connection error + errHTTP2Connection, + errEstablishConnection, + errHandleGoAway, + errKeepAlive, + errOperateHeaders, + errNoActiveStream, + errControlBufFinished, + errNotReachable, + errConnectionIsClosing, +} + +func Test_err(t *testing.T) { + for _, err := range errs { + test.Assert(t, errors.Is(err, kerrors.ErrStreamingProtocol), err) + test.Assert(t, kerrors.IsKitexError(err), err) + test.Assert(t, kerrors.IsStreamingError(err), err) + test.Assert(t, strings.Contains(err.Error(), kerrors.ErrStreamingProtocol.Error()), err) + } +} diff --git a/pkg/remote/trans/nphttp2/grpc/http2_client.go b/pkg/remote/trans/nphttp2/grpc/http2_client.go index 58b2eaba01..afc524240c 100644 --- a/pkg/remote/trans/nphttp2/grpc/http2_client.go +++ b/pkg/remote/trans/nphttp2/grpc/http2_client.go @@ -36,14 +36,15 @@ import ( "golang.org/x/net/http2" "golang.org/x/net/http2/hpack" + istatus "github.com/cloudwego/kitex/internal/remote/trans/grpc/status" "github.com/cloudwego/kitex/pkg/gofunc" + "github.com/cloudwego/kitex/pkg/kerrors" "github.com/cloudwego/kitex/pkg/klog" "github.com/cloudwego/kitex/pkg/remote/trans/nphttp2/codes" "github.com/cloudwego/kitex/pkg/remote/trans/nphttp2/grpc/grpcframe" "github.com/cloudwego/kitex/pkg/remote/trans/nphttp2/grpc/syscall" "github.com/cloudwego/kitex/pkg/remote/trans/nphttp2/metadata" "github.com/cloudwego/kitex/pkg/remote/trans/nphttp2/peer" - "github.com/cloudwego/kitex/pkg/remote/trans/nphttp2/status" ) // http2Client implements the ClientTransport interface with HTTP2. @@ -219,12 +220,14 @@ func newHTTP2Client(ctx context.Context, conn net.Conn, opts ConnectOptions, // Send connection preface to server. n, err := t.conn.Write(ClientPreface) if err != nil { - err = connectionErrorf(true, err, "transport: failed to write client preface: %v", err) + err = newStatusf(codes.Unavailable, errEstablishConnection, "transport: failed to write client preface: %v", err). + Err() t.Close(err) return nil, err } if n != ClientPrefaceLen { - err = connectionErrorf(true, err, "transport: preface mismatch, wrote %d bytes; want %d", n, ClientPrefaceLen) + err = newStatusf(codes.Unavailable, errEstablishConnection, "transport: preface mismatch, wrote %d bytes; want %d", n, ClientPrefaceLen). + Err() t.Close(err) return nil, err } @@ -237,7 +240,8 @@ func newHTTP2Client(ctx context.Context, conn net.Conn, opts ConnectOptions, } err = t.framer.WriteSettings(ss...) if err != nil { - err = connectionErrorf(true, err, "transport: failed to write initial settings frame: %v", err) + err = newStatusf(codes.Unavailable, errEstablishConnection, "transport: failed to write initial settings frame: %v", err). + Err() t.Close(err) return nil, err } @@ -245,7 +249,7 @@ func newHTTP2Client(ctx context.Context, conn net.Conn, opts ConnectOptions, // Adjust the connection flow control window if needed. if delta := uint32(icwz - defaultWindowSize); delta > 0 { if err := t.framer.WriteWindowUpdate(0, delta); err != nil { - err = connectionErrorf(true, err, "transport: failed to write window update: %v", err) + err = newStatusf(codes.Unavailable, errEstablishConnection, "transport: failed to write window update: %v", err).Err() t.Close(err) return nil, err } @@ -431,9 +435,12 @@ func (t *http2Client) NewStream(ctx context.Context, callHdr *CallHdr) (_ *Strea if state := t.state; state != reachable { t.mu.Unlock() // Do a quick cleanup. - err := error(errStreamDrain) + err := errStatusStreamDrain if state == closing { - err = ErrConnClosing + err = errStatusConnClosing + // make sure the error exposed to users is *status.Error + cleanup(err) + return ErrConnClosing } cleanup(err) return err @@ -491,7 +498,7 @@ func (t *http2Client) NewStream(ctx context.Context, callHdr *CallHdr) (_ *Strea var sz int64 for _, f := range hdrFrame.hf { if sz += int64(f.Size()); sz > int64(*t.maxSendHeaderListSize) { - hdrListSizeErr = status.Errorf(codes.Internal, "header list size to send violates the maximum size (%d bytes) set by server", *t.maxSendHeaderListSize) + hdrListSizeErr = newStatusf(codes.Internal, kerrors.ErrMetaSizeExceeded, "header list size to send violates the maximum size (%d bytes) set by server", *t.maxSendHeaderListSize).Err() return false } } @@ -516,7 +523,7 @@ func (t *http2Client) NewStream(ctx context.Context, callHdr *CallHdr) (_ *Strea case <-s.ctx.Done(): return nil, ContextErr(s.ctx.Err()) case <-t.goAway: - return nil, errStreamDrain + return nil, errStatusStreamDrain case <-t.ctx.Done(): return nil, ErrConnClosing } @@ -534,11 +541,12 @@ func (t *http2Client) CloseStream(s *Stream, err error) { if err != nil { rst = true rstCode = http2.ErrCodeCancel + klog.CtxInfof(s.ctx, "KITEX: stream closed by ctx canceled, err: %v"+sendRSTStreamFrameSuffix, err) } - t.closeStream(s, err, rst, rstCode, status.Convert(err), nil, false) + t.closeStream(s, err, rst, rstCode, istatus.Convert(err), nil, false) } -func (t *http2Client) closeStream(s *Stream, err error, rst bool, rstCode http2.ErrCode, st *status.Status, mdata map[string][]string, eosReceived bool) { +func (t *http2Client) closeStream(s *Stream, err error, rst bool, rstCode http2.ErrCode, st *istatus.Status, mdata map[string][]string, eosReceived bool) { // Set stream status to done. if s.swapState(streamDone) == streamDone { // If it was already done, return. If multiple closeStream calls @@ -557,11 +565,24 @@ func (t *http2Client) closeStream(s *Stream, err error, rst bool, rstCode http2. // This will unblock reads eventually. s.write(recvMsg{err: err}) } + + // store closeStreamErr + storeErr := err + if err == io.EOF { + storeErr = st.Err() + } + if storeErr != nil { + s.closeStreamErr.Store(storeErr) + } + // If headerChan isn't closed, then close it. if atomic.CompareAndSwapUint32(&s.headerChanClosed, 0, 1) { s.noHeaders = true close(s.headerChan) } + if rst && isCustomRstCodeEnabled() { + rstCode = getRstCode(err) + } cleanup := &cleanupStream{ streamID: s.id, onWrite: func() { @@ -617,13 +638,13 @@ func (t *http2Client) Close(err error) error { t.kpDormancyCond.Signal() } t.mu.Unlock() - t.controlBuf.finish() + t.controlBuf.finish(err) t.cancel() cErr := t.conn.Close() // Notify all active streams. for _, s := range streams { - t.closeStream(s, err, false, http2.ErrCodeNo, status.New(codes.Unavailable, ErrConnClosing.Desc), nil, false) + t.closeStream(s, err, false, http2.ErrCodeNo, istatus.New(codes.Unavailable, ErrConnClosing.Desc), nil, false) } return cErr } @@ -644,7 +665,7 @@ func (t *http2Client) GracefulClose() { active := len(t.activeStreams) t.mu.Unlock() if active == 0 { - t.Close(connectionErrorf(true, nil, "no active streams left to process while draining")) + t.Close(newStatus(codes.Unavailable, errNoActiveStream, "no active streams left to process while draining").Err()) return } t.controlBuf.put(&incomingGoAway{}) @@ -656,10 +677,10 @@ func (t *http2Client) Write(s *Stream, hdr, data []byte, opts *Options) error { if opts.Last { // If it's the last message, update stream state. if !s.compareAndSwapState(streamActive, streamWriteDone) { - return errStreamDone + return s.getCloseStreamErr() } } else if s.getState() != streamActive { - return errStreamDone + return s.getCloseStreamErr() } df := newDataFrame() df.streamID = s.id @@ -670,7 +691,7 @@ func (t *http2Client) Write(s *Stream, hdr, data []byte, opts *Options) error { df.originD = df.d if hdr != nil || data != nil { // If it's not an empty data frame, check quota. if err := s.wq.get(int32(len(hdr) + len(data))); err != nil { - return err + return s.getCloseStreamErr() } } return t.controlBuf.put(df) @@ -766,7 +787,9 @@ func (t *http2Client) handleData(f *grpcframe.DataFrame) { } if size > 0 { if err := s.fc.onData(size); err != nil { - t.closeStream(s, io.EOF, true, http2.ErrCodeFlowControl, status.New(codes.Internal, err.Error()), nil, false) + klog.CtxErrorf(s.ctx, "KITEX: http2Client.handleData inflow control err: %v, code: %d"+sendRSTStreamFrameSuffix, err, http2.ErrCodeFlowControl) + st := newStatus(codes.Internal, errStreamFlowControl, err.Error()) + t.closeStream(s, io.EOF, true, http2.ErrCodeFlowControl, st, nil, false) return } if f.Header().Flags.Has(http2.FlagDataPadded) { @@ -787,7 +810,8 @@ func (t *http2Client) handleData(f *grpcframe.DataFrame) { // The server has closed the stream without sending trailers. Record that // the read direction is closed, and set the status appropriately. if f.FrameHeader.Flags.Has(http2.FlagDataEndStream) { - t.closeStream(s, io.EOF, false, http2.ErrCodeNo, status.New(codes.Internal, "server closed the stream without sending trailers"), nil, true) + st := newStatus(codes.Internal, errClosedWithoutTrailer, "server closed the stream without sending trailers") + t.closeStream(s, io.EOF, false, http2.ErrCodeNo, st, nil, true) } } @@ -800,19 +824,11 @@ func (t *http2Client) handleRSTStream(f *http2.RSTStreamFrame) { // The stream was unprocessed by the server. atomic.StoreUint32(&s.unprocessed, 1) } - statusCode, ok := http2ErrConvTab[f.ErrCode] - if !ok { - klog.Warnf("transport: http2Client.handleRSTStream found no mapped gRPC status for the received nhttp2 error %v", f.ErrCode) - statusCode = codes.Unknown - } - if statusCode == codes.Canceled { - if d, ok := s.ctx.Deadline(); ok && !d.After(time.Now()) { - // Our deadline was already exceeded, and that was likely the cause - // of this cancelation. Alter the status code accordingly. - statusCode = codes.DeadlineExceeded - } - } - t.closeStream(s, io.EOF, false, http2.ErrCodeNo, status.Newf(statusCode, "stream terminated by RST_STREAM with error code: %v", f.ErrCode), nil, false) + + mappingErr, stCode := getMappingErrAndStatusCode(s.ctx, f.ErrCode) + + st := newStatusf(stCode, mappingErr, "stream terminated by RST_STREAM with error code: %v", f.ErrCode) + t.closeStream(s, io.EOF, false, http2.ErrCodeNo, st, nil, false) } func (t *http2Client) handleSettings(f *grpcframe.SettingsFrame, isFirst bool) { @@ -889,7 +905,8 @@ func (t *http2Client) handleGoAway(f *grpcframe.GoAwayFrame) { id := f.LastStreamID if id > 0 && id%2 != 1 { t.mu.Unlock() - t.Close(connectionErrorf(true, nil, "received goaway with non-zero even-numbered numbered stream id: %v", id)) + st := newStatusf(codes.Unavailable, errHandleGoAway, "received goaway with non-zero even-numbered numbered stream id: %v", id) + t.Close(st.Err()) return } // A client can receive multiple GoAways from the server (see @@ -907,7 +924,8 @@ func (t *http2Client) handleGoAway(f *grpcframe.GoAwayFrame) { // If there are multiple GoAways the first one should always have an ID greater than the following ones. if id > t.prevGoAwayID { t.mu.Unlock() - t.Close(connectionErrorf(true, nil, "received goaway with stream id: %v, which exceeds stream id of previous goaway: %v", id, t.prevGoAwayID)) + st := newStatusf(codes.Unavailable, errHandleGoAway, "received goaway with stream id: %v, which exceeds stream id of previous goaway: %v", id, t.prevGoAwayID) + t.Close(st.Err()) return } default: @@ -932,14 +950,16 @@ func (t *http2Client) handleGoAway(f *grpcframe.GoAwayFrame) { if streamID > id && streamID <= upperLimit { // The stream was unprocessed by the server. atomic.StoreUint32(&stream.unprocessed, 1) - t.closeStream(stream, errStreamDrain, false, http2.ErrCodeNo, statusGoAway, nil, false) + st := newStatus(codes.Unavailable, errStreamDrain, "the stream is rejected because server is draining the connection") + t.closeStream(stream, st.Err(), false, http2.ErrCodeNo, st, nil, false) } } t.prevGoAwayID = id active := len(t.activeStreams) t.mu.Unlock() if active == 0 { - t.Close(connectionErrorf(true, nil, "received goaway and there are no active streams")) + st := newStatus(codes.Unavailable, errNoActiveStream, "received goaway and there are no active streams") + t.Close(st.Err()) } } @@ -982,7 +1002,8 @@ func (t *http2Client) operateHeaders(frame *grpcframe.MetaHeadersFrame) { if !initialHeader && !endStream { // As specified by gRPC over HTTP2, a HEADERS frame (and associated CONTINUATION frames) can only appear at the start or end of a stream. Therefore, second HEADERS frame must have EOS bit set. - st := status.New(codes.Internal, "a HEADERS frame cannot appear in the middle of a stream") + st := newStatus(codes.Internal, errMiddleHeader, "a HEADERS frame cannot appear in the middle of a stream") + klog.CtxErrorf(s.ctx, "KITEX: http2Client.operateHeaders received HEADERS frame in the middle of a stream"+sendRSTStreamFrameSuffix) t.closeStream(s, st.Err(), true, http2.ErrCodeProtocol, st, nil, false) return } @@ -990,8 +1011,10 @@ func (t *http2Client) operateHeaders(frame *grpcframe.MetaHeadersFrame) { state := &decodeState{} // Initialize isGRPC value to be !initialHeader, since if a gRPC Response-Headers has already been received, then it means that the peer is speaking gRPC and we are in gRPC mode. state.data.isGRPC = !initialHeader - if err := state.decodeHeader(frame); err != nil { - t.closeStream(s, err, true, http2.ErrCodeProtocol, status.Convert(err), nil, endStream) + if st := state.decodeHeader(frame); st != nil { + klog.CtxErrorf(s.ctx, "KITEX: http2Client.operateHeaders decode HEADERS frame failed, err: %v, code: %d"+sendRSTStreamFrameSuffix, st.Err(), http2.ErrCodeProtocol) + istatus.InjectMappingErr(st, errDecodeHeader) + t.closeStream(s, st.Err(), true, http2.ErrCodeProtocol, st, nil, endStream) return } @@ -1034,8 +1057,8 @@ func (t *http2Client) reader() { // Check the validity of server preface. frame, err := t.framer.ReadFrame() if err != nil { - err = connectionErrorf(true, err, "error reading from server, remoteAddress=%s, error=%v", t.conn.RemoteAddr(), err) - t.Close(err) // this kicks off resetTransport, so must be last before return + st := newStatusf(codes.Unavailable, errEstablishConnection, "error reading from server, remoteAddress=%s, error=%v", t.conn.RemoteAddr(), err) + t.Close(st.Err()) // this kicks off resetTransport, so must be last before return return } t.conn.SetReadDeadline(time.Time{}) // reset deadline once we get the settings frame (we didn't time out, yay!) @@ -1044,8 +1067,8 @@ func (t *http2Client) reader() { } sf, ok := frame.(*grpcframe.SettingsFrame) if !ok { - err = connectionErrorf(true, err, "first frame received is not a setting frame") - t.Close(err) // this kicks off resetTransport, so must be last before return + st := newStatus(codes.Unavailable, errEstablishConnection, "first frame received is not a setting frame") + t.Close(st.Err()) // this kicks off resetTransport, so must be last before return return } t.handleSettings(sf, true) @@ -1073,13 +1096,15 @@ func (t *http2Client) reader() { if err != nil { msg = err.Error() } - t.closeStream(s, status.New(code, msg).Err(), true, http2.ErrCodeProtocol, status.New(code, msg), nil, false) + klog.CtxErrorf(s.ctx, "KITEX: http2Client.reader encountered http2.StreamError: %v, code: %d"+sendRSTStreamFrameSuffix, se, http2.ErrCodeProtocol) + st := newStatus(code, errHTTP2Stream, msg) + t.closeStream(s, st.Err(), true, http2.ErrCodeProtocol, st, nil, false) } continue } else { // Transport error. - err = connectionErrorf(true, err, "error reading from server, remoteAddress=%s, error=%v", t.conn.RemoteAddr(), err) - t.Close(err) + st := newStatusf(codes.Unavailable, errHTTP2Connection, "error reading from server, remoteAddress=%s, error=%v", t.conn.RemoteAddr(), err) + t.Close(st.Err()) return } } @@ -1137,7 +1162,8 @@ func (t *http2Client) keepalive() { continue } if outstandingPing && timeoutLeft <= 0 { - t.Close(connectionErrorf(true, nil, "keepalive ping failed to receive ACK within timeout")) + st := newStatus(codes.Unavailable, errKeepAlive, "keepalive ping failed to receive ACK within timeout") + t.Close(st.Err()) return } t.mu.Lock() diff --git a/pkg/remote/trans/nphttp2/grpc/http2_server.go b/pkg/remote/trans/nphttp2/grpc/http2_server.go index c2b84efe13..bdc959a92c 100644 --- a/pkg/remote/trans/nphttp2/grpc/http2_server.go +++ b/pkg/remote/trans/nphttp2/grpc/http2_server.go @@ -34,7 +34,10 @@ import ( "sync/atomic" "time" + "github.com/cloudwego/kitex/pkg/kerrors" "github.com/cloudwego/kitex/pkg/remote/codec/protobuf/encoding" + "github.com/cloudwego/kitex/pkg/remote/trans/nphttp2/status" + "github.com/cloudwego/kitex/pkg/remote/transmeta" "github.com/cloudwego/netpoll" "golang.org/x/net/http2" @@ -46,27 +49,32 @@ import ( "github.com/cloudwego/kitex/pkg/remote/trans/nphttp2/codes" "github.com/cloudwego/kitex/pkg/remote/trans/nphttp2/grpc/grpcframe" "github.com/cloudwego/kitex/pkg/remote/trans/nphttp2/metadata" - "github.com/cloudwego/kitex/pkg/remote/trans/nphttp2/status" "github.com/cloudwego/kitex/pkg/utils" ) var ( // ErrIllegalHeaderWrite indicates that setting header is illegal because of // the stream's state. - ErrIllegalHeaderWrite = errors.New("transport: the stream is done or WriteHeader was already called") + ErrIllegalHeaderWrite = errors.New("transport: WriteHeader was already called") + errStatusIllegalHeaderWrite = newStatus(codes.Internal, errIllegalHeaderWrite, ErrIllegalHeaderWrite.Error()+triggeredByHandlerSideSuffix). + Err() // ErrHeaderListSizeLimitViolation indicates that the header list size is larger // than the limit set by peer. - ErrHeaderListSizeLimitViolation = errors.New("transport: trying to send header list size larger than the limit set by peer") + ErrHeaderListSizeLimitViolation = errors.New("transport: trying to send header list size larger than the limit set by peer") + errStatusHeaderListSizeLimitViolation = newStatus(codes.Internal, kerrors.ErrMetaSizeExceeded, ErrHeaderListSizeLimitViolation.Error()+triggeredByHandlerSideSuffix). + Err() // errors used for cancelling stream. // the code should be codes.Canceled coz it's NOT returned from remote - errConnectionEOF = status.New(codes.Canceled, "transport: connection EOF").Err() - errStreamClosing = status.New(codes.Canceled, "transport: stream is closing").Err() - errMaxStreamsExceeded = status.New(codes.Canceled, "transport: max streams exceeded").Err() - errNotReachable = status.New(codes.Canceled, "transport: server not reachable").Err() - errMaxAgeClosing = status.New(codes.Canceled, "transport: closing server transport due to maximum connection age").Err() - errIdleClosing = status.New(codes.Canceled, "transport: closing server transport due to idleness").Err() + errStatusConnectionEOF = newStatus(codes.Canceled, errHTTP2Connection, "transport: connection EOF"+triggeredByRemoteServiceSuffix). + Err() + errStatusMaxStreamsExceeded = newStatus(codes.Canceled, errMaxStreamExceeded, "transport: max streams exceeded"+triggeredByRemoteServiceSuffix). + Err() + errStatusNotReachable = newStatus(codes.Canceled, errNotReachable, "transport: server not reachable"+triggeredByRemoteServiceSuffix). + Err() + errStatusHandlerReturn = newStatus(codes.Canceled, kerrors.ErrHandlerReturn, "transport: handler return"+triggeredByHandlerSideSuffix). + Err() ) func init() { @@ -177,14 +185,16 @@ func newHTTP2Server(ctx context.Context, conn net.Conn, config *ServerConfig) (_ } if err := framer.WriteSettings(isettings...); err != nil { - return nil, connectionErrorf(false, err, "transport: %v", err) + st := newStatusf(codes.Unavailable, errEstablishConnection, "transport: server failed to write initial settings frame: %v", err) + return nil, st.Err() } // Adjust the connection flow control window if needed. if icwz > defaultWindowSize { if delta := icwz - defaultWindowSize; delta > 0 { if err := framer.WriteWindowUpdate(0, delta); err != nil { - return nil, connectionErrorf(false, err, "transport: %v", err) + st := newStatusf(codes.Unavailable, errEstablishConnection, "transport: server failed to write window update frame: %v", err) + return nil, st.Err() } } } @@ -247,32 +257,37 @@ func newHTTP2Server(ctx context.Context, conn net.Conn, config *ServerConfig) (_ // Check the validity of client preface. preface := make([]byte, len(ClientPreface)) - if _, err := io.ReadFull(t.conn, preface); err != nil { + if _, rErr := io.ReadFull(t.conn, preface); rErr != nil { // In deployments where a gRPC server runs behind a cloud load balancer // which performs regular TCP level health checks, the connection is // closed immediately by the latter. Returning io.EOF here allows the // grpc server implementation to recognize this scenario and suppress // logging to reduce spam. - if err == io.EOF { + if rErr == io.EOF { return nil, io.EOF } - return nil, connectionErrorf(false, err, "transport: http2Server.HandleStreams failed to receive the preface from client: %v", err) + err = newStatusf(codes.Unavailable, errEstablishConnection, "transport: server failed to receive the preface from client: %v", rErr).Err() + return nil, err } if !bytes.Equal(preface, ClientPreface) { - return nil, connectionErrorf(false, nil, "transport: http2Server.HandleStreams received bogus greeting from client: %q", preface) + err = newStatusf(codes.Unavailable, errEstablishConnection, "transport: server received bogus greeting from client: %q", preface).Err() + return nil, err } - frame, err := t.framer.ReadFrame() - if err == io.EOF || err == io.ErrUnexpectedEOF { + frame, rErr := t.framer.ReadFrame() + if rErr == io.EOF || rErr == io.ErrUnexpectedEOF { + err = newStatusf(codes.Unavailable, errEstablishConnection, "transport: connection EOF: %v", err).Err() return nil, err } - if err != nil { - return nil, connectionErrorf(false, err, "transport: http2Server.HandleStreams failed to read initial settings frame: %v", err) + if rErr != nil { + err = newStatusf(codes.Unavailable, errEstablishConnection, "transport: server failed to read initial settings frame: %v", err).Err() + return nil, err } atomic.StoreInt64(&t.lastRead, time.Now().UnixNano()) sf, ok := frame.(*grpcframe.SettingsFrame) if !ok { - return nil, connectionErrorf(false, nil, "transport: http2Server.HandleStreams saw invalid preface type %T from client", frame) + err = newStatusf(codes.Unavailable, errEstablishConnection, "transport: server received invalid frame type %T from client", frame).Err() + return nil, err } t.handleSettings(sf) @@ -292,20 +307,21 @@ func newHTTP2Server(ctx context.Context, conn net.Conn, config *ServerConfig) (_ // operateHeaders takes action on the decoded headers. Returns an error if fatal // error encountered and transport needs to close, otherwise returns nil. +// users will only be able to perceive the stream if the input handle is executed. func (t *http2Server) operateHeaders(frame *grpcframe.MetaHeadersFrame, handle func(*Stream), traceCtx func(context.Context, string) context.Context) error { streamID := frame.Header().StreamID state := &decodeState{ serverSide: true, } - if err := state.decodeHeader(frame); err != nil { - if se, ok := status.FromError(err); ok { - t.controlBuf.put(&cleanupStream{ - streamID: streamID, - rst: true, - rstCode: statusCodeConvTab[se.Code()], - onWrite: func() {}, - }) - } + if st := state.decodeHeader(frame); st != nil { + rstCode := statusCodeConvTab[st.Code()] + klog.CtxErrorf(t.ctx, "KITEX: http2Server.operateHeaders failed to decode header frame, err=%v, code: %d"+sendRSTStreamFrameSuffix, st.Err(), rstCode) + t.controlBuf.put(&cleanupStream{ + streamID: streamID, + rst: true, + rstCode: rstCode, + onWrite: func() {}, + }) return nil } @@ -330,27 +346,33 @@ func (t *http2Server) operateHeaders(frame *grpcframe.MetaHeadersFrame, handle f } else { s.ctx, cancel = context.WithCancel(t.ctx) } - s.ctx, s.cancel = newContextWithCancelReason(s.ctx, cancel) + s.ctx, s.cancelFunc = newContextWithCancelReason(s.ctx, cancel) // Attach the received metadata to the context. if len(state.data.mdata) > 0 { s.ctx = metadata.NewIncomingContext(s.ctx, state.data.mdata) + // retrieve source service from metadata + vals := state.data.mdata[transmeta.HTTPSourceService] + if len(vals) > 0 { + s.sourceService = vals[0] + } } t.mu.Lock() if t.state != reachable { t.mu.Unlock() - s.cancel(errNotReachable) + s.cancel(errStatusNotReachable) return nil } if uint32(len(t.activeStreams)) >= t.maxStreams { t.mu.Unlock() + klog.CtxErrorf(t.ctx, "KITEX: http2Server.operateHeaders failed to create stream, err=%v, code: %d"+sendRSTStreamFrameSuffix, errStatusMaxStreamsExceeded, http2.ErrCodeRefusedStream) t.controlBuf.put(&cleanupStream{ streamID: streamID, rst: true, rstCode: http2.ErrCodeRefusedStream, onWrite: func() {}, }) - s.cancel(errMaxStreamsExceeded) + s.cancel(errStatusMaxStreamsExceeded) return nil } if streamID%2 != 1 || streamID <= t.maxStreamID { @@ -405,11 +427,11 @@ func (t *http2Server) HandleStreams(handle func(*Stream), traceCtx func(context. s := t.activeStreams[se.StreamID] t.mu.Unlock() if s != nil { - // it will be codes.Internal error for GRPC - // TODO: map http2.StreamError to status.Error? - s.cancel(err) - t.closeStream(s, true, se.Code, false) + klog.CtxErrorf(s.ctx, "KITEX: http2Server.HandleStreams encountered http2.StreamError, err=%v, code: %d"+sendRSTStreamFrameSuffix, err, se.Code) + stErr := newStatusf(codes.Canceled, errHTTP2Stream, "transport: ReadFrame encountered http2.StreamError: %v [triggered by %s]", err, s.sourceService).Err() + t.closeStream(s, stErr, true, se.Code, false) } else { + klog.CtxErrorf(t.ctx, "KITEX: http2Server.HandleStreams failed to ReadFrame, err=%v, code: %d"+sendRSTStreamFrameSuffix, err, se.Code) t.controlBuf.put(&cleanupStream{ streamID: se.StreamID, rst: true, @@ -420,18 +442,20 @@ func (t *http2Server) HandleStreams(handle func(*Stream), traceCtx func(context. continue } if err == io.EOF || err == io.ErrUnexpectedEOF || errors.Is(err, netpoll.ErrEOF) { - t.closeWithErr(errConnectionEOF) + t.closeWithErr(errStatusConnectionEOF) return } klog.CtxWarnf(t.ctx, "transport: http2Server.HandleStreams failed to read frame: %v", err) - t.closeWithErr(err) + stErr := newStatusf(codes.Canceled, errHTTP2Connection, "transport: ReadFrame encountered err: %v"+triggeredByRemoteServiceSuffix, err).Err() + t.closeWithErr(stErr) return } switch frame := frame.(type) { case *grpcframe.MetaHeadersFrame: if err := t.operateHeaders(frame, handle, traceCtx); err != nil { klog.CtxErrorf(t.ctx, "transport: http2Server.HandleStreams fatal err: %v", err) - t.closeWithErr(err) + stErr := newStatus(codes.Canceled, errOperateHeaders, err.Error()).Err() + t.closeWithErr(stErr) break } case *grpcframe.DataFrame: @@ -551,7 +575,9 @@ func (t *http2Server) handleData(f *grpcframe.DataFrame) { } if size > 0 { if err := s.fc.onData(size); err != nil { - t.closeStream(s, true, http2.ErrCodeFlowControl, false) + klog.CtxErrorf(s.ctx, "KITEX: http2Server.handleData inflow control err: %v, code: %d"+sendRSTStreamFrameSuffix, err, http2.ErrCodeFlowControl) + stErr := newStatusf(codes.Canceled, errStreamFlowControl, "transport: inflow control err: %v [triggered by %s]", err, s.sourceService).Err() + t.closeStream(s, stErr, true, http2.ErrCodeFlowControl, false) return } if f.Header().Flags.Has(http2.FlagDataPadded) { @@ -579,7 +605,10 @@ func (t *http2Server) handleData(f *grpcframe.DataFrame) { func (t *http2Server) handleRSTStream(f *http2.RSTStreamFrame) { // If the stream is not deleted from the transport's active streams map, then do a regular close stream. if s, ok := t.getStream(f); ok { - t.closeStream(s, false, 0, false) + mappingErr, stCode := getMappingErrAndStatusCode(s.ctx, f.ErrCode) + stErr := newStatusf(stCode, mappingErr, "transport: RSTStream Frame received with error code: %d [triggered by %s]", f.ErrCode, s.sourceService).Err() + klog.CtxInfof(s.ctx, "transport: http2Server.handleRSTStream received RSTStream Frame with error code: %v", f.ErrCode) + t.closeStream(s, stErr, false, 0, false) return } // If the stream is already deleted from the active streams map, then put a cleanupStream item into controlbuf to delete the stream from loopy writer's established streams map. @@ -589,6 +618,7 @@ func (t *http2Server) handleRSTStream(f *http2.RSTStreamFrame) { rstCode: 0, onWrite: func() {}, }) + // since we do not need to send RstStream Frame, do not add log here } func (t *http2Server) handleSettings(f *grpcframe.SettingsFrame) { @@ -711,8 +741,11 @@ func (t *http2Server) checkForHeaderListSize(it interface{}) bool { // WriteHeader sends the header metadata md back to the client. func (t *http2Server) WriteHeader(s *Stream, md metadata.MD) error { - if s.updateHeaderSent() || s.getState() == streamDone { - return ErrIllegalHeaderWrite + if s.updateHeaderSent() { + return errStatusIllegalHeaderWrite + } + if s.getState() == streamDone { + return ContextErr(s.ctx.Err()) } s.hdrMu.Lock() if md.Len() > 0 { @@ -756,8 +789,10 @@ func (t *http2Server) writeHeaderLocked(s *Stream) error { if err != nil { return err } - t.closeStream(s, true, http2.ErrCodeInternal, false) - return ErrHeaderListSizeLimitViolation + klog.CtxErrorf(s.ctx, "KITEX: http2Server.writeHeaderLocked checkForHeaderListSize failed, code: %d"+sendRSTStreamFrameSuffix, http2.ErrCodeInternal) + stErr := errStatusHeaderListSizeLimitViolation + t.closeStream(s, stErr, true, http2.ErrCodeInternal, false) + return stErr } return nil } @@ -819,8 +854,9 @@ func (t *http2Server) WriteStatus(s *Stream, st *status.Status) error { if err != nil { return err } - t.closeStream(s, true, http2.ErrCodeInternal, false) - return ErrHeaderListSizeLimitViolation + stErr := errStatusHeaderListSizeLimitViolation + t.closeStream(s, stErr, true, http2.ErrCodeInternal, false) + return stErr } // Send a RST_STREAM after the trailers if the client has not already half-closed. rst := s.getState() == streamActive @@ -838,13 +874,6 @@ func (t *http2Server) Write(s *Stream, hdr, data []byte, opts *Options) error { } else { // Writing headers checks for this condition. if s.getState() == streamDone { - // TODO(mmukhi, dfawley): Should the server write also return io.EOF? - s.cancel(errStreamClosing) - select { - case <-t.done: - return ErrConnClosing - default: - } return ContextErr(s.ctx.Err()) } } @@ -856,11 +885,6 @@ func (t *http2Server) Write(s *Stream, hdr, data []byte, opts *Options) error { df.originD = df.d df.resetPingStrikes = &t.resetPingStrikes if err := s.wq.get(int32(len(hdr) + len(data))); err != nil { - select { - case <-t.done: - return ErrConnClosing - default: - } return ContextErr(s.ctx.Err()) } return t.controlBuf.put(df) @@ -921,7 +945,8 @@ func (t *http2Server) keepalive() { case <-ageTimer.C: // Close the connection after grace period. klog.Infof("transport: closing server transport due to maximum connection age.") - t.closeWithErr(errMaxAgeClosing) + stErr := newStatus(codes.Canceled, errKeepAlive, "transport: closing server transport due to maximum connection age"+triggeredByRemoteServiceSuffix).Err() + t.closeWithErr(stErr) case <-t.done: } return @@ -938,7 +963,8 @@ func (t *http2Server) keepalive() { } if outstandingPing && kpTimeoutLeft <= 0 { klog.Infof("transport: closing server transport due to idleness.") - t.closeWithErr(errIdleClosing) + stErr := newStatus(codes.Canceled, errKeepAlive, "transport: closing server transport due to idleness"+triggeredByRemoteServiceSuffix).Err() + t.closeWithErr(stErr) return } if !outstandingPing { @@ -963,7 +989,7 @@ func (t *http2Server) keepalive() { // TODO(zhaoq): Now the destruction is not blocked on any pending streams. This // could cause some resource issue. Revisit this later. func (t *http2Server) Close() error { - return t.closeWithErr(nil) + return t.closeWithErr(errStatusConnClosing) } func (t *http2Server) closeWithErr(reason error) error { @@ -976,7 +1002,7 @@ func (t *http2Server) closeWithErr(reason error) error { streams := t.activeStreams t.activeStreams = nil t.mu.Unlock() - t.controlBuf.finish() + t.controlBuf.finish(reason) close(t.done) err := t.conn.Close() @@ -990,11 +1016,6 @@ func (t *http2Server) closeWithErr(reason error) error { // deleteStream deletes the stream s from transport's active streams. func (t *http2Server) deleteStream(s *Stream, eosReceived bool) { - // In case stream sending and receiving are invoked in separate - // goroutines (e.g., bi-directional streaming), cancel needs to be - // called to interrupt the potential blocking on other goroutines. - s.cancel(nil) // more details about the reason? - t.mu.Lock() if _, ok := t.activeStreams[s.id]; ok { delete(t.activeStreams, s.id) @@ -1012,7 +1033,11 @@ func (t *http2Server) finishStream(s *Stream, rst bool, rstCode http2.ErrCode, h // If the stream was already done, return. return } + s.cancel(errStatusHandlerReturn) + if isCustomRstCodeEnabled() { + rstCode = getRstCode(nil) + } hdr.cleanup = &cleanupStream{ streamID: s.id, rst: rst, @@ -1025,10 +1050,17 @@ func (t *http2Server) finishStream(s *Stream, rst bool, rstCode http2.ErrCode, h } // closeStream clears the footprint of a stream when the stream is not needed any more. -func (t *http2Server) closeStream(s *Stream, rst bool, rstCode http2.ErrCode, eosReceived bool) { +func (t *http2Server) closeStream(s *Stream, err error, rst bool, rstCode http2.ErrCode, eosReceived bool) { + // In case stream sending and receiving are invoked in separate + // goroutines (e.g., bi-directional streaming), cancel needs to be + // called to interrupt the potential blocking on other goroutines. + s.cancel(err) s.swapState(streamDone) t.deleteStream(s, eosReceived) + if rst && isCustomRstCodeEnabled() { + rstCode = getRstCode(err) + } t.controlBuf.put(&cleanupStream{ streamID: s.id, rst: rst, diff --git a/pkg/remote/trans/nphttp2/grpc/http_util.go b/pkg/remote/trans/nphttp2/grpc/http_util.go index ecb6bb7ec1..6ba245e2f4 100644 --- a/pkg/remote/trans/nphttp2/grpc/http_util.go +++ b/pkg/remote/trans/nphttp2/grpc/http_util.go @@ -22,8 +22,11 @@ package grpc import ( "bytes" + "context" "encoding/base64" + "errors" "fmt" + "io" "math" "net/http" "strconv" @@ -141,8 +144,8 @@ type parsedHeaderData struct { // Otherwise (i.e. a content-type string starts without "application/grpc", or does not exist), we // are in HTTP fallback mode, and should handle error specific to HTTP. isGRPC bool - grpcErr error - httpErr error + grpcStatus *status.Status + httpErrStatus *status.Status contentTypeErr string } @@ -289,11 +292,11 @@ func decodeMetadataHeader(k, v string) (string, error) { return v, nil } -func (d *decodeState) decodeHeader(frame *grpcframe.MetaHeadersFrame) error { +func (d *decodeState) decodeHeader(frame *grpcframe.MetaHeadersFrame) *status.Status { // frame.Truncated is set to true when framer detects that the current header // list size hits MaxHeaderListSize limit. if frame.Truncated { - return status.New(codes.Internal, "peer header list size exceeded limit").Err() + return status.New(codes.Internal, "peer header list size exceeded limit") } for _, hf := range frame.Fields { @@ -301,8 +304,8 @@ func (d *decodeState) decodeHeader(frame *grpcframe.MetaHeadersFrame) error { } if d.data.isGRPC { - if d.data.grpcErr != nil { - return d.data.grpcErr + if d.data.grpcStatus != nil { + return d.data.grpcStatus } if d.serverSide { return nil @@ -321,8 +324,8 @@ func (d *decodeState) decodeHeader(frame *grpcframe.MetaHeadersFrame) error { } // HTTP fallback mode - if d.data.httpErr != nil { - return d.data.httpErr + if d.data.httpErrStatus != nil { + return d.data.httpErrStatus } var ( @@ -337,7 +340,7 @@ func (d *decodeState) decodeHeader(frame *grpcframe.MetaHeadersFrame) error { } } - return status.New(code, d.constructHTTPErrMsg()).Err() + return status.New(code, d.constructHTTPErrMsg()) } // constructErrMsg constructs error message to be returned in HTTP fallback mode. @@ -389,7 +392,7 @@ func (d *decodeState) processHeaderField(f hpack.HeaderField) { case "grpc-status": code, err := strconv.Atoi(f.Value) if err != nil { - d.data.grpcErr = status.Errorf(codes.Internal, "transport: malformed grpc-status: %v", err) + d.data.grpcStatus = status.Newf(codes.Internal, "transport: malformed grpc-status: %v", err) return } d.data.rawStatusCode = &code @@ -398,26 +401,26 @@ func (d *decodeState) processHeaderField(f hpack.HeaderField) { case "biz-status": code, err := strconv.Atoi(f.Value) if err != nil { - d.data.grpcErr = status.Errorf(codes.Internal, "transport: malformed biz-status: %v", err) + d.data.grpcStatus = status.Newf(codes.Internal, "transport: malformed biz-status: %v", err) return } d.data.bizStatusCode = &code case "biz-extra": extra, err := utils.JSONStr2Map(f.Value) if err != nil { - d.data.grpcErr = status.Errorf(codes.Internal, "transport: malformed biz-extra: %v", err) + d.data.grpcStatus = status.Newf(codes.Internal, "transport: malformed biz-extra: %v", err) return } d.data.bizStatusExtra = extra case "grpc-status-details-bin": v, err := decodeBinHeader(f.Value) if err != nil { - d.data.grpcErr = status.Errorf(codes.Internal, "transport: malformed grpc-status-details-bin: %v", err) + d.data.grpcStatus = status.Newf(codes.Internal, "transport: malformed grpc-status-details-bin: %v", err) return } s := &spb.Status{} if err := proto.Unmarshal(v, s); err != nil { - d.data.grpcErr = status.Errorf(codes.Internal, "transport: malformed grpc-status-details-bin: %v", err) + d.data.grpcStatus = status.Newf(codes.Internal, "transport: malformed grpc-status-details-bin: %v", err) return } d.data.statusGen = status.FromProto(s) @@ -425,21 +428,21 @@ func (d *decodeState) processHeaderField(f hpack.HeaderField) { d.data.timeoutSet = true var err error if d.data.timeout, err = decodeTimeout(f.Value); err != nil { - d.data.grpcErr = status.Errorf(codes.Internal, "transport: malformed time-out: %v", err) + d.data.grpcStatus = status.Newf(codes.Internal, "transport: malformed time-out: %v", err) } case ":path": d.data.method = f.Value case ":status": code, err := strconv.Atoi(f.Value) if err != nil { - d.data.httpErr = status.Errorf(codes.Internal, "transport: malformed http-status: %v", err) + d.data.httpErrStatus = status.Newf(codes.Internal, "transport: malformed http-status: %v", err) return } d.data.httpStatus = &code case "grpc-tags-bin": v, err := decodeBinHeader(f.Value) if err != nil { - d.data.grpcErr = status.Errorf(codes.Internal, "transport: malformed grpc-tags-bin: %v", err) + d.data.grpcStatus = status.Newf(codes.Internal, "transport: malformed grpc-tags-bin: %v", err) return } d.data.statsTags = v @@ -447,7 +450,7 @@ func (d *decodeState) processHeaderField(f hpack.HeaderField) { case "grpc-trace-bin": v, err := decodeBinHeader(f.Value) if err != nil { - d.data.grpcErr = status.Errorf(codes.Internal, "transport: malformed grpc-trace-bin: %v", err) + d.data.grpcStatus = status.Newf(codes.Internal, "transport: malformed grpc-trace-bin: %v", err) return } d.data.statsTrace = v @@ -642,3 +645,135 @@ func decodeGrpcMessageUnchecked(msg string) string { } return buf.String() } + +var ( + customRstCodeEnabled = false + rstCode2MappingErrMap = map[http2.ErrCode]error{} + mappingErr2RstCodeMap = map[error]http2.ErrCode{} +) + +// SetCustomRstCodeEnabled enables/disables the ErrType and custom RstCode mapping. +// it is off by default. +func SetCustomRstCodeEnabled(flag bool) { + customRstCodeEnabled = flag +} + +// RegisterCustomRstCode registers a mapping between an ErrType and a custom RstCode. +// The mapping only works when SetCustomRstCodeEnabled(true) is invoked. +// e.g. RegisterCustomRstCode(ErrTypeGracefulShutdown, 1000) +func RegisterCustomRstCode(errType ErrType, rstCode uint32) { + mappingErr, ok := errType2ErrMap[errType] + if !ok { + return + } + h2Code := http2.ErrCode(rstCode) + rstCode2MappingErrMap[h2Code] = mappingErr + mappingErr2RstCodeMap[mappingErr] = h2Code +} + +func isCustomRstCodeEnabled() bool { + return customRstCodeEnabled +} + +func cleanupCustomRstCodeMapping() { + rstCode2MappingErrMap = map[http2.ErrCode]error{} + mappingErr2RstCodeMap = map[error]http2.ErrCode{} +} + +func getMappingErrAndStatusCode(ctx context.Context, rstCode http2.ErrCode) (error, codes.Code) { + mappingErr := getMappingErr(rstCode) + stCode := retrieveStatusCode(ctx, rstCode, mappingErr) + return mappingErr, stCode +} + +func getMappingErr(rstCode http2.ErrCode) error { + err, ok := rstCode2MappingErrMap[rstCode] + if !isCustomRstCodeEnabled() || !ok { + return errRecvRstStream + } + return err +} + +func retrieveStatusCode(ctx context.Context, errCode http2.ErrCode, mappingErr error) codes.Code { + stCode, ok := http2ErrConvTab[errCode] + if !ok { + if !isCustomRstCodeEnabled() { + klog.Warnf("transport: http2Client.handleRSTStream found no mapped gRPC status for the received nhttp2 error %v", errCode) + stCode = codes.Unknown + } else { + stCode = codes.Canceled + if errors.Is(mappingErr, kerrors.ErrGracefulShutdown) { + stCode = codes.Unavailable + } + if errors.Is(mappingErr, kerrors.ErrStreamTimeout) { + stCode = codes.DeadlineExceeded + } + } + } + if stCode == codes.Canceled { + if d, ok := ctx.Deadline(); ok && !d.After(time.Now()) { + // Our deadline was already exceeded, and that was likely the cause + // of this cancelation. Alter the status code accordingly. + stCode = codes.DeadlineExceeded + } + } + return stCode +} + +func getRstCode(err error) (rstCode http2.ErrCode) { + if err == nil || err == io.EOF { + return http2.ErrCodeNo + } + rstCode = http2.ErrCodeCancel + statusErr, ok := err.(*status.Error) + if !ok { + return + } + mappingErr := statusErr.GetMappingErr() + if mappingErr == nil { + return + } + code, ok := mappingErr2RstCodeMap[mappingErr] + if !ok { + return + } + return code +} + +var ( + customStatusCodeEnabled = false + mappingErr2StatusCodeMap = map[error]codes.Code{} +) + +// SetCustomStatusCodeEnabled enables/disables the ErrType and custom Status code mapping. +// it is off by default. +func SetCustomStatusCodeEnabled(flag bool) { + customStatusCodeEnabled = flag +} + +// RegisterCustomStatusCode registers a mapping between an ErrType and a custom Status code. +// The mapping only works when SetCustomStatusCodeEnabled(true) is invoked. +// e.g. RegisterCustomStatusCode(ErrTypeGracefulShutdown, codes.Unavailable) +func RegisterCustomStatusCode(errType ErrType, code codes.Code) { + mappingErr, ok := errType2ErrMap[errType] + if !ok { + return + } + mappingErr2StatusCodeMap[mappingErr] = code +} + +func isCustomStatusCodeEnabled() bool { + return customStatusCodeEnabled +} + +func cleanupCustomStatusCodeMapping() { + mappingErr2StatusCodeMap = map[error]codes.Code{} +} + +func getStatusCode(def codes.Code, mappingErr error) codes.Code { + code, ok := mappingErr2StatusCodeMap[mappingErr] + if isCustomStatusCodeEnabled() && ok { + return code + } + return def +} diff --git a/pkg/remote/trans/nphttp2/grpc/http_util_test.go b/pkg/remote/trans/nphttp2/grpc/http_util_test.go index 9cd1592413..576c14ec43 100644 --- a/pkg/remote/trans/nphttp2/grpc/http_util_test.go +++ b/pkg/remote/trans/nphttp2/grpc/http_util_test.go @@ -19,10 +19,15 @@ package grpc import ( "bytes" "context" + "io" "testing" "time" + "golang.org/x/net/http2" + "github.com/cloudwego/kitex/internal/test" + "github.com/cloudwego/kitex/pkg/kerrors" + "github.com/cloudwego/kitex/pkg/remote/trans/nphttp2/codes" ) func TestEncoding(t *testing.T) { @@ -212,3 +217,277 @@ func TestConnectionError(t *testing.T) { ori = connectionError.Origin() test.Assert(t, ori == context.Canceled) } + +func Test_getMappingErrAndStatusCode(t *testing.T) { + testcases := []struct { + desc string + input []http2.ErrCode + want []struct { + err error + stCode codes.Code + } + setup func(t *testing.T) + clean func(t *testing.T) + }{ + { + desc: "normal RstCode", + input: []http2.ErrCode{ + http2.ErrCodeNo, + http2.ErrCodeCancel, + }, + want: []struct { + err error + stCode codes.Code + }{ + { + err: errRecvRstStream, + stCode: codes.Internal, + }, + { + err: errRecvRstStream, + stCode: codes.Canceled, + }, + }, + }, + { + desc: "custom RstCode", + setup: customSetup, + input: []http2.ErrCode{ + http2.ErrCode(1000), + http2.ErrCode(1001), + http2.ErrCode(1002), + http2.ErrCode(1003), + http2.ErrCode(1004), + http2.ErrCode(1005), + http2.ErrCode(1006), + http2.ErrCode(1007), + http2.ErrCode(1008), + http2.ErrCode(1009), + http2.ErrCode(1010), + }, + want: []struct { + err error + stCode codes.Code + }{ + { + err: kerrors.ErrGracefulShutdown, + stCode: codes.Unavailable, + }, + { + err: kerrors.ErrBizCanceled, + stCode: codes.Canceled, + }, + { + err: kerrors.ErrHandlerReturn, + stCode: codes.Canceled, + }, + { + err: kerrors.ErrStreamingCanceled, + stCode: codes.Canceled, + }, + { + err: kerrors.ErrStreamTimeout, + stCode: codes.DeadlineExceeded, + }, + { + err: errMiddleHeader, + stCode: codes.Canceled, + }, + { + err: errDecodeHeader, + stCode: codes.Canceled, + }, + { + err: errHTTP2Stream, + stCode: codes.Canceled, + }, + { + err: errClosedWithoutTrailer, + stCode: codes.Canceled, + }, + { + err: errStreamFlowControl, + stCode: codes.Canceled, + }, + { + err: kerrors.ErrMetaSizeExceeded, + stCode: codes.Canceled, + }, + }, + clean: customClean, + }, + } + + for _, tc := range testcases { + t.Run(tc.desc, func(t *testing.T) { + if tc.setup != nil { + tc.setup(t) + } + if tc.clean != nil { + defer tc.clean(t) + } + for i, rstCode := range tc.input { + err, stCode := getMappingErrAndStatusCode(context.Background(), rstCode) + test.Assert(t, err == tc.want[i].err, err) + res := stCode == tc.want[i].stCode + if !res { + t.Logf("stCode: %d, err: %v", stCode, err) + } + test.Assert(t, stCode == tc.want[i].stCode, stCode) + } + }) + } +} + +func Test_getRstCode(t *testing.T) { + testcases := []struct { + desc string + input []error + want []http2.ErrCode + setup func(*testing.T) + clean func(*testing.T) + }{ + { + desc: "normal RstCode", + input: []error{ + nil, io.EOF, + newStatus(codes.Internal, kerrors.ErrGracefulShutdown, "test").Err(), + }, + want: []http2.ErrCode{ + http2.ErrCodeNo, http2.ErrCodeNo, + http2.ErrCodeCancel, + }, + }, + { + desc: "custom RstCode", + input: []error{ + nil, io.EOF, + newStatus(codes.Internal, kerrors.ErrGracefulShutdown, "test").Err(), + }, + want: []http2.ErrCode{ + http2.ErrCodeNo, http2.ErrCodeNo, + http2.ErrCode(1000), + }, + setup: customSetup, + clean: customClean, + }, + } + + for _, tc := range testcases { + t.Run(tc.desc, func(t *testing.T) { + if tc.setup != nil { + tc.setup(t) + } + if tc.clean != nil { + defer tc.clean(t) + } + for i, err := range tc.input { + rstCode := getRstCode(err) + test.Assert(t, rstCode == tc.want[i], rstCode) + } + }) + } +} + +func Test_getStatusCode(t *testing.T) { + testcases := []struct { + desc string + def codes.Code + input []error + want []codes.Code + setup func(t *testing.T) + clean func(t *testing.T) + }{ + { + desc: "without custom status code", + def: codes.Internal, + input: []error{ + errHTTP2Stream, + }, + want: []codes.Code{ + codes.Internal, + }, + }, + { + desc: "custom status code", + def: codes.Internal, + input: []error{ + kerrors.ErrGracefulShutdown, kerrors.ErrBizCanceled, kerrors.ErrHandlerReturn, kerrors.ErrStreamingCanceled, + kerrors.ErrStreamTimeout, kerrors.ErrMetaSizeExceeded, + errHTTP2Stream, errClosedWithoutTrailer, errMiddleHeader, errDecodeHeader, errRecvRstStream, errStreamDrain, errStreamFlowControl, errIllegalHeaderWrite, errStreamIsDone, errMaxStreamExceeded, + errHTTP2Connection, errEstablishConnection, errHandleGoAway, errKeepAlive, errOperateHeaders, errNoActiveStream, errControlBufFinished, errNotReachable, errConnectionIsClosing, + }, + want: []codes.Code{ + codes.Code(10000), codes.Code(10001), codes.Code(10002), codes.Code(10003), + codes.Code(10004), codes.Code(10005), + codes.Code(10006), codes.Code(10007), codes.Code(10008), codes.Code(10009), codes.Code(10010), codes.Code(10011), codes.Code(10012), codes.Code(10013), codes.Code(10014), codes.Code(10015), + codes.Code(10016), codes.Code(10017), codes.Code(10018), codes.Code(10019), codes.Code(10020), codes.Code(10021), codes.Code(10022), codes.Code(10023), codes.Code(10024), + }, + setup: customSetup, + clean: customClean, + }, + } + + for _, tc := range testcases { + t.Run(tc.desc, func(t *testing.T) { + if tc.setup != nil { + tc.setup(t) + } + if tc.clean != nil { + defer tc.clean(t) + } + for i, err := range tc.input { + code := getStatusCode(tc.def, err) + test.Assert(t, code == tc.want[i], err) + } + }) + } +} + +func customSetup(t *testing.T) { + SetCustomRstCodeEnabled(true) + RegisterCustomRstCode(ErrTypeGracefulShutdown, 1000) + RegisterCustomRstCode(ErrTypeBizCanceled, 1001) + RegisterCustomRstCode(ErrTypeHandlerReturn, 1002) + RegisterCustomRstCode(ErrTypeStreamingCanceled, 1003) + RegisterCustomRstCode(ErrTypeStreamTimeout, 1004) + RegisterCustomRstCode(ErrTypeMiddleHeader, 1005) + RegisterCustomRstCode(ErrTypeDecodeHeader, 1006) + RegisterCustomRstCode(ErrTypeHTTP2Stream, 1007) + RegisterCustomRstCode(ErrTypeClosedWithoutTrailer, 1008) + RegisterCustomRstCode(ErrTypeStreamFlowControl, 1009) + RegisterCustomRstCode(ErrTypeMetaSizeExceeded, 1010) + SetCustomStatusCodeEnabled(true) + RegisterCustomStatusCode(ErrTypeGracefulShutdown, codes.Code(10000)) + RegisterCustomStatusCode(ErrTypeBizCanceled, codes.Code(10001)) + RegisterCustomStatusCode(ErrTypeHandlerReturn, codes.Code(10002)) + RegisterCustomStatusCode(ErrTypeStreamingCanceled, codes.Code(10003)) + RegisterCustomStatusCode(ErrTypeStreamTimeout, codes.Code(10004)) + RegisterCustomStatusCode(ErrTypeMetaSizeExceeded, codes.Code(10005)) + RegisterCustomStatusCode(ErrTypeHTTP2Stream, codes.Code(10006)) + RegisterCustomStatusCode(ErrTypeClosedWithoutTrailer, codes.Code(10007)) + RegisterCustomStatusCode(ErrTypeMiddleHeader, codes.Code(10008)) + RegisterCustomStatusCode(ErrTypeDecodeHeader, codes.Code(10009)) + RegisterCustomStatusCode(ErrTypeRecvRstStream, codes.Code(10010)) + RegisterCustomStatusCode(ErrTypeStreamDrain, codes.Code(10011)) + RegisterCustomStatusCode(ErrTypeStreamFlowControl, codes.Code(10012)) + RegisterCustomStatusCode(ErrTypeIllegalHeaderWrite, codes.Code(10013)) + RegisterCustomStatusCode(ErrTypeStreamIsDone, codes.Code(10014)) + RegisterCustomStatusCode(ErrTypeMaxStreamExceeded, codes.Code(10015)) + RegisterCustomStatusCode(ErrTypeHTTP2Connection, codes.Code(10016)) + RegisterCustomStatusCode(ErrTypeEstablishConnection, codes.Code(10017)) + RegisterCustomStatusCode(ErrTypeHandleGoAway, codes.Code(10018)) + RegisterCustomStatusCode(ErrTypeKeepAlive, codes.Code(10019)) + RegisterCustomStatusCode(ErrTypeOperateHeaders, codes.Code(10020)) + RegisterCustomStatusCode(ErrTypeNoActiveStream, codes.Code(10021)) + RegisterCustomStatusCode(ErrTypeControlBufFinished, codes.Code(10022)) + RegisterCustomStatusCode(ErrTypeNotReachable, codes.Code(10023)) + RegisterCustomStatusCode(ErrTypeConnectionIsClosing, codes.Code(10024)) +} + +func customClean(t *testing.T) { + SetCustomRstCodeEnabled(false) + cleanupCustomRstCodeMapping() + SetCustomStatusCodeEnabled(false) + cleanupCustomStatusCodeMapping() +} diff --git a/pkg/remote/trans/nphttp2/grpc/transport.go b/pkg/remote/trans/nphttp2/grpc/transport.go index e2b5f0ef1f..b08fa2f70d 100644 --- a/pkg/remote/trans/nphttp2/grpc/transport.go +++ b/pkg/remote/trans/nphttp2/grpc/transport.go @@ -34,7 +34,9 @@ import ( "sync" "sync/atomic" + istatus "github.com/cloudwego/kitex/internal/remote/trans/grpc/status" "github.com/cloudwego/kitex/pkg/kerrors" + "github.com/cloudwego/kitex/pkg/klog" "github.com/cloudwego/kitex/pkg/remote/trans/nphttp2/codes" "github.com/cloudwego/kitex/pkg/remote/trans/nphttp2/metadata" "github.com/cloudwego/kitex/pkg/remote/trans/nphttp2/status" @@ -198,7 +200,7 @@ func (r *recvBufferReader) readClient(p []byte) (n int, err error) { // TODO: delaying ctx error seems like a unnecessary side effect. What // we really want is to mark the stream as done, and return ctx error // faster. - r.closeStream(ContextErr(r.ctx.Err())) + r.closeStream(clientContextErr(r.ctx.Err())) m := <-r.recv.get() return r.readAdditional(m, p) case m := <-r.recv.get(): @@ -236,7 +238,7 @@ type Stream struct { st ServerTransport // nil for client side Stream ct *http2Client // nil for server side Stream ctx context.Context // the associated context of the stream - cancel cancelWithReason // always nil for client side Stream + cancelFunc cancelWithReason // always nil for client side Stream done chan struct{} // closed at the end of stream to unblock writers. On the client side. ctxDone <-chan struct{} // same as done chan but for server side. Cache of ctx.Done() (for performance) method string // the associated RPC method of the stream @@ -276,7 +278,7 @@ type Stream struct { // On client-side it is the status error received from the server. // On server-side it is unused. - status *status.Status + status *istatus.Status bizStatusErr kerrors.BizStatusErrorIface bytesReceived uint32 // indicates whether any bytes have been received on this stream @@ -285,6 +287,11 @@ type Stream struct { // contentSubtype is the content-subtype for requests. // this must be lowercase or the behavior is undefined. contentSubtype string + + // closeStreamErr is used to store the error when stream is closed + closeStreamErr atomic.Value + // sourceService is the source service name of this stream + sourceService string } // isHeaderSent is only valid on the server-side. @@ -479,6 +486,28 @@ func (s *Stream) Read(p []byte) (n int, err error) { return io.ReadFull(s.trReader, p) } +func (s *Stream) getCloseStreamErr() error { + rawErr := s.closeStreamErr.Load() + if rawErr != nil { + return rawErr.(error) + } + return errStatusStreamDone +} + +// err should be of the same type +func (s *Stream) cancel(err error) { + if err == nil { + s.cancelFunc(nil) + return + } + if _, ok := err.(*istatus.Error); !ok { + klog.CtxWarnf(s.ctx, "stream canceled with non status.Error: %v", err) + err = newStatus(codes.Canceled, kerrors.ErrStreamingCanceled, err.Error()).Err() + } + // all errors propagated by cancelFunc must be of type *status.Error or nil + s.cancelFunc(err) +} + // StreamWrite only used for unit test func StreamWrite(s *Stream, buffer *bytes.Buffer) { s.write(recvMsg{buffer: buffer}) @@ -496,10 +525,8 @@ func CreateStream(ctx context.Context, id uint32, requestRead func(i int), metho }, windowHandler: func(i int) {}, } - stream := &Stream{ id: id, - ctx: ctx, method: method, buf: recvBuffer, trReader: trReader, @@ -508,6 +535,9 @@ func CreateStream(ctx context.Context, id uint32, requestRead func(i int), metho hdrMu: sync.Mutex{}, } + ctx, cancel := context.WithCancel(ctx) + stream.ctx, stream.cancelFunc = newContextWithCancelReason(ctx, cancel) + return stream } @@ -762,20 +792,19 @@ func (e ConnectionError) Origin() error { var ( // ErrConnClosing indicates that the transport is closing. - ErrConnClosing = connectionErrorf(true, nil, "transport is closing") + ErrConnClosing = connectionErrorf(true, nil, "transport is closing") + errStatusConnClosing = newStatus(codes.Unavailable, errConnectionIsClosing, "transport is closing").Err() + errStatusControlBufFinished = newStatus(codes.Unavailable, errControlBufFinished, "controlbuf finished").Err() // errStreamDone is returned from write at the client side to indicate application // layer of an error. - errStreamDone = errors.New("the stream is done") + errStreamDone = errors.New("the stream is done") + errStatusStreamDone = newStatus(codes.Internal, errStreamIsDone, errStreamDone.Error()).Err() - // errStreamDrain indicates that the stream is rejected because the + // errStatusStreamDrain indicates that the stream is rejected because the // connection is draining. This could be caused by goaway or balancer // removing the address. - errStreamDrain = status.New(codes.Unavailable, "the connection is draining").Err() - - // StatusGoAway indicates that the server sent a GOAWAY that included this - // stream's ID in unprocessed RPCs. - statusGoAway = status.New(codes.Unavailable, "the stream is rejected because server is draining the connection") + errStatusStreamDrain = newStatus(codes.Unavailable, errStreamDrain, "the connection is draining").Err() ) // GoAwayReason contains the reason for the GoAway frame received. @@ -796,15 +825,39 @@ const ( func ContextErr(err error) error { switch err { case context.DeadlineExceeded: - return status.New(codes.DeadlineExceeded, err.Error()).Err() + return newStatus(codes.DeadlineExceeded, kerrors.ErrStreamTimeout, err.Error()).Err() case context.Canceled: - return status.New(codes.Canceled, err.Error()).Err() + return newStatus(codes.Canceled, kerrors.ErrBizCanceled, err.Error()).Err() } - statusErr, ok := err.(*status.Error) + statusErr, ok := err.(*istatus.Error) if ok { // only returned by contextWithCancelReason return statusErr } - return status.Errorf(codes.Internal, "Unexpected error from context packet: %v", err) + return newStatusf(codes.Internal, kerrors.ErrStreamingCanceled, "Unexpected error from context packet: %v", err).Err() +} + +// clientContextErr converts the error from context package into a status error +// when the error is passed through streams by cancel. +func clientContextErr(err error) error { + stErr := ContextErr(err).(*istatus.Error) + switch { + // errors defined here could pass through streams by cancel + // e.g. A -> B -> C + // stream between A and B closed by Graceful Shutdown, BC.Recv() could get kerrors.ErrGracefulShutdown + case errors.Is(stErr, kerrors.ErrStreamTimeout): + case errors.Is(stErr, kerrors.ErrBizCanceled): + case errors.Is(stErr, kerrors.ErrGracefulShutdown): + case errors.Is(stErr, kerrors.ErrHandlerReturn): + default: + // Other errs are treated as kerrors.ErrStreamingCanceled + // when passed through streams by cancel. + // Then users could use errors.Is(err, kerrors.ErrStreamingCanceled) + // to check if an exception in the upstream stream caused cancel to be delivered + st := stErr.GRPCStatus() + istatus.InjectMappingErr(st, kerrors.ErrStreamingCanceled) + return st.Err() + } + return stErr } // IsStreamDoneErr returns true if the error indicates that the stream is done. @@ -840,3 +893,19 @@ func tlsAppendH2ToALPNProtocols(ps []string) []string { ret = append(ret, ps...) return append(ret, alpnProtoStrH2) } + +var ( + sendRSTStreamFrameSuffix = " [send RSTStream Frame]" + triggeredByRemoteServiceSuffix = " [triggered by remote service]" + triggeredByHandlerSideSuffix = " [triggered by handler side]" +) + +func newStatus(def codes.Code, mappingErr error, msg string) *istatus.Status { + code := getStatusCode(def, mappingErr) + return istatus.NewWithMappingErr(code, mappingErr, msg) +} + +func newStatusf(def codes.Code, mappingErr error, format string, a ...interface{}) *istatus.Status { + code := getStatusCode(def, mappingErr) + return istatus.NewfWithMappingErr(code, mappingErr, format, a...) +} diff --git a/pkg/remote/trans/nphttp2/grpc/transport_test.go b/pkg/remote/trans/nphttp2/grpc/transport_test.go index ec98259a52..e101883ea0 100644 --- a/pkg/remote/trans/nphttp2/grpc/transport_test.go +++ b/pkg/remote/trans/nphttp2/grpc/transport_test.go @@ -30,6 +30,7 @@ import ( "io" "math" "net" + "reflect" "runtime" "strconv" "strings" @@ -41,7 +42,9 @@ import ( "golang.org/x/net/http2" "golang.org/x/net/http2/hpack" + istatus "github.com/cloudwego/kitex/internal/remote/trans/grpc/status" "github.com/cloudwego/kitex/internal/test" + "github.com/cloudwego/kitex/pkg/kerrors" "github.com/cloudwego/kitex/pkg/remote/trans/nphttp2/codes" "github.com/cloudwego/kitex/pkg/remote/trans/nphttp2/grpc/grpcframe" "github.com/cloudwego/kitex/pkg/remote/trans/nphttp2/grpc/testutils" @@ -121,13 +124,13 @@ func (h *testStreamHandler) handleStream(t *testing.T, s *Stream) { } if !bytes.Equal(p, req) { t.Errorf("handleStream got %v, want %v", p, req) - h.t.WriteStatus(s, status.New(codes.Internal, "panic")) + h.t.WriteStatus(s, istatus.New(codes.Internal, "panic")) return } // send a response back to the client. h.t.Write(s, nil, resp, &Options{}) // send the trailer to end the stream. - h.t.WriteStatus(s, status.New(codes.OK, "")) + h.t.WriteStatus(s, istatus.New(codes.OK, "")) } func (h *testStreamHandler) handleStreamPingPong(t *testing.T, s *Stream) { @@ -135,18 +138,18 @@ func (h *testStreamHandler) handleStreamPingPong(t *testing.T, s *Stream) { for { if _, err := s.Read(header); err != nil { if err == io.EOF { - h.t.WriteStatus(s, status.New(codes.OK, "")) + h.t.WriteStatus(s, istatus.New(codes.OK, "")) return } t.Errorf("Error on server while reading data header: %v", err) - h.t.WriteStatus(s, status.New(codes.Internal, "panic")) + h.t.WriteStatus(s, istatus.New(codes.Internal, "panic")) return } sz := binary.BigEndian.Uint32(header[1:]) msg := make([]byte, int(sz)) if _, err := s.Read(msg); err != nil { t.Errorf("Error on server while reading message: %v", err) - h.t.WriteStatus(s, status.New(codes.Internal, "panic")) + h.t.WriteStatus(s, istatus.New(codes.Internal, "panic")) return } buf := make([]byte, sz+5) @@ -161,7 +164,7 @@ func (h *testStreamHandler) handleStreamMisbehave(t *testing.T, s *Stream) { conn, ok := s.st.(*http2Server) if !ok { t.Errorf("Failed to convert %v to *http2Server", s.st) - h.t.WriteStatus(s, status.New(codes.Internal, "")) + h.t.WriteStatus(s, istatus.New(codes.Internal, "")) return } var sent int @@ -286,7 +289,7 @@ func (h *testStreamHandler) handleStreamDelayRead(t *testing.T, s *Stream) { return } // send the trailer to end the stream. - if err := h.t.WriteStatus(s, status.New(codes.OK, "")); err != nil { + if err := h.t.WriteStatus(s, istatus.New(codes.OK, "")); err != nil { t.Errorf("server WriteStatus got %v, want ", err) return } @@ -526,8 +529,8 @@ func TestInflightStreamClosing(t *testing.T) { } donec := make(chan struct{}) - // serr := &status.Error{e: s.Proto()} - serr := status.Err(codes.Internal, "client connection is closing") + // serr := &istatus.Error{e: s.Proto()} + serr := istatus.Err(codes.Internal, "client connection is closing") go func() { defer close(donec) if _, err := stream.Read(make([]byte, defaultWindowSize)); err != serr { @@ -822,8 +825,8 @@ func TestLargeMessageWithDelayRead(t *testing.T) { // return // } // ct.write(str, nil, nil, &Options{Last: true}) -// if _, err := str.Read(make([]byte, 8)); err != errStreamDrain && err != ErrConnClosing { -// t.Errorf("_.Read(_) = _, %v, want _, %v or %v", err, errStreamDrain, ErrConnClosing) +// if _, err := str.Read(make([]byte, 8)); err != errStatusStreamDrain && err != ErrConnClosing { +// t.Errorf("_.Read(_) = _, %v, want _, %v or %v", err, errStatusStreamDrain, ErrConnClosing) // } // }() // } @@ -858,11 +861,9 @@ func TestLargeMessageSuspension(t *testing.T) { msg := make([]byte, initialWindowSize*8) ct.Write(s, nil, msg, &Options{}) err = ct.Write(s, nil, msg, &Options{Last: true}) - if err != errStreamDone { - t.Fatalf("write got %v, want io.EOF", err) - } - expectedErr := status.Err(codes.DeadlineExceeded, context.DeadlineExceeded.Error()) - if _, err := s.Read(make([]byte, 8)); err.Error() != expectedErr.Error() { + test.Assert(t, errors.Is(err, ContextErr(ctx.Err())), err) + expectedErr := newStatus(codes.DeadlineExceeded, kerrors.ErrStreamTimeout, context.DeadlineExceeded.Error()) + if _, err := s.Read(make([]byte, 8)); reflect.DeepEqual(err, expectedErr) { t.Fatalf("Read got %v of type %T, want %v", err, err, expectedErr) } ct.Close(errSelfCloseForTest) @@ -892,7 +893,7 @@ func TestMaxStreams(t *testing.T) { pctx, cancel := context.WithCancel(context.Background()) defer cancel() timer := time.NewTimer(time.Second * 10) - expectedErr := status.Err(codes.DeadlineExceeded, context.DeadlineExceeded.Error()) + expectedErr := newStatus(codes.DeadlineExceeded, kerrors.ErrStreamTimeout, context.DeadlineExceeded.Error()) for { select { case <-timer.C: @@ -906,7 +907,7 @@ func TestMaxStreams(t *testing.T) { if str, err := ct.NewStream(ctx, callHdr); err == nil { slist = append(slist, str) continue - } else if err.Error() != expectedErr.Error() { + } else if reflect.DeepEqual(err, expectedErr) { t.Fatalf("ct.NewStream(_,_) = _, %v, want _, %v", err, expectedErr) } timer.Stop() @@ -994,8 +995,9 @@ func TestServerContextCanceledOnClosedConnection(t *testing.T) { ct.Close(errSelfCloseForTest) select { case <-ss.Context().Done(): - if ss.Context().Err() != errConnectionEOF { - t.Fatalf("ss.Context().Err() got %v, want %v", ss.Context().Err(), errConnectionEOF) + cErr := ss.Context().Err() + if !errors.Is(cErr, errHTTP2Connection) { + t.Fatalf("ss.Context().Err() got %v, want %v", ss.Context().Err(), errStatusConnectionEOF) } case <-time.After(3 * time.Second): t.Fatalf("%s", "Failed to cancel the context of the sever side stream.") @@ -1430,7 +1432,7 @@ func TestServerWithMisbehavedClient(t *testing.T) { // } //} -var encodingTestStatus = status.New(codes.Internal, "\n") +var encodingTestStatus = istatus.New(codes.Internal, "\n") func TestEncodingRequiredStatus(t *testing.T) { server, ct := setUp(t, 0, math.MaxUint32, encodingRequiredStatus) @@ -1531,8 +1533,8 @@ func TestContextErr(t *testing.T) { // outputs errOut error }{ - {context.DeadlineExceeded, status.Err(codes.DeadlineExceeded, context.DeadlineExceeded.Error())}, - {context.Canceled, status.Err(codes.Canceled, context.Canceled.Error())}, + {context.DeadlineExceeded, newStatus(codes.DeadlineExceeded, kerrors.ErrStreamTimeout, context.DeadlineExceeded.Error()).Err()}, + {context.Canceled, newStatus(codes.Canceled, kerrors.ErrBizCanceled, context.Canceled.Error()).Err()}, } { err := ContextErr(test.errIn) if err.Error() != test.errOut.Error() { @@ -2050,3 +2052,49 @@ func TestTlsAppendH2ToALPNProtocols(t *testing.T) { appended = tlsAppendH2ToALPNProtocols(appended) test.Assert(t, len(appended) == 1) } + +func Test_clientContextErr(t *testing.T) { + testcases := []struct { + desc string + input error + target error + }{ + { + desc: "user invokes cancel()", + input: context.Canceled, + target: kerrors.ErrBizCanceled, + }, + { + desc: "ctx timeout", + input: context.DeadlineExceeded, + target: kerrors.ErrStreamTimeout, + }, + { + desc: "kerrors.ErrGracefulShutdown pass through", + input: newStatus(codes.Internal, kerrors.ErrGracefulShutdown, "pass through").Err(), + target: kerrors.ErrGracefulShutdown, + }, + { + desc: "kerrors.ErrBizCanceled pass through", + input: newStatus(codes.Internal, kerrors.ErrBizCanceled, "pass through").Err(), + target: kerrors.ErrBizCanceled, + }, + { + desc: "kerrors.ErrHandlerReturn pass through", + input: newStatus(codes.Internal, kerrors.ErrHandlerReturn, "pass through").Err(), + target: kerrors.ErrHandlerReturn, + }, + { + desc: "non-pass through", + input: newStatus(codes.Internal, errStreamFlowControl, "non-pass through").Err(), + target: kerrors.ErrStreamingCanceled, + }, + } + + for _, tc := range testcases { + t.Run(tc.desc, func(t *testing.T) { + err := clientContextErr(tc.input) + test.Assert(t, errors.Is(err, tc.target), err) + }) + } +} diff --git a/pkg/remote/trans/nphttp2/status/status.go b/pkg/remote/trans/nphttp2/status/status.go index 130d4425cd..f6fe7cbe27 100644 --- a/pkg/remote/trans/nphttp2/status/status.go +++ b/pkg/remote/trans/nphttp2/status/status.go @@ -30,30 +30,23 @@ package status import ( - "context" - "errors" "fmt" spb "google.golang.org/genproto/googleapis/rpc/status" - "google.golang.org/protobuf/proto" - "google.golang.org/protobuf/types/known/anypb" + istatus "github.com/cloudwego/kitex/internal/remote/trans/grpc/status" "github.com/cloudwego/kitex/pkg/remote/trans/nphttp2/codes" ) -type Iface interface { - GRPCStatus() *Status -} +type Iface = istatus.Iface // Status represents an RPC status code, message, and details. It is immutable // and should be created with New, Newf, or FromProto. -type Status struct { - s *spb.Status -} +type Status = istatus.Status // New returns a Status representing c and msg. func New(c codes.Code, msg string) *Status { - return &Status{s: &spb.Status{Code: int32(c), Message: msg}} + return istatus.New(c, msg) } // Newf returns New(c, fmt.Sprintf(format, a...)). @@ -68,7 +61,7 @@ func ErrorProto(s *spb.Status) error { // FromProto returns a Status representing s. func FromProto(s *spb.Status) *Status { - return &Status{s: proto.Clone(s).(*spb.Status)} + return istatus.FromProto(s) } // Err returns an error representing c and msg. If c is OK, returns nil. @@ -81,120 +74,15 @@ func Errorf(c codes.Code, format string, a ...interface{}) error { return Err(c, fmt.Sprintf(format, a...)) } -// Code returns the status code contained in s. -func (s *Status) Code() codes.Code { - if s == nil || s.s == nil { - return codes.OK - } - return codes.Code(s.s.Code) -} - -// Message returns the message contained in s. -func (s *Status) Message() string { - if s == nil || s.s == nil { - return "" - } - return s.s.Message -} - -// AppendMessage append extra msg for Status -func (s *Status) AppendMessage(extraMsg string) *Status { - if s == nil || s.s == nil || extraMsg == "" { - return s - } - s.s.Message = fmt.Sprintf("%s %s", s.s.Message, extraMsg) - return s -} - -// Proto returns s's status as an spb.Status proto message. -func (s *Status) Proto() *spb.Status { - if s == nil { - return nil - } - return proto.Clone(s.s).(*spb.Status) -} - -// Err returns an immutable error representing s; returns nil if s.Code() is OK. -func (s *Status) Err() error { - if s.Code() == codes.OK { - return nil - } - return &Error{e: s.Proto()} -} - -// WithDetails returns a new status with the provided details messages appended to the status. -// If any errors are encountered, it returns nil and the first error encountered. -func (s *Status) WithDetails(details ...proto.Message) (*Status, error) { - if s.Code() == codes.OK { - return nil, errors.New("no error details for status with code OK") - } - // s.Code() != OK implies that s.Proto() != nil. - p := s.Proto() - for _, detail := range details { - any, err := anypb.New(detail) - if err != nil { - return nil, err - } - p.Details = append(p.Details, any) - } - return &Status{s: p}, nil -} - -// Details returns a slice of details messages attached to the status. -// If a detail cannot be decoded, the error is returned in place of the detail. -func (s *Status) Details() []interface{} { - if s == nil || s.s == nil { - return nil - } - details := make([]interface{}, 0, len(s.s.Details)) - for _, any := range s.s.Details { - detail, err := any.UnmarshalNew() - if err != nil { - details = append(details, err) - continue - } - details = append(details, detail) - } - return details -} - // Error wraps a pointer of a status proto. It implements error and Status, // and a nil *Error should never be returned by this package. -type Error struct { - e *spb.Status -} - -func (e *Error) Error() string { - return fmt.Sprintf("rpc error: code = %d desc = %s", codes.Code(e.e.GetCode()), e.e.GetMessage()) -} - -// GRPCStatus returns the Status represented by se. -func (e *Error) GRPCStatus() *Status { - return FromProto(e.e) -} - -// Is implements future error.Is functionality. -// A Error is equivalent if the code and message are identical. -func (e *Error) Is(target error) bool { - tse, ok := target.(*Error) - if !ok { - return false - } - return proto.Equal(e.e, tse.e) -} +type Error = istatus.Error // FromError returns a Status representing err if it was produced from this // package or has a method `GRPCStatus() *Status`. Otherwise, ok is false and a // Status is returned with codes.Unknown and the original error message. func FromError(err error) (s *Status, ok bool) { - if err == nil { - return nil, true - } - var se Iface - if errors.As(err, &se) { - return se.GRPCStatus(), true - } - return New(codes.Unknown, err.Error()), false + return istatus.FromError(err) } // Convert is a convenience function which removes the need to handle the @@ -207,29 +95,12 @@ func Convert(err error) *Status { // Code returns the Code of the error if it is a Status error, codes.OK if err // is nil, or codes.Unknown otherwise. func Code(err error) codes.Code { - // Don't use FromError to avoid allocation of OK status. - if err == nil { - return codes.OK - } - var se Iface - if errors.As(err, &se) { - return se.GRPCStatus().Code() - } - return codes.Unknown + return istatus.Code(err) } // FromContextError converts a context error into a Status. It returns a // Status with codes.OK if err is nil, or a Status with codes.Unknown if err is // non-nil and not a context error. func FromContextError(err error) *Status { - switch err { - case nil: - return nil - case context.DeadlineExceeded: - return New(codes.DeadlineExceeded, err.Error()) - case context.Canceled: - return New(codes.Canceled, err.Error()) - default: - return New(codes.Unknown, err.Error()) - } + return istatus.FromContextError(err) } diff --git a/pkg/streamx/provider/grpc/gerrors/gerrors.go b/pkg/streamx/provider/grpc/gerrors/gerrors.go new file mode 100644 index 0000000000..07928e8a7b --- /dev/null +++ b/pkg/streamx/provider/grpc/gerrors/gerrors.go @@ -0,0 +1,43 @@ +/* + * + * Copyright 2024 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * This file may have been modified by CloudWeGo authors. All CloudWeGo + * Modifications are Copyright 2021 CloudWeGo Authors. + */ + +package gerrors + +import ( + "github.com/cloudwego/kitex/pkg/remote/trans/nphttp2/grpc" +) + +// SetCustomRstCodeEnabled enables/disables the ErrType and custom RstCode mapping. +// it is off by default. +var SetCustomRstCodeEnabled = grpc.SetCustomRstCodeEnabled + +// RegisterCustomRstCode registers a mapping between an ErrType and a custom RstCode. +// The mapping only works when SetCustomRstCodeEnabled(true) is invoked. +// e.g. RegisterCustomRstCode(ErrTypeGracefulShutdown, 1000) +var RegisterCustomRstCode = grpc.RegisterCustomRstCode + +// SetCustomStatusCodeEnabled enables/disables the ErrType and custom Status code mapping. +// it is off by default. +var SetCustomStatusCodeEnabled = grpc.SetCustomStatusCodeEnabled + +// RegisterCustomStatusCode registers a mapping between an ErrType and a custom Status code. +// The mapping only works when SetCustomStatusCodeEnabled(true) is invoked. +// e.g. RegisterCustomStatusCode(ErrTypeGracefulShutdown, codes.Unavailable) +var RegisterCustomStatusCode = grpc.RegisterCustomStatusCode