forked from trpc-group/trpc-go
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathrouter.go
359 lines (316 loc) · 10.2 KB
/
router.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
//
//
// Tencent is pleased to support the open source community by making tRPC available.
//
// Copyright (C) 2023 THL A29 Limited, a Tencent company.
// All rights reserved.
//
// If you have downloaded a copy of the tRPC source code from Tencent,
// please note that tRPC source code is licensed under the Apache 2.0 License,
// A copy of the Apache 2.0 License is included in this file.
//
//
package restful
import (
"context"
"fmt"
"io"
"net/http"
"strconv"
"strings"
"sync"
"google.golang.org/protobuf/proto"
"google.golang.org/protobuf/types/known/emptypb"
"trpc.group/trpc-go/trpc-go/codec"
"trpc.group/trpc-go/trpc-go/errs"
"trpc.group/trpc-go/trpc-go/filter"
"trpc.group/trpc-go/trpc-go/internal/dat"
)
// Router is restful router.
type Router struct {
opts *Options
transcoders map[string][]*transcoder
}
// NewRouter creates a Router.
func NewRouter(opts ...Option) *Router {
o := Options{
ErrorHandler: DefaultErrorHandler,
HeaderMatcher: DefaultHeaderMatcher,
ResponseHandler: DefaultResponseHandler,
FastHTTPErrHandler: DefaultFastHTTPErrorHandler,
FastHTTPHeaderMatcher: DefaultFastHTTPHeaderMatcher,
FastHTTPRespHandler: DefaultFastHTTPRespHandler,
}
for _, opt := range opts {
opt(&o)
}
o.rebuildHeaderMatcher()
return &Router{
opts: &o,
transcoders: make(map[string][]*transcoder),
}
}
var (
routers = make(map[string]http.Handler) // tRPC service name -> Router
routerLock sync.RWMutex
)
// RegisterRouter registers a Router which corresponds to a tRPC Service.
func RegisterRouter(name string, router http.Handler) {
routerLock.Lock()
routers[name] = router
routerLock.Unlock()
}
// GetRouter returns a Router which corresponds to a tRPC Service.
func GetRouter(name string) http.Handler {
routerLock.RLock()
router := routers[name]
routerLock.RUnlock()
return router
}
// ProtoMessage is alias of proto.Message.
type ProtoMessage proto.Message
// Initializer initializes a ProtoMessage.
type Initializer func() ProtoMessage
// BodyLocator locates which fields of the proto message would be
// populated according to HttpRule body.
type BodyLocator interface {
Body() string
Locate(ProtoMessage) interface{}
}
// ResponseBodyLocator locates which fields of the proto message would be marshaled
// according to HttpRule response_body.
type ResponseBodyLocator interface {
ResponseBody() string
Locate(ProtoMessage) interface{}
}
// HandleFunc is tRPC method handle function.
type HandleFunc func(svc interface{}, ctx context.Context, reqBody interface{}) (interface{}, error)
// ExtractFilterFunc extracts tRPC service filter chain.
type ExtractFilterFunc func() filter.ServerChain
// Binding is the binding of tRPC method and HttpRule.
type Binding struct {
Name string
Input Initializer
Output Initializer
Filter HandleFunc
HTTPMethod string
Pattern *Pattern
Body BodyLocator
ResponseBody ResponseBodyLocator
}
// AddImplBinding creates a new binding with a specified service implementation.
func (r *Router) AddImplBinding(binding *Binding, serviceImpl interface{}) error {
tr, err := r.newTranscoder(binding, serviceImpl)
if err != nil {
return fmt.Errorf("new transcoder during add impl binding: %w", err)
}
// add transcoder
r.transcoders[binding.HTTPMethod] = append(r.transcoders[binding.HTTPMethod], tr)
return nil
}
func (r *Router) newTranscoder(binding *Binding, serviceImpl interface{}) (*transcoder, error) {
if binding.Output == nil {
binding.Output = func() ProtoMessage { return &emptypb.Empty{} }
}
// create a transcoder
tr := &transcoder{
name: binding.Name,
input: binding.Input,
output: binding.Output,
handler: binding.Filter,
httpMethod: binding.HTTPMethod,
pat: binding.Pattern,
body: binding.Body,
respBody: binding.ResponseBody,
router: r,
discardUnknownParams: r.opts.DiscardUnknownParams,
serviceImpl: serviceImpl,
}
// create a dat, filter all fields specified in HttpRule
var fps [][]string
if fromPat := binding.Pattern.FieldPaths(); fromPat != nil {
fps = append(fps, fromPat...)
}
if binding.Body != nil {
if fromBody := binding.Body.Body(); fromBody != "" && fromBody != "*" {
fps = append(fps, strings.Split(fromBody, "."))
}
}
if len(fps) > 0 {
doubleArrayTrie, err := dat.Build(fps)
if err != nil {
return nil, fmt.Errorf("failed to build dat: %w", err)
}
tr.dat = doubleArrayTrie
}
return tr, nil
}
// ctxForCompatibility is used only for compatibility with thttp.
var ctxForCompatibility func(context.Context, http.ResponseWriter, *http.Request) context.Context
// SetCtxForCompatibility is used only for compatibility with thttp.
func SetCtxForCompatibility(f func(context.Context, http.ResponseWriter, *http.Request) context.Context) {
ctxForCompatibility = f
}
// HeaderMatcher matches http request header to tRPC Stub Context.
type HeaderMatcher func(
ctx context.Context,
w http.ResponseWriter,
r *http.Request,
serviceName, methodName string,
) (context.Context, error)
// DefaultHeaderMatcher is the default HeaderMatcher.
var DefaultHeaderMatcher = func(
ctx context.Context,
w http.ResponseWriter,
req *http.Request,
serviceName, methodName string,
) (context.Context, error) {
// Noted: it's better to do the same thing as withNewMessage.
return withNewMessage(ctx, serviceName, methodName), nil
}
// withNewMessage create a new codec.Msg, put it into ctx,
// and set target service name and method name.
func withNewMessage(ctx context.Context, serviceName, methodName string) context.Context {
ctx, msg := codec.WithNewMessage(ctx)
msg.WithServerRPCName(methodName)
msg.WithCalleeServiceName(serviceName)
msg.WithSerializationType(codec.SerializationTypePB)
return ctx
}
// CustomResponseHandler is the custom response handler.
type CustomResponseHandler func(
ctx context.Context,
w http.ResponseWriter,
r *http.Request,
resp proto.Message,
body []byte,
) error
var httpStatusKey = "t-http-status"
// SetStatusCodeOnSucceed sets status code on succeed, should be 2XX.
// It's not supposed to call this function but use WithStatusCode in restful/errors.go
// to set status code on error.
func SetStatusCodeOnSucceed(ctx context.Context, code int) {
msg := codec.Message(ctx)
metadata := msg.ServerMetaData()
if metadata == nil {
metadata = codec.MetaData{}
}
metadata[httpStatusKey] = []byte(strconv.Itoa(code))
msg.WithServerMetaData(metadata)
}
// GetStatusCodeOnSucceed returns status code on succeed.
// SetStatusCodeOnSucceed must be called first in tRPC method.
func GetStatusCodeOnSucceed(ctx context.Context) int {
if metadata := codec.Message(ctx).ServerMetaData(); metadata != nil {
if buf, ok := metadata[httpStatusKey]; ok {
if code, err := strconv.Atoi(bytes2str(buf)); err == nil {
return code
}
}
}
return http.StatusOK
}
// DefaultResponseHandler is the default CustomResponseHandler.
var DefaultResponseHandler = func(
ctx context.Context,
w http.ResponseWriter,
r *http.Request,
resp proto.Message,
body []byte,
) error {
// compress
var writer io.Writer = w
_, c := compressorForTranscoding(r.Header[headerContentEncoding],
r.Header[headerAcceptEncoding])
if c != nil {
writeCloser, err := c.Compress(w)
if err != nil {
return fmt.Errorf("failed to compress resp body: %w", err)
}
defer writeCloser.Close()
w.Header().Set(headerContentEncoding, c.ContentEncoding())
writer = writeCloser
}
// set response content-type
_, s := serializerForTranscoding(r.Header[headerContentType],
r.Header[headerAccept])
w.Header().Set(headerContentType, s.ContentType())
// set status code
statusCode := GetStatusCodeOnSucceed(ctx)
w.WriteHeader(statusCode)
// response body
if statusCode != http.StatusNoContent && statusCode != http.StatusNotModified {
writer.Write(body)
}
return nil
}
// putBackCtxMessage calls codec.PutBackMessage to put a codec.Msg back to pool,
// if the codec.Msg has been put into ctx.
func putBackCtxMessage(ctx context.Context) {
if msg, ok := ctx.Value(codec.ContextKeyMessage).(codec.Msg); ok {
codec.PutBackMessage(msg)
}
}
// ServeHTTP implements http.Handler.
// TODO: better routing handling.
func (r *Router) ServeHTTP(w http.ResponseWriter, req *http.Request) {
ctx := ctxForCompatibility(req.Context(), w, req)
for _, tr := range r.transcoders[req.Method] {
fieldValues, err := tr.pat.Match(req.URL.Path)
if err == nil {
r.handle(ctx, w, req, tr, fieldValues)
return
}
}
r.opts.ErrorHandler(ctx, w, req, errs.New(errs.RetServerNoFunc, "failed to match any pattern"))
}
func (r *Router) handle(
ctx context.Context,
w http.ResponseWriter,
req *http.Request,
tr *transcoder,
fieldValues map[string]string,
) {
modifiedCtx, err := r.opts.HeaderMatcher(ctx, w, req, r.opts.ServiceName, tr.name)
if err != nil {
r.opts.ErrorHandler(ctx, w, req, errs.New(errs.RetServerDecodeFail, err.Error()))
return
}
ctx = modifiedCtx
defer putBackCtxMessage(ctx)
timeout := r.opts.Timeout
requestTimeout := codec.Message(ctx).RequestTimeout()
if requestTimeout > 0 && (requestTimeout < timeout || timeout == 0) {
timeout = requestTimeout
}
if timeout > 0 {
var cancel context.CancelFunc
ctx, cancel = context.WithTimeout(ctx, timeout)
defer cancel()
}
// get inbound/outbound Compressor and Serializer
reqCompressor, respCompressor := compressorForTranscoding(req.Header[headerContentEncoding],
req.Header[headerAcceptEncoding])
reqSerializer, respSerializer := serializerForTranscoding(req.Header[headerContentType],
req.Header[headerAccept])
// set transcoder params
params, _ := paramsPool.Get().(*transcodeParams)
params.reqCompressor = reqCompressor
params.respCompressor = respCompressor
params.reqSerializer = reqSerializer
params.respSerializer = respSerializer
params.body = req.Body
params.fieldValues = fieldValues
params.form = req.URL.Query()
defer putBackParams(params)
// transcode
resp, body, err := tr.transcode(ctx, params)
if err != nil {
r.opts.ErrorHandler(ctx, w, req, err)
return
}
// custom response handling
if err := r.opts.ResponseHandler(ctx, w, req, resp, body); err != nil {
r.opts.ErrorHandler(ctx, w, req, errs.New(errs.RetServerEncodeFail, err.Error()))
}
}