Skip to content

Commit 5e55a4a

Browse files
fharding1elithrar
authored andcommitted
Add CORSMethodMiddleware (#366)
CORSMethodMiddleware sets the Access-Control-Allow-Methods response header on a request, by matching routes based only on paths. It also handles OPTIONS requests, by settings Access-Control-Allow-Methods, and then returning without calling the next HTTP handler.
1 parent ded0c29 commit 5e55a4a

File tree

3 files changed

+92
-1
lines changed

3 files changed

+92
-1
lines changed

middleware.go

Lines changed: 43 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
11
package mux
22

3-
import "net/http"
3+
import (
4+
"net/http"
5+
"strings"
6+
)
47

58
// MiddlewareFunc is a function which receives an http.Handler and returns another http.Handler.
69
// Typically, the returned handler is a closure which does something with the http.ResponseWriter and http.Request passed
@@ -28,3 +31,42 @@ func (r *Router) Use(mwf ...MiddlewareFunc) {
2831
func (r *Router) useInterface(mw middleware) {
2932
r.middlewares = append(r.middlewares, mw)
3033
}
34+
35+
// CORSMethodMiddleware sets the Access-Control-Allow-Methods response header
36+
// on a request, by matching routes based only on paths. It also handles
37+
// OPTIONS requests, by settings Access-Control-Allow-Methods, and then
38+
// returning without calling the next http handler.
39+
func CORSMethodMiddleware(r *Router) MiddlewareFunc {
40+
return func(next http.Handler) http.Handler {
41+
return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
42+
var allMethods []string
43+
44+
err := r.Walk(func(route *Route, _ *Router, _ []*Route) error {
45+
for _, m := range route.matchers {
46+
if _, ok := m.(*routeRegexp); ok {
47+
if m.Match(req, &RouteMatch{}) {
48+
methods, err := route.GetMethods()
49+
if err != nil {
50+
return err
51+
}
52+
53+
allMethods = append(allMethods, methods...)
54+
}
55+
break
56+
}
57+
}
58+
return nil
59+
})
60+
61+
if err == nil {
62+
w.Header().Set("Access-Control-Allow-Methods", strings.Join(append(allMethods, "OPTIONS"), ","))
63+
64+
if req.Method == "OPTIONS" {
65+
return
66+
}
67+
}
68+
69+
next.ServeHTTP(w, req)
70+
})
71+
}
72+
}

middleware_test.go

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ package mux
33
import (
44
"bytes"
55
"net/http"
6+
"net/http/httptest"
67
"testing"
78
)
89

@@ -334,3 +335,43 @@ func TestMiddlewareMethodMismatchSubrouter(t *testing.T) {
334335
t.Fatal("Middleware was called for a method mismatch")
335336
}
336337
}
338+
339+
func TestCORSMethodMiddleware(t *testing.T) {
340+
router := NewRouter()
341+
342+
cases := []struct {
343+
path string
344+
response string
345+
method string
346+
testURL string
347+
expectedAllowedMethods string
348+
}{
349+
{"/g/{o}", "a", "POST", "/g/asdf", "POST,PUT,GET,OPTIONS"},
350+
{"/g/{o}", "b", "PUT", "/g/bla", "POST,PUT,GET,OPTIONS"},
351+
{"/g/{o}", "c", "GET", "/g/orilla", "POST,PUT,GET,OPTIONS"},
352+
{"/g", "d", "POST", "/g", "POST,OPTIONS"},
353+
}
354+
355+
for _, tt := range cases {
356+
router.HandleFunc(tt.path, stringHandler(tt.response)).Methods(tt.method)
357+
}
358+
359+
router.Use(CORSMethodMiddleware(router))
360+
361+
for _, tt := range cases {
362+
rr := httptest.NewRecorder()
363+
req := newRequest(tt.method, tt.testURL)
364+
365+
router.ServeHTTP(rr, req)
366+
367+
if rr.Body.String() != tt.response {
368+
t.Errorf("Expected body '%s', found '%s'", tt.response, rr.Body.String())
369+
}
370+
371+
allowedMethods := rr.HeaderMap.Get("Access-Control-Allow-Methods")
372+
373+
if allowedMethods != tt.expectedAllowedMethods {
374+
t.Errorf("Expected Access-Control-Allow-Methods '%s', found '%s'", tt.expectedAllowedMethods, allowedMethods)
375+
}
376+
}
377+
}

mux_test.go

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2315,6 +2315,14 @@ func stringMapEqual(m1, m2 map[string]string) bool {
23152315
return true
23162316
}
23172317

2318+
// stringHandler returns a handler func that writes a message 's' to the
2319+
// http.ResponseWriter.
2320+
func stringHandler(s string) http.HandlerFunc {
2321+
return func(w http.ResponseWriter, r *http.Request) {
2322+
w.Write([]byte(s))
2323+
}
2324+
}
2325+
23182326
// newRequest is a helper function to create a new request with a method and url.
23192327
// The request returned is a 'server' request as opposed to a 'client' one through
23202328
// simulated write onto the wire and read off of the wire.

0 commit comments

Comments
 (0)