Skip to content

Commit

Permalink
feat(middleware): adding middleware to protect against csrf attacks
Browse files Browse the repository at this point in the history
  • Loading branch information
bgcicca committed Oct 25, 2024
1 parent e3b5816 commit ed09969
Show file tree
Hide file tree
Showing 3 changed files with 94 additions and 2 deletions.
35 changes: 33 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ So, if you’re looking for a friendly and efficient way to build web apps in Go
- [x] manipulation of the methods (get, post, put, delete ...) 100%
- [x] plugin support 100%
- [x] more detailed error logs 100%
- [ ] middleware (~~authentication~~, ~~timeout~~, anti csrf, ~~logging~~, etc...) 75%
- [x] basical middlewares (~~authentication~~, ~~timeout~~, ~~csrf~~, ~~logging~~, etc...) 100%
- [ ] next func 0 %
- [ ] More complete documentation 0%

Expand Down Expand Up @@ -72,14 +72,45 @@ func main() {
```

### basic use middlewares
csrf middleware

```go
func isValidToken(token string) bool {
return token == "fixed-valid-csrf-token"
}

func main () {
csrfToken := middleware.GenerateCSRFToken()
middleware.SetValidCSRFToken(csrfToken)

app.Use(func(handler func(*req.Request, *req.Response)) func(*req.Request, *req.Response) {
return func(r *req.Request, res *req.Response) {
token := r.Header("X-CSRF-Token")
if token != csrfToken {
res.SetStatus(403)
res.Write([]byte("Invalid CSRF token"))
return
}

handler(r, res)
}
})

app.Get("/example", func(r *req.Request, res *req.Response) {
res.SetHeader("Content-Type", "text/plain")
res.Write([]byte("Hello from /example route with CSRF protection"))
})

}
```

```go
app.Use(middleware.LoggingMiddleware)
app.Use(middleware.TimeoutMiddleware(2 * time.Second))
app.Use(middleware.NewAuthMiddleware)
```

### Basic plugin example
### basic plugin example

```go
package main
Expand Down
24 changes: 24 additions & 0 deletions middleware/csrf.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
package middleware

import (
"crypto/rand"
"encoding/base64"
"net/http"
)

func CSRFMiddleware(next http.HandlerFunc, isValidToken func(string) bool) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
csrfToken := r.Header.Get("X-CSRF-Token")
if csrfToken == "" || !isValidToken(csrfToken) {
http.Error(w, "Invalid CSRF token", http.StatusForbidden)
return
}
next(w, r)
}
}

func GenerateCSRFToken() string {
b := make([]byte, 32)
rand.Read(b)
return base64.StdEncoding.EncodeToString(b)
}
37 changes: 37 additions & 0 deletions middleware/csrf_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
package middleware

import (
"net/http"
"net/http/httptest"
"testing"
)

func TestCSRFMiddleware(t *testing.T) {

validToken := GenerateCSRFToken()
isValidToken := func(token string) bool {
return token == validToken
}

handler := CSRFMiddleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}), isValidToken)

req := httptest.NewRequest(http.MethodGet, "/", nil)
req.Header.Set("X-CSRF-Token", validToken)
respRec := httptest.NewRecorder()
handler.ServeHTTP(respRec, req)

if respRec.Code != http.StatusOK {
t.Errorf("Expected status 200, got %d", respRec.Code)
}

invalidReq := httptest.NewRequest(http.MethodGet, "/", nil)
invalidReq.Header.Set("X-CSRF-Token", "invalid-token")
invalidRespRec := httptest.NewRecorder()
handler.ServeHTTP(invalidRespRec, invalidReq)

if invalidRespRec.Code != http.StatusForbidden {
t.Errorf("Expected status 403, got %d", invalidRespRec.Code)
}
}

0 comments on commit ed09969

Please sign in to comment.