diff --git a/context.go b/context.go index 713a425..4097d87 100644 --- a/context.go +++ b/context.go @@ -32,8 +32,17 @@ type Context struct { engine *Engine } -func (p *Context) FormInt(name string, defval int) int { - ret := p.FormValue(name) +func (p *Context) setParam(name, val string) { + p.ParseForm() + p.Form.Set(name, val) +} + +func (p *Context) Param(name string) string { + return p.FormValue(name) +} + +func (p *Context) ParamInt(name string, defval int) int { + ret := p.Param(name) if ret != "" { if v, err := strconv.Atoi(ret); err == nil { return v diff --git a/demo/hello/hello.go b/demo/hello/hello.go new file mode 100644 index 0000000..4647683 --- /dev/null +++ b/demo/hello/hello.go @@ -0,0 +1,18 @@ +package main + +import ( + "github.com/goplus/yap" +) + +func main() { + y := yap.New() + y.GET("/p/:id", func(ctx *yap.Context) { + ctx.JSON(200, yap.H{ + "id": ctx.Param("id"), + }) + }) + y.Handle("/", func(ctx *yap.Context) { + ctx.TEXT(200, "text/html", `
Hello, Yap!`) + }) + y.Run(":8080") +} diff --git a/internal/url/path.go b/internal/url/path.go new file mode 100644 index 0000000..0d79c1f --- /dev/null +++ b/internal/url/path.go @@ -0,0 +1,161 @@ +/* + * Copyright (c) 2023 The GoPlus Authors (goplus.org). All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package url + +// CleanPath is the URL version of path.Clean, it returns a canonical URL path +// for p, eliminating . and .. elements. +// +// The following rules are applied iteratively until no further processing can +// be done: +// 1. Replace multiple slashes with a single slash. +// 2. Eliminate each . path name element (the current directory). +// 3. Eliminate each inner .. path name element (the parent directory) +// along with the non-.. element that precedes it. +// 4. Eliminate .. elements that begin a rooted path: +// that is, replace "/.." by "/" at the beginning of a path. +// +// If the result of this process is an empty string, "/" is returned +func CleanPath(p string) string { + const stackBufSize = 128 + + // Turn empty string into "/" + if p == "" { + return "/" + } + + // Reasonably sized buffer on stack to avoid allocations in the common case. + // If a larger buffer is required, it gets allocated dynamically. + buf := make([]byte, 0, stackBufSize) + + n := len(p) + + // Invariants: + // reading from path; r is index of next byte to process. + // writing to buf; w is index of next byte to write. + + // path must start with '/' + r := 1 + w := 1 + + if p[0] != '/' { + r = 0 + + if n+1 > stackBufSize { + buf = make([]byte, n+1) + } else { + buf = buf[:n+1] + } + buf[0] = '/' + } + + trailing := n > 1 && p[n-1] == '/' + + // A bit more clunky without a 'lazybuf' like the path package, but the loop + // gets completely inlined (bufApp calls). + // So in contrast to the path package this loop has no expensive function + // calls (except make, if needed). + + for r < n { + switch { + case p[r] == '/': + // empty path element, trailing slash is added after the end + r++ + + case p[r] == '.' && r+1 == n: + trailing = true + r++ + + case p[r] == '.' && p[r+1] == '/': + // . element + r += 2 + + case p[r] == '.' && p[r+1] == '.' && (r+2 == n || p[r+2] == '/'): + // .. element: remove to last / + r += 3 + + if w > 1 { + // can backtrack + w-- + + if len(buf) == 0 { + for w > 1 && p[w] != '/' { + w-- + } + } else { + for w > 1 && buf[w] != '/' { + w-- + } + } + } + + default: + // Real path element. + // Add slash if needed + if w > 1 { + bufApp(&buf, p, w, '/') + w++ + } + + // Copy element + for r < n && p[r] != '/' { + bufApp(&buf, p, w, p[r]) + w++ + r++ + } + } + } + + // Re-append trailing slash + if trailing && w > 1 { + bufApp(&buf, p, w, '/') + w++ + } + + // If the original string was not modified (or only shortened at the end), + // return the respective substring of the original string. + // Otherwise return a new string from the buffer. + if len(buf) == 0 { + return p[:w] + } + return string(buf[:w]) +} + +// Internal helper to lazily create a buffer if necessary. +// Calls to this function get inlined. +func bufApp(buf *[]byte, s string, w int, c byte) { + b := *buf + if len(b) == 0 { + // No modification of the original string so far. + // If the next character is the same as in the original string, we do + // not yet have to allocate a buffer. + if s[w] == c { + return + } + + // Otherwise use either the stack buffer, if it is large enough, or + // allocate a new buffer on the heap, and copy all previous characters. + if l := len(s); l > cap(b) { + *buf = make([]byte, len(s)) + } else { + *buf = (*buf)[:l] + } + b = *buf + + copy(b, s[:w]) + } + b[w] = c +} diff --git a/router.go b/router.go new file mode 100644 index 0000000..89605d0 --- /dev/null +++ b/router.go @@ -0,0 +1,288 @@ +/* + * Copyright (c) 2023 The GoPlus Authors (goplus.org). All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package yap + +import ( + "net/http" + "strings" + + "github.com/goplus/yap/internal/url" +) + +// router is a http rounter which can be used to dispatch requests to different +// handler functions via configurable routes +type router struct { + trees map[string]*node + + // An optional http.Handler that is called on automatic OPTIONS requests. + // The handler is only called if HandleOPTIONS is true and no OPTIONS + // handler for the specific path was set. + // The "Allowed" header is set before calling the handler. + GlobalOPTIONS http.Handler + + // Cached value of global (*) allowed methods + globalAllowed string + + // Configurable http.Handler which is called when a request + // cannot be routed and HandleMethodNotAllowed is true. + // If it is not set, http.Error with http.StatusMethodNotAllowed is used. + // The "Allow" header with allowed request methods is set before the handler + // is called. + MethodNotAllowed http.Handler + + // Function to handle panics recovered from http handlers. + // It should be used to generate a error page and return the http error code + // 500 (Internal Server Error). + // The handler can be used to keep your server from crashing because of + // unrecovered panics. + PanicHandler func(http.ResponseWriter, *http.Request, interface{}) + + // Enables automatic redirection if the current route can't be matched but a + // handler for the path with (without) the trailing slash exists. + // For example if /foo/ is requested but a route only exists for /foo, the + // client is redirected to /foo with http status code 301 for GET requests + // and 308 for all other request methods. + RedirectTrailingSlash bool + + // If enabled, the router tries to fix the current request path, if no + // handle is registered for it. + // First superfluous path elements like ../ or // are removed. + // Afterwards the router does a case-insensitive lookup of the cleaned path. + // If a handle can be found for this route, the router makes a redirection + // to the corrected path with status code 301 for GET requests and 308 for + // all other request methods. + // For example /FOO and /..//Foo could be redirected to /foo. + // RedirectTrailingSlash is independent of this option. + RedirectFixedPath bool + + // If enabled, the router checks if another method is allowed for the + // current route, if the current request can not be routed. + // If this is the case, the request is answered with 'Method Not Allowed' + // and HTTP status code 405. + // If no other Method is allowed, the request is delegated to the NotFound + // handler. + HandleMethodNotAllowed bool + + // If enabled, the router automatically replies to OPTIONS requests. + // Custom OPTIONS handlers take priority over automatic replies. + HandleOPTIONS bool +} + +func (r *router) init() { + r.RedirectTrailingSlash = true + r.RedirectFixedPath = true + r.HandleMethodNotAllowed = true + r.HandleOPTIONS = true +} + +// GET is a shortcut for router.Route(http.MethodGet, path, handle) +func (r *router) GET(path string, handle func(ctx *Context)) { + r.Route(http.MethodGet, path, handle) +} + +// HEAD is a shortcut for router.Route(http.MethodHead, path, handle) +func (r *router) HEAD(path string, handle func(ctx *Context)) { + r.Route(http.MethodHead, path, handle) +} + +// OPTIONS is a shortcut for router.Route(http.MethodOptions, path, handle) +func (r *router) OPTIONS(path string, handle func(ctx *Context)) { + r.Route(http.MethodOptions, path, handle) +} + +// POST is a shortcut for router.Route(http.MethodPost, path, handle) +func (r *router) POST(path string, handle func(ctx *Context)) { + r.Route(http.MethodPost, path, handle) +} + +// PUT is a shortcut for router.Route(http.MethodPut, path, handle) +func (r *router) PUT(path string, handle func(ctx *Context)) { + r.Route(http.MethodPut, path, handle) +} + +// PATCH is a shortcut for router.Route(http.MethodPatch, path, handle) +func (r *router) PATCH(path string, handle func(ctx *Context)) { + r.Route(http.MethodPatch, path, handle) +} + +// DELETE is a shortcut for router.Route(http.MethodDelete, path, handle) +func (r *router) DELETE(path string, handle func(ctx *Context)) { + r.Route(http.MethodDelete, path, handle) +} + +// Route registers a new request handle with the given path and method. +// +// For GET, POST, PUT, PATCH and DELETE requests the respective shortcut +// functions can be used. +// +// This function is intended for bulk loading and to allow the usage of less +// frequently used, non-standardized or custom methods (e.g. for internal +// communication with a proxy). +func (r *router) Route(method, path string, handle func(ctx *Context)) { + if method == "" { + panic("method must not be empty") + } + if len(path) < 1 || path[0] != '/' { + panic("path must begin with '/' in path '" + path + "'") + } + if handle == nil { + panic("handle must not be nil") + } + + if r.trees == nil { + r.trees = make(map[string]*node) + } + + root := r.trees[method] + if root == nil { + root = new(node) + r.trees[method] = root + + r.globalAllowed = r.allowed("*", "") + } + + root.addRoute(path, handle) +} + +func (r *router) recv(w http.ResponseWriter, req *http.Request) { + if rcv := recover(); rcv != nil { + r.PanicHandler(w, req, rcv) + } +} + +func (r *router) allowed(path, reqMethod string) (allow string) { + allowed := make([]string, 0, 9) + + if path == "*" { // server-wide + // empty method is used for internal calls to refresh the cache + if reqMethod == "" { + for method := range r.trees { + if method == http.MethodOptions { + continue + } + // Route request method to list of allowed methods + allowed = append(allowed, method) + } + } else { + return r.globalAllowed + } + } else { // specific path + for method := range r.trees { + // Skip the requested method - we already tried this one + if method == reqMethod || method == http.MethodOptions { + continue + } + + handle, _ := r.trees[method].getValue(path, nil) + if handle != nil { + // Route request method to list of allowed methods + allowed = append(allowed, method) + } + } + } + + if len(allowed) > 0 { + // Route request method to list of allowed methods + allowed = append(allowed, http.MethodOptions) + + // Sort allowed methods. + // sort.Strings(allowed) unfortunately causes unnecessary allocations + // due to allowed being moved to the heap and interface conversion + for i, l := 1, len(allowed); i < l; i++ { + for j := i; j > 0 && allowed[j] < allowed[j-1]; j-- { + allowed[j], allowed[j-1] = allowed[j-1], allowed[j] + } + } + + // return as comma separated list + return strings.Join(allowed, ", ") + } + + return allow +} + +func (r *router) serveHTTP(w http.ResponseWriter, req *http.Request, e *Engine) { + if r.PanicHandler != nil { + defer r.recv(w, req) + } + + path := req.URL.Path + root := r.trees[req.Method] + if root != nil { + ctx := e.NewContext(w, req) + if handle, tsr := root.getValue(path, ctx); handle != nil { + handle(ctx) + return + } else if req.Method != http.MethodConnect && path != "/" { + // Moved Permanently, request with GET method + code := http.StatusMovedPermanently + if req.Method != http.MethodGet { + // Permanent Redirect, request with same method + code = http.StatusPermanentRedirect + } + + if tsr && r.RedirectTrailingSlash { + if len(path) > 1 && path[len(path)-1] == '/' { + req.URL.Path = path[:len(path)-1] + } else { + req.URL.Path = path + "/" + } + http.Redirect(w, req, req.URL.String(), code) + return + } + + // Try to fix the request path + if r.RedirectFixedPath { + fixedPath, found := root.findCaseInsensitivePath( + url.CleanPath(path), + r.RedirectTrailingSlash, + ) + if found { + req.URL.Path = fixedPath + http.Redirect(w, req, req.URL.String(), code) + return + } + } + } + } + + if req.Method == http.MethodOptions && r.HandleOPTIONS { + // Route OPTIONS requests + if allow := r.allowed(path, http.MethodOptions); allow != "" { + w.Header().Set("Allow", allow) + if r.GlobalOPTIONS != nil { + r.GlobalOPTIONS.ServeHTTP(w, req) + } + return + } + } else if r.HandleMethodNotAllowed { // Route 405 + if allow := r.allowed(path, req.Method); allow != "" { + w.Header().Set("Allow", allow) + if r.MethodNotAllowed != nil { + r.MethodNotAllowed.ServeHTTP(w, req) + } else { + http.Error(w, + http.StatusText(http.StatusMethodNotAllowed), + http.StatusMethodNotAllowed, + ) + } + return + } + } + + e.Mux.ServeHTTP(w, req) +} diff --git a/tree.go b/tree.go new file mode 100644 index 0000000..508a619 --- /dev/null +++ b/tree.go @@ -0,0 +1,667 @@ +/* + * Copyright (c) 2023 The GoPlus Authors (goplus.org). All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package yap + +import ( + "strings" + "unicode" + "unicode/utf8" +) + +func min(a, b int) int { + if a <= b { + return a + } + return b +} + +func longestCommonPrefix(a, b string) int { + i := 0 + max := min(len(a), len(b)) + for i < max && a[i] == b[i] { + i++ + } + return i +} + +// Search for a wildcard segment and check the name for invalid characters. +// Returns -1 as index, if no wildcard was found. +func findWildcard(path string) (wilcard string, i int, valid bool) { + // Find start + for start, c := range []byte(path) { + // A wildcard starts with ':' (param) or '*' (catch-all) + if c != ':' && c != '*' { + continue + } + + // Find end and check for invalid characters + valid = true + for end, c := range []byte(path[start+1:]) { + switch c { + case '/': + return path[start : start+1+end], start, valid + case ':', '*': + valid = false + } + } + return path[start:], start, valid + } + return "", -1, false +} + +type nodeType uint8 + +const ( + static nodeType = iota // default + root + param + catchAll +) + +type node struct { + path string + indices string + wildChild bool + nType nodeType + priority uint32 + children []*node + handle func(ctx *Context) +} + +// Increments priority of the given child and reorders if necessary +func (n *node) incrementChildPrio(pos int) int { + cs := n.children + cs[pos].priority++ + prio := cs[pos].priority + + // Adjust position (move to front) + newPos := pos + for ; newPos > 0 && cs[newPos-1].priority < prio; newPos-- { + // Swap node positions + cs[newPos-1], cs[newPos] = cs[newPos], cs[newPos-1] + } + + // Build new index char string + if newPos != pos { + n.indices = n.indices[:newPos] + // Unchanged prefix, might be empty + n.indices[pos:pos+1] + // The index char we move + n.indices[newPos:pos] + n.indices[pos+1:] // Rest without char at 'pos' + } + + return newPos +} + +// addRoute adds a node with the given handle to the path. +// Not concurrency-safe! +func (n *node) addRoute(path string, handle func(ctx *Context)) { + fullPath := path + n.priority++ + + // Empty tree + if n.path == "" && n.indices == "" { + n.insertChild(path, fullPath, handle) + n.nType = root + return + } + +walk: + for { + // Find the longest common prefix. + // This also implies that the common prefix contains no ':' or '*' + // since the existing key can't contain those chars. + i := longestCommonPrefix(path, n.path) + + // Split edge + if i < len(n.path) { + child := node{ + path: n.path[i:], + wildChild: n.wildChild, + nType: static, + indices: n.indices, + children: n.children, + handle: n.handle, + priority: n.priority - 1, + } + + n.children = []*node{&child} + // []byte for proper unicode char conversion, see #65 + n.indices = string([]byte{n.path[i]}) + n.path = path[:i] + n.handle = nil + n.wildChild = false + } + + // Make new node a child of this node + if i < len(path) { + path = path[i:] + + if n.wildChild { + n = n.children[0] + n.priority++ + + // Check if the wildcard matches + if len(path) >= len(n.path) && n.path == path[:len(n.path)] && + // Adding a child to a catchAll is not possible + n.nType != catchAll && + // Check for longer wildcard, e.g. :name and :names + (len(n.path) >= len(path) || path[len(n.path)] == '/') { + continue walk + } else { + // Wildcard conflict + pathSeg := path + if n.nType != catchAll { + pathSeg = strings.SplitN(pathSeg, "/", 2)[0] + } + prefix := fullPath[:strings.Index(fullPath, pathSeg)] + n.path + panic("'" + pathSeg + + "' in new path '" + fullPath + + "' conflicts with existing wildcard '" + n.path + + "' in existing prefix '" + prefix + + "'") + } + } + + idxc := path[0] + + // '/' after param + if n.nType == param && idxc == '/' && len(n.children) == 1 { + n = n.children[0] + n.priority++ + continue walk + } + + // Check if a child with the next path byte exists + for i, c := range []byte(n.indices) { + if c == idxc { + i = n.incrementChildPrio(i) + n = n.children[i] + continue walk + } + } + + // Otherwise insert it + if idxc != ':' && idxc != '*' { + // []byte for proper unicode char conversion, see #65 + n.indices += string([]byte{idxc}) + child := &node{} + n.children = append(n.children, child) + n.incrementChildPrio(len(n.indices) - 1) + n = child + } + n.insertChild(path, fullPath, handle) + return + } + + // Otherwise add handle to current node + if n.handle != nil { + panic("a handle is already registered for path '" + fullPath + "'") + } + n.handle = handle + return + } +} + +func (n *node) insertChild(path, fullPath string, handle func(ctx *Context)) { + for { + // Find prefix until first wildcard + wildcard, i, valid := findWildcard(path) + if i < 0 { // No wilcard found + break + } + + // The wildcard name must not contain ':' and '*' + if !valid { + panic("only one wildcard per path segment is allowed, has: '" + + wildcard + "' in path '" + fullPath + "'") + } + + // Check if the wildcard has a name + if len(wildcard) < 2 { + panic("wildcards must be named with a non-empty name in path '" + fullPath + "'") + } + + // Check if this node has existing children which would be + // unreachable if we insert the wildcard here + if len(n.children) > 0 { + panic("wildcard segment '" + wildcard + + "' conflicts with existing children in path '" + fullPath + "'") + } + + // param + if wildcard[0] == ':' { + if i > 0 { + // Insert prefix before the current wildcard + n.path = path[:i] + path = path[i:] + } + + n.wildChild = true + child := &node{ + nType: param, + path: wildcard, + } + n.children = []*node{child} + n = child + n.priority++ + + // If the path doesn't end with the wildcard, then there + // will be another non-wildcard subpath starting with '/' + if len(wildcard) < len(path) { + path = path[len(wildcard):] + child := &node{ + priority: 1, + } + n.children = []*node{child} + n = child + continue + } + + // Otherwise we're done. Insert the handle in the new leaf + n.handle = handle + return + } + + // catchAll + if i+len(wildcard) != len(path) { + panic("catch-all routes are only allowed at the end of the path in path '" + fullPath + "'") + } + + if len(n.path) > 0 && n.path[len(n.path)-1] == '/' { + panic("catch-all conflicts with existing handle for the path segment root in path '" + fullPath + "'") + } + + // Currently fixed width 1 for '/' + i-- + if path[i] != '/' { + panic("no / before catch-all in path '" + fullPath + "'") + } + + n.path = path[:i] + + // First node: catchAll node with empty path + child := &node{ + wildChild: true, + nType: catchAll, + } + n.children = []*node{child} + n.indices = string('/') + n = child + n.priority++ + + // Second node: node holding the variable + child = &node{ + path: path[i:], + nType: catchAll, + handle: handle, + priority: 1, + } + n.children = []*node{child} + + return + } + + // If no wildcard was found, simply insert the path and handle + n.path = path + n.handle = handle +} + +// Returns the handle registered with the given path (key). The values of +// wildcards are saved to a map. +// If no handle can be found, a TSR (trailing slash redirect) recommendation is +// made if a handle exists with an extra (without the) trailing slash for the +// given path. +func (n *node) getValue(path string, ctx *Context) (handle func(ctx *Context), tsr bool) { +walk: // Outer loop for walking the tree + for { + prefix := n.path + if len(path) > len(prefix) { + if path[:len(prefix)] == prefix { + path = path[len(prefix):] + + // If this node does not have a wildcard (param or catchAll) + // child, we can just look up the next child node and continue + // to walk down the tree + if !n.wildChild { + idxc := path[0] + for i, c := range []byte(n.indices) { + if c == idxc { + n = n.children[i] + continue walk + } + } + + // Nothing found. + // We can recommend to redirect to the same URL without a + // trailing slash if a leaf exists for that path. + tsr = (path == "/" && n.handle != nil) + return + } + + // Handle wildcard child + n = n.children[0] + switch n.nType { + case param: + // Find param end (either '/' or path end) + end := 0 + for end < len(path) && path[end] != '/' { + end++ + } + + // Save param value + if ctx != nil { + ctx.setParam(n.path[1:], path[:end]) + } + + // We need to go deeper! + if end < len(path) { + if len(n.children) > 0 { + path = path[end:] + n = n.children[0] + continue walk + } + + // ... but we can't + tsr = (len(path) == end+1) + return + } + + if handle = n.handle; handle != nil { + return + } else if len(n.children) == 1 { + // No handle found. Check if a handle for this path + a + // trailing slash exists for TSR recommendation + n = n.children[0] + tsr = (n.path == "/" && n.handle != nil) || (n.path == "" && n.indices == "/") + } + return + + case catchAll: + // Save param value + if ctx != nil { + ctx.setParam(n.path[2:], path) + } + + handle = n.handle + return + + default: + panic("invalid node type") + } + } + } else if path == prefix { + // We should have reached the node containing the handle. + // Check if this node has a handle registered. + if handle = n.handle; handle != nil { + return + } + + // If there is no handle for this route, but this route has a + // wildcard child, there must be a handle for this path with an + // additional trailing slash + if path == "/" && n.wildChild && n.nType != root { + tsr = true + return + } + + if path == "/" && n.nType == static { + tsr = true + return + } + + // No handle found. Check if a handle for this path + a + // trailing slash exists for trailing slash recommendation + for i, c := range []byte(n.indices) { + if c == '/' { + n = n.children[i] + tsr = (len(n.path) == 1 && n.handle != nil) || + (n.nType == catchAll && n.children[0].handle != nil) + return + } + } + return + } + + // Nothing found. We can recommend to redirect to the same URL with an + // extra trailing slash if a leaf exists for that path + tsr = (path == "/") || + (len(prefix) == len(path)+1 && prefix[len(path)] == '/' && + path == prefix[:len(prefix)-1] && n.handle != nil) + return + } +} + +// Makes a case-insensitive lookup of the given path and tries to find a func(ctx *Context). +// It can optionally also fix trailing slashes. +// It returns the case-corrected path and a bool indicating whether the lookup +// was successful. +func (n *node) findCaseInsensitivePath(path string, fixTrailingSlash bool) (fixedPath string, found bool) { + const stackBufSize = 128 + + // Use a static sized buffer on the stack in the common case. + // If the path is too long, allocate a buffer on the heap instead. + buf := make([]byte, 0, stackBufSize) + if l := len(path) + 1; l > stackBufSize { + buf = make([]byte, 0, l) + } + + ciPath := n.findCaseInsensitivePathRec( + path, + buf, // Preallocate enough memory for new path + [4]byte{}, // Empty rune buffer + fixTrailingSlash, + ) + + return string(ciPath), ciPath != nil +} + +// Shift bytes in array by n bytes left +func shiftNRuneBytes(rb [4]byte, n int) [4]byte { + switch n { + case 0: + return rb + case 1: + return [4]byte{rb[1], rb[2], rb[3], 0} + case 2: + return [4]byte{rb[2], rb[3]} + case 3: + return [4]byte{rb[3]} + default: + return [4]byte{} + } +} + +// Recursive case-insensitive lookup function used by n.findCaseInsensitivePath +func (n *node) findCaseInsensitivePathRec(path string, ciPath []byte, rb [4]byte, fixTrailingSlash bool) []byte { + npLen := len(n.path) + +walk: // Outer loop for walking the tree + for len(path) >= npLen && (npLen == 0 || strings.EqualFold(path[1:npLen], n.path[1:])) { + // Add common prefix to result + oldPath := path + path = path[npLen:] + ciPath = append(ciPath, n.path...) + + if len(path) > 0 { + // If this node does not have a wildcard (param or catchAll) child, + // we can just look up the next child node and continue to walk down + // the tree + if !n.wildChild { + // Skip rune bytes already processed + rb = shiftNRuneBytes(rb, npLen) + + if rb[0] != 0 { + // Old rune not finished + idxc := rb[0] + for i, c := range []byte(n.indices) { + if c == idxc { + // continue with child node + n = n.children[i] + npLen = len(n.path) + continue walk + } + } + } else { + // Process a new rune + var rv rune + + // Find rune start. + // Runes are up to 4 byte long, + // -4 would definitely be another rune. + var off int + for max := min(npLen, 3); off < max; off++ { + if i := npLen - off; utf8.RuneStart(oldPath[i]) { + // read rune from cached path + rv, _ = utf8.DecodeRuneInString(oldPath[i:]) + break + } + } + + // Calculate lowercase bytes of current rune + lo := unicode.ToLower(rv) + utf8.EncodeRune(rb[:], lo) + + // Skip already processed bytes + rb = shiftNRuneBytes(rb, off) + + idxc := rb[0] + for i, c := range []byte(n.indices) { + // Lowercase matches + if c == idxc { + // must use a recursive approach since both the + // uppercase byte and the lowercase byte might exist + // as an index + if out := n.children[i].findCaseInsensitivePathRec( + path, ciPath, rb, fixTrailingSlash, + ); out != nil { + return out + } + break + } + } + + // If we found no match, the same for the uppercase rune, + // if it differs + if up := unicode.ToUpper(rv); up != lo { + utf8.EncodeRune(rb[:], up) + rb = shiftNRuneBytes(rb, off) + + idxc := rb[0] + for i, c := range []byte(n.indices) { + // Uppercase matches + if c == idxc { + // Continue with child node + n = n.children[i] + npLen = len(n.path) + continue walk + } + } + } + } + + // Nothing found. We can recommend to redirect to the same URL + // without a trailing slash if a leaf exists for that path + if fixTrailingSlash && path == "/" && n.handle != nil { + return ciPath + } + return nil + } + + n = n.children[0] + switch n.nType { + case param: + // Find param end (either '/' or path end) + end := 0 + for end < len(path) && path[end] != '/' { + end++ + } + + // Add param value to case insensitive path + ciPath = append(ciPath, path[:end]...) + + // We need to go deeper! + if end < len(path) { + if len(n.children) > 0 { + // Continue with child node + n = n.children[0] + npLen = len(n.path) + path = path[end:] + continue + } + + // ... but we can't + if fixTrailingSlash && len(path) == end+1 { + return ciPath + } + return nil + } + + if n.handle != nil { + return ciPath + } else if fixTrailingSlash && len(n.children) == 1 { + // No handle found. Check if a handle for this path + a + // trailing slash exists + n = n.children[0] + if n.path == "/" && n.handle != nil { + return append(ciPath, '/') + } + } + return nil + + case catchAll: + return append(ciPath, path...) + + default: + panic("invalid node type") + } + } else { + // We should have reached the node containing the handle. + // Check if this node has a handle registered. + if n.handle != nil { + return ciPath + } + + // No handle found. + // Try to fix the path by adding a trailing slash + if fixTrailingSlash { + for i, c := range []byte(n.indices) { + if c == '/' { + n = n.children[i] + if (len(n.path) == 1 && n.handle != nil) || + (n.nType == catchAll && n.children[0].handle != nil) { + return append(ciPath, '/') + } + return nil + } + } + } + return nil + } + } + + // Nothing found. + // Try to fix the path by adding / removing a trailing slash + if fixTrailingSlash { + if path == "/" { + return ciPath + } + if len(path)+1 == npLen && n.path[len(path)] == '/' && + strings.EqualFold(path[1:], n.path[1:len(path)]) && n.handle != nil { + return append(ciPath, n.path...) + } + } + return nil +} diff --git a/yap.go b/yap.go index 33c5647..3dd5563 100644 --- a/yap.go +++ b/yap.go @@ -25,6 +25,7 @@ import ( type H map[string]interface{} type Engine struct { + router Mux *http.ServeMux tpls map[string]Template @@ -39,6 +40,7 @@ func New(fs ...fs.FS) *Engine { e.fs = fs[0] e.tpls = make(map[string]Template) } + e.router.init() return e } @@ -47,6 +49,11 @@ func (p *Engine) NewContext(w http.ResponseWriter, r *http.Request) *Context { return ctx } +// ServeHTTP makes the router implement the http.Handler interface. +func (p *Engine) ServeHTTP(w http.ResponseWriter, req *http.Request) { + p.router.serveHTTP(w, req, p) +} + func (p *Engine) Handle(pattern string, f func(ctx *Context)) { p.Mux.HandleFunc(pattern, func(w http.ResponseWriter, r *http.Request) { f(p.NewContext(w, r)) @@ -54,7 +61,7 @@ func (p *Engine) Handle(pattern string, f func(ctx *Context)) { } func (p *Engine) Run(addr string, mws ...func(h http.Handler) http.Handler) { - h := http.Handler(p.Mux) + h := http.Handler(p) for _, mw := range mws { h = mw(h) }