Skip to content

Commit 59e8bdf

Browse files
committed
add middleware wrapping tollboth/v8
1 parent bb0c485 commit 59e8bdf

File tree

4 files changed

+194
-0
lines changed

4 files changed

+194
-0
lines changed

go.mod

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,14 @@ module github.com/go-pkgz/rest
33
go 1.21
44

55
require (
6+
github.com/didip/tollbooth/v8 v8.0.0
67
github.com/stretchr/testify v1.10.0
78
golang.org/x/crypto v0.31.0
89
)
910

1011
require (
1112
github.com/davecgh/go-spew v1.1.1 // indirect
13+
github.com/go-pkgz/expirable-cache/v3 v3.0.0 // indirect
1214
github.com/pmezard/go-difflib v1.0.0 // indirect
1315
golang.org/x/sys v0.28.0 // indirect
1416
gopkg.in/yaml.v3 v3.0.1 // indirect

go.sum

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,11 @@
11
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
22
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
3+
github.com/didip/tollbooth/v8 v8.0.0 h1:AMr7m7TlXLpxAgKfshyoBMLuq6uv6EU3CZ9HBAGQIvQ=
4+
github.com/didip/tollbooth/v8 v8.0.0/go.mod h1:oEd9l+ep373d7DmvKLc0a5gasPOev2mTewi6KPQBGJ4=
5+
github.com/go-pkgz/expirable-cache/v3 v3.0.0 h1:u3/gcu3sabLYiTCevoRKv+WzjIn5oo7P8XtiXBeRDLw=
6+
github.com/go-pkgz/expirable-cache/v3 v3.0.0/go.mod h1:2OQiDyEGQalYecLWmXprm3maPXeVb5/6/X7yRPYTzec=
7+
github.com/hashicorp/golang-lru/v2 v2.0.7 h1:a+bsQ5rvGLjzHuww6tVxozPZFVghXaHOwFs4luLUK2k=
8+
github.com/hashicorp/golang-lru/v2 v2.0.7/go.mod h1:QeFd9opnmA6QUJc5vARoKUSoFhyfM2/ZepoAG6RGpeM=
39
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
410
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
511
github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA=

tollbooth.go

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
package rest
2+
3+
import (
4+
"net/http"
5+
6+
"github.com/didip/tollbooth/v8"
7+
"github.com/didip/tollbooth/v8/limiter"
8+
)
9+
10+
// based on https://github.com/didip/tollbooth_chi/blob/master/tollbooth_chi.go
11+
// added support of v8 and simplified, removed chi dependency
12+
// one notable difference is that this middleware sets IP lookup to RemoteAddr by default,
13+
// however, it can be overridden by setting it in the limiter
14+
15+
// LimitHandler wraps http.Handler with tollbooth limiter
16+
func LimitHandler(lmt *limiter.Limiter) func(http.Handler) http.Handler {
17+
// // set IP lookup only if not set
18+
if lmt.GetIPLookup().Name == "" {
19+
lmt.SetIPLookup(limiter.IPLookup{Name: "RemoteAddr"})
20+
}
21+
22+
return func(next http.Handler) http.Handler {
23+
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
24+
select {
25+
case <-r.Context().Done():
26+
http.Error(w, "Context was canceled", http.StatusServiceUnavailable)
27+
return
28+
default:
29+
if httpError := tollbooth.LimitByRequest(lmt, w, r); httpError != nil {
30+
lmt.ExecOnLimitReached(w, r)
31+
w.Header().Add("Content-Type", lmt.GetMessageContentType())
32+
w.WriteHeader(httpError.StatusCode)
33+
w.Write([]byte(httpError.Message))
34+
return
35+
}
36+
next.ServeHTTP(w, r)
37+
}
38+
})
39+
}
40+
}

tollbooth_test.go

Lines changed: 146 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,146 @@
1+
package rest
2+
3+
import (
4+
"context"
5+
"net/http"
6+
"net/http/httptest"
7+
"testing"
8+
9+
"github.com/didip/tollbooth/v8"
10+
"github.com/didip/tollbooth/v8/limiter"
11+
"github.com/stretchr/testify/assert"
12+
)
13+
14+
func TestLimitHandler(t *testing.T) {
15+
16+
t.Run("basic request", func(t *testing.T) {
17+
lmt := tollbooth.NewLimiter(1, nil)
18+
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
19+
w.WriteHeader(http.StatusOK)
20+
})
21+
wrapped := LimitHandler(lmt)(handler)
22+
w := httptest.NewRecorder()
23+
r := httptest.NewRequest(http.MethodGet, "/test", nil)
24+
r.RemoteAddr = "127.0.0.1:12345"
25+
wrapped.ServeHTTP(w, r)
26+
assert.Equal(t, http.StatusOK, w.Code)
27+
})
28+
29+
t.Run("rate limit exceeded", func(t *testing.T) {
30+
lmt := tollbooth.NewLimiter(0.1, nil) // only allow one request per 10 seconds
31+
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
32+
w.WriteHeader(http.StatusOK)
33+
})
34+
wrapped := LimitHandler(lmt)(handler)
35+
36+
// first request
37+
w1 := httptest.NewRecorder()
38+
r1 := httptest.NewRequest(http.MethodGet, "/test", nil)
39+
r1.RemoteAddr = "127.0.0.1:12345"
40+
wrapped.ServeHTTP(w1, r1)
41+
42+
// immediate second request should fail
43+
w2 := httptest.NewRecorder()
44+
r2 := httptest.NewRequest(http.MethodGet, "/test", nil)
45+
r2.RemoteAddr = "127.0.0.1:12345"
46+
wrapped.ServeHTTP(w2, r2)
47+
48+
assert.Equal(t, http.StatusTooManyRequests, w2.Code)
49+
assert.Contains(t, w2.Body.String(), "maximum request limit")
50+
})
51+
52+
t.Run("context cancelled", func(t *testing.T) {
53+
lmt := tollbooth.NewLimiter(1, nil)
54+
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
55+
w.WriteHeader(http.StatusOK)
56+
})
57+
wrapped := LimitHandler(lmt)(handler)
58+
w := httptest.NewRecorder()
59+
r := httptest.NewRequest(http.MethodGet, "/test", nil)
60+
ctx, cancel := context.WithCancel(r.Context())
61+
cancel()
62+
r = r.WithContext(ctx)
63+
wrapped.ServeHTTP(w, r)
64+
assert.Equal(t, http.StatusServiceUnavailable, w.Code)
65+
assert.Contains(t, w.Body.String(), "Context was canceled")
66+
})
67+
68+
t.Run("custom error handler", func(t *testing.T) {
69+
lmt := tollbooth.NewLimiter(0.1, nil) // only allow one request per 10 seconds
70+
customMsg := "custom limit reached"
71+
lmt.SetMessage(customMsg)
72+
73+
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
74+
w.WriteHeader(http.StatusOK)
75+
})
76+
wrapped := LimitHandler(lmt)(handler)
77+
78+
// first request
79+
w1 := httptest.NewRecorder()
80+
r1 := httptest.NewRequest(http.MethodGet, "/test", nil)
81+
r1.RemoteAddr = "127.0.0.1:12345"
82+
wrapped.ServeHTTP(w1, r1)
83+
84+
// immediate second request should fail
85+
w2 := httptest.NewRecorder()
86+
r2 := httptest.NewRequest(http.MethodGet, "/test", nil)
87+
r2.RemoteAddr = "127.0.0.1:12345"
88+
wrapped.ServeHTTP(w2, r2)
89+
90+
assert.Equal(t, http.StatusTooManyRequests, w2.Code)
91+
assert.Contains(t, w2.Body.String(), customMsg)
92+
})
93+
94+
t.Run("default IP lookup", func(t *testing.T) {
95+
lmt := tollbooth.NewLimiter(0.1, nil)
96+
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
97+
w.WriteHeader(http.StatusOK)
98+
})
99+
wrapped := LimitHandler(lmt)(handler)
100+
101+
// first request
102+
w1 := httptest.NewRecorder()
103+
r1 := httptest.NewRequest(http.MethodGet, "/test", nil)
104+
r1.RemoteAddr = "127.0.0.1:12345"
105+
wrapped.ServeHTTP(w1, r1)
106+
107+
// second request should fail as default RemoteAddr will be used
108+
w2 := httptest.NewRecorder()
109+
r2 := httptest.NewRequest(http.MethodGet, "/test", nil)
110+
r2.RemoteAddr = "127.0.0.1:12345"
111+
wrapped.ServeHTTP(w2, r2)
112+
113+
assert.Equal(t, http.StatusTooManyRequests, w2.Code)
114+
})
115+
116+
t.Run("custom IP lookup", func(t *testing.T) {
117+
lmt := tollbooth.NewLimiter(0.1, nil)
118+
lmt.SetIPLookup(limiter.IPLookup{Name: "X-Real-IP"})
119+
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
120+
w.WriteHeader(http.StatusOK)
121+
})
122+
wrapped := LimitHandler(lmt)(handler)
123+
124+
// first request
125+
w1 := httptest.NewRecorder()
126+
r1 := httptest.NewRequest(http.MethodGet, "/test", nil)
127+
r1.Header.Set("X-Real-IP", "5.5.5.5")
128+
wrapped.ServeHTTP(w1, r1)
129+
130+
// second request with same X-Real-IP should fail
131+
w2 := httptest.NewRecorder()
132+
r2 := httptest.NewRequest(http.MethodGet, "/test", nil)
133+
r2.Header.Set("X-Real-IP", "5.5.5.5")
134+
wrapped.ServeHTTP(w2, r2)
135+
136+
assert.Equal(t, http.StatusTooManyRequests, w2.Code)
137+
138+
// request with different X-Real-IP should pass
139+
w3 := httptest.NewRecorder()
140+
r3 := httptest.NewRequest(http.MethodGet, "/test", nil)
141+
r3.Header.Set("X-Real-IP", "6.6.6.6")
142+
wrapped.ServeHTTP(w3, r3)
143+
144+
assert.Equal(t, http.StatusOK, w3.Code)
145+
})
146+
}

0 commit comments

Comments
 (0)