Skip to content

Commit 0534769

Browse files
fharding1elithrar
authored andcommitted
Improve CORS Method Middleware (gorilla#477)
* More sensical CORSMethodMiddleware * Only sets Access-Control-Allow-Methods on valid preflight requests * Does not return after setting the Access-Control-Allow-Methods header * Does not append OPTIONS header to Access-Control-Allow-Methods regardless of whether there is an OPTIONS method matcher * Adds tests for the listed behavior * Add example for CORSMethodMiddleware * Do not check for preflight and add documentation to the README * Use http.MethodOptions instead of "OPTIONS" * Add link to CORSMethodMiddleware section to readme * Add test for unmatching route methods * Rename CORS Method Middleware to Handling CORS Requests in README * Link CORSMethodMiddleware in README to godoc * Break CORSMethodMiddleware doc into bullets for readability * Add comment about specifying OPTIONS to example in README for CORSMethodMiddleware * Document cURL command used for testing CORS Method Middleware * Update comment in example to "Handle the request" * Add explicit comment about OPTIONS matchers to CORSMethodMiddleware doc * Update circleci config to only check gofmt diff on latest go version * Break up gofmt and go vet checks into separate steps. * Use canonical circleci config
1 parent d70f7b4 commit 0534769

File tree

5 files changed

+252
-58
lines changed

5 files changed

+252
-58
lines changed

.circleci/config.yml

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,20 @@ jobs:
1111
- checkout
1212
- run: go version
1313
- run: go get -t -v ./...
14-
- run: diff -u <(echo -n) <(gofmt -d .)
15-
- run: if [[ "$LATEST" = true ]]; then go vet -v .; fi
14+
# Only run gofmt, vet & lint against the latest Go version
15+
- run: >
16+
if [[ "$LATEST" = true ]]; then
17+
go get -u golang.org/x/lint/golint
18+
golint ./...
19+
fi
20+
- run: >
21+
if [[ "$LATEST" = true ]]; then
22+
diff -u <(echo -n) <(gofmt -d .)
23+
fi
24+
- run: >
25+
if [[ "$LATEST" = true ]]; then
26+
go vet -v .
27+
fi
1628
- run: go test -v -race ./...
1729

1830
"latest":

README.md

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ The name mux stands for "HTTP request multiplexer". Like the standard `http.Serv
3030
* [Walking Routes](#walking-routes)
3131
* [Graceful Shutdown](#graceful-shutdown)
3232
* [Middleware](#middleware)
33+
* [Handling CORS Requests](#handling-cors-requests)
3334
* [Testing Handlers](#testing-handlers)
3435
* [Full Example](#full-example)
3536

@@ -492,6 +493,73 @@ r.Use(amw.Middleware)
492493

493494
Note: The handler chain will be stopped if your middleware doesn't call `next.ServeHTTP()` with the corresponding parameters. This can be used to abort a request if the middleware writer wants to. Middlewares _should_ write to `ResponseWriter` if they _are_ going to terminate the request, and they _should not_ write to `ResponseWriter` if they _are not_ going to terminate it.
494495

496+
### Handling CORS Requests
497+
498+
[CORSMethodMiddleware](https://godoc.org/github.com/gorilla/mux#CORSMethodMiddleware) intends to make it easier to strictly set the `Access-Control-Allow-Methods` response header.
499+
500+
* You will still need to use your own CORS handler to set the other CORS headers such as `Access-Control-Allow-Origin`
501+
* The middleware will set the `Access-Control-Allow-Methods` header to all the method matchers (e.g. `r.Methods(http.MethodGet, http.MethodPut, http.MethodOptions)` -> `Access-Control-Allow-Methods: GET,PUT,OPTIONS`) on a route
502+
* If you do not specify any methods, then:
503+
> _Important_: there must be an `OPTIONS` method matcher for the middleware to set the headers.
504+
505+
Here is an example of using `CORSMethodMiddleware` along with a custom `OPTIONS` handler to set all the required CORS headers:
506+
507+
```go
508+
package main
509+
510+
import (
511+
"net/http"
512+
"github.com/gorilla/mux"
513+
)
514+
515+
func main() {
516+
r := mux.NewRouter()
517+
518+
// IMPORTANT: you must specify an OPTIONS method matcher for the middleware to set CORS headers
519+
r.HandleFunc("/foo", fooHandler).Methods(http.MethodGet, http.MethodPut, http.MethodPatch, http.MethodOptions)
520+
r.Use(mux.CORSMethodMiddleware(r))
521+
522+
http.ListenAndServe(":8080", r)
523+
}
524+
525+
func fooHandler(w http.ResponseWriter, r *http.Request) {
526+
w.Header().Set("Access-Control-Allow-Origin", "*")
527+
if r.Method == http.MethodOptions {
528+
return
529+
}
530+
531+
w.Write([]byte("foo"))
532+
}
533+
```
534+
535+
And an request to `/foo` using something like:
536+
537+
```bash
538+
curl localhost:8080/foo -v
539+
```
540+
541+
Would look like:
542+
543+
```bash
544+
* Trying ::1...
545+
* TCP_NODELAY set
546+
* Connected to localhost (::1) port 8080 (#0)
547+
> GET /foo HTTP/1.1
548+
> Host: localhost:8080
549+
> User-Agent: curl/7.59.0
550+
> Accept: */*
551+
>
552+
< HTTP/1.1 200 OK
553+
< Access-Control-Allow-Methods: GET,PUT,PATCH,OPTIONS
554+
< Access-Control-Allow-Origin: *
555+
< Date: Fri, 28 Jun 2019 20:13:30 GMT
556+
< Content-Length: 3
557+
< Content-Type: text/plain; charset=utf-8
558+
<
559+
* Connection #0 to host localhost left intact
560+
foo
561+
```
562+
495563
### Testing Handlers
496564
497565
Testing handlers in a Go web application is straightforward, and _mux_ doesn't complicate this any further. Given two files: `endpoints.go` and `endpoints_test.go`, here's how we'd test an application using _mux_.
Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
package mux_test
2+
3+
import (
4+
"fmt"
5+
"net/http"
6+
"net/http/httptest"
7+
8+
"github.com/gorilla/mux"
9+
)
10+
11+
func ExampleCORSMethodMiddleware() {
12+
r := mux.NewRouter()
13+
14+
r.HandleFunc("/foo", func(w http.ResponseWriter, r *http.Request) {
15+
// Handle the request
16+
}).Methods(http.MethodGet, http.MethodPut, http.MethodPatch)
17+
r.HandleFunc("/foo", func(w http.ResponseWriter, r *http.Request) {
18+
w.Header().Set("Access-Control-Allow-Origin", "http://example.com")
19+
w.Header().Set("Access-Control-Max-Age", "86400")
20+
}).Methods(http.MethodOptions)
21+
22+
r.Use(mux.CORSMethodMiddleware(r))
23+
24+
rw := httptest.NewRecorder()
25+
req, _ := http.NewRequest("OPTIONS", "/foo", nil) // needs to be OPTIONS
26+
req.Header.Set("Access-Control-Request-Method", "POST") // needs to be non-empty
27+
req.Header.Set("Access-Control-Request-Headers", "Authorization") // needs to be non-empty
28+
req.Header.Set("Origin", "http://example.com") // needs to be non-empty
29+
30+
r.ServeHTTP(rw, req)
31+
32+
fmt.Println(rw.Header().Get("Access-Control-Allow-Methods"))
33+
fmt.Println(rw.Header().Get("Access-Control-Allow-Origin"))
34+
// Output:
35+
// GET,PUT,PATCH,OPTIONS
36+
// http://example.com
37+
}

middleware.go

Lines changed: 34 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -32,41 +32,48 @@ func (r *Router) useInterface(mw middleware) {
3232
r.middlewares = append(r.middlewares, mw)
3333
}
3434

35-
// CORSMethodMiddleware sets the Access-Control-Allow-Methods response header
36-
// on a request, by matching routes based only on paths. It also handles
37-
// OPTIONS requests, by settings Access-Control-Allow-Methods, and then
38-
// returning without calling the next http handler.
35+
// CORSMethodMiddleware automatically sets the Access-Control-Allow-Methods response header
36+
// on requests for routes that have an OPTIONS method matcher to all the method matchers on
37+
// the route. Routes that do not explicitly handle OPTIONS requests will not be processed
38+
// by the middleware. See examples for usage.
3939
func CORSMethodMiddleware(r *Router) MiddlewareFunc {
4040
return func(next http.Handler) http.Handler {
4141
return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
42-
var allMethods []string
43-
44-
err := r.Walk(func(route *Route, _ *Router, _ []*Route) error {
45-
for _, m := range route.matchers {
46-
if _, ok := m.(*routeRegexp); ok {
47-
if m.Match(req, &RouteMatch{}) {
48-
methods, err := route.GetMethods()
49-
if err != nil {
50-
return err
51-
}
52-
53-
allMethods = append(allMethods, methods...)
54-
}
55-
break
56-
}
57-
}
58-
return nil
59-
})
60-
42+
allMethods, err := getAllMethodsForRoute(r, req)
6143
if err == nil {
62-
w.Header().Set("Access-Control-Allow-Methods", strings.Join(append(allMethods, "OPTIONS"), ","))
63-
64-
if req.Method == "OPTIONS" {
65-
return
44+
for _, v := range allMethods {
45+
if v == http.MethodOptions {
46+
w.Header().Set("Access-Control-Allow-Methods", strings.Join(allMethods, ","))
47+
}
6648
}
6749
}
6850

6951
next.ServeHTTP(w, req)
7052
})
7153
}
7254
}
55+
56+
// getAllMethodsForRoute returns all the methods from method matchers matching a given
57+
// request.
58+
func getAllMethodsForRoute(r *Router, req *http.Request) ([]string, error) {
59+
var allMethods []string
60+
61+
err := r.Walk(func(route *Route, _ *Router, _ []*Route) error {
62+
for _, m := range route.matchers {
63+
if _, ok := m.(*routeRegexp); ok {
64+
if m.Match(req, &RouteMatch{}) {
65+
methods, err := route.GetMethods()
66+
if err != nil {
67+
return err
68+
}
69+
70+
allMethods = append(allMethods, methods...)
71+
}
72+
break
73+
}
74+
}
75+
return nil
76+
})
77+
78+
return allMethods, err
79+
}

middleware_test.go

Lines changed: 99 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,7 @@ package mux
22

33
import (
44
"bytes"
5-
"fmt"
65
"net/http"
7-
"net/http/httptest"
86
"testing"
97
)
108

@@ -367,42 +365,114 @@ func TestMiddlewareMethodMismatchSubrouter(t *testing.T) {
367365
}
368366

369367
func TestCORSMethodMiddleware(t *testing.T) {
370-
router := NewRouter()
371-
372-
cases := []struct {
373-
path string
374-
response string
375-
method string
376-
testURL string
377-
expectedAllowedMethods string
368+
testCases := []struct {
369+
name string
370+
registerRoutes func(r *Router)
371+
requestHeader http.Header
372+
requestMethod string
373+
requestPath string
374+
expectedAccessControlAllowMethodsHeader string
375+
expectedResponse string
378376
}{
379-
{"/g/{o}", "a", "POST", "/g/asdf", "POST,PUT,GET,OPTIONS"},
380-
{"/g/{o}", "b", "PUT", "/g/bla", "POST,PUT,GET,OPTIONS"},
381-
{"/g/{o}", "c", "GET", "/g/orilla", "POST,PUT,GET,OPTIONS"},
382-
{"/g", "d", "POST", "/g", "POST,OPTIONS"},
377+
{
378+
name: "does not set without OPTIONS matcher",
379+
registerRoutes: func(r *Router) {
380+
r.HandleFunc("/foo", stringHandler("a")).Methods(http.MethodGet, http.MethodPut, http.MethodPatch)
381+
},
382+
requestMethod: "GET",
383+
requestPath: "/foo",
384+
expectedAccessControlAllowMethodsHeader: "",
385+
expectedResponse: "a",
386+
},
387+
{
388+
name: "sets on non OPTIONS",
389+
registerRoutes: func(r *Router) {
390+
r.HandleFunc("/foo", stringHandler("a")).Methods(http.MethodGet, http.MethodPut, http.MethodPatch)
391+
r.HandleFunc("/foo", stringHandler("b")).Methods(http.MethodOptions)
392+
},
393+
requestMethod: "GET",
394+
requestPath: "/foo",
395+
expectedAccessControlAllowMethodsHeader: "GET,PUT,PATCH,OPTIONS",
396+
expectedResponse: "a",
397+
},
398+
{
399+
name: "sets without preflight headers",
400+
registerRoutes: func(r *Router) {
401+
r.HandleFunc("/foo", stringHandler("a")).Methods(http.MethodGet, http.MethodPut, http.MethodPatch)
402+
r.HandleFunc("/foo", stringHandler("b")).Methods(http.MethodOptions)
403+
},
404+
requestMethod: "OPTIONS",
405+
requestPath: "/foo",
406+
expectedAccessControlAllowMethodsHeader: "GET,PUT,PATCH,OPTIONS",
407+
expectedResponse: "b",
408+
},
409+
{
410+
name: "does not set on error",
411+
registerRoutes: func(r *Router) {
412+
r.HandleFunc("/foo", stringHandler("a"))
413+
},
414+
requestMethod: "OPTIONS",
415+
requestPath: "/foo",
416+
expectedAccessControlAllowMethodsHeader: "",
417+
expectedResponse: "a",
418+
},
419+
{
420+
name: "sets header on valid preflight",
421+
registerRoutes: func(r *Router) {
422+
r.HandleFunc("/foo", stringHandler("a")).Methods(http.MethodGet, http.MethodPut, http.MethodPatch)
423+
r.HandleFunc("/foo", stringHandler("b")).Methods(http.MethodOptions)
424+
},
425+
requestMethod: "OPTIONS",
426+
requestPath: "/foo",
427+
requestHeader: http.Header{
428+
"Access-Control-Request-Method": []string{"GET"},
429+
"Access-Control-Request-Headers": []string{"Authorization"},
430+
"Origin": []string{"http://example.com"},
431+
},
432+
expectedAccessControlAllowMethodsHeader: "GET,PUT,PATCH,OPTIONS",
433+
expectedResponse: "b",
434+
},
435+
{
436+
name: "does not set methods from unmatching routes",
437+
registerRoutes: func(r *Router) {
438+
r.HandleFunc("/foo", stringHandler("c")).Methods(http.MethodDelete)
439+
r.HandleFunc("/foo/bar", stringHandler("a")).Methods(http.MethodGet, http.MethodPut, http.MethodPatch)
440+
r.HandleFunc("/foo/bar", stringHandler("b")).Methods(http.MethodOptions)
441+
},
442+
requestMethod: "OPTIONS",
443+
requestPath: "/foo/bar",
444+
requestHeader: http.Header{
445+
"Access-Control-Request-Method": []string{"GET"},
446+
"Access-Control-Request-Headers": []string{"Authorization"},
447+
"Origin": []string{"http://example.com"},
448+
},
449+
expectedAccessControlAllowMethodsHeader: "GET,PUT,PATCH,OPTIONS",
450+
expectedResponse: "b",
451+
},
383452
}
384453

385-
for _, tt := range cases {
386-
router.HandleFunc(tt.path, stringHandler(tt.response)).Methods(tt.method)
387-
}
454+
for _, tt := range testCases {
455+
t.Run(tt.name, func(t *testing.T) {
456+
router := NewRouter()
388457

389-
router.Use(CORSMethodMiddleware(router))
458+
tt.registerRoutes(router)
390459

391-
for i, tt := range cases {
392-
t.Run(fmt.Sprintf("cases[%d]", i), func(t *testing.T) {
393-
rr := httptest.NewRecorder()
394-
req := newRequest(tt.method, tt.testURL)
460+
router.Use(CORSMethodMiddleware(router))
395461

396-
router.ServeHTTP(rr, req)
462+
rw := NewRecorder()
463+
req := newRequest(tt.requestMethod, tt.requestPath)
464+
req.Header = tt.requestHeader
397465

398-
if rr.Body.String() != tt.response {
399-
t.Errorf("Expected body '%s', found '%s'", tt.response, rr.Body.String())
400-
}
466+
router.ServeHTTP(rw, req)
401467

402-
allowedMethods := rr.Header().Get("Access-Control-Allow-Methods")
468+
actualMethodsHeader := rw.Header().Get("Access-Control-Allow-Methods")
469+
if actualMethodsHeader != tt.expectedAccessControlAllowMethodsHeader {
470+
t.Fatalf("Expected Access-Control-Allow-Methods to equal %s but got %s", tt.expectedAccessControlAllowMethodsHeader, actualMethodsHeader)
471+
}
403472

404-
if allowedMethods != tt.expectedAllowedMethods {
405-
t.Errorf("Expected Access-Control-Allow-Methods '%s', found '%s'", tt.expectedAllowedMethods, allowedMethods)
473+
actualResponse := rw.Body.String()
474+
if actualResponse != tt.expectedResponse {
475+
t.Fatalf("Expected response to equal %s but got %s", tt.expectedResponse, actualResponse)
406476
}
407477
})
408478
}

0 commit comments

Comments
 (0)