From 3d44ee15c46fb0058bb00956f814d794cad79057 Mon Sep 17 00:00:00 2001 From: s0up4200 Date: Tue, 26 Nov 2024 02:07:53 +0100 Subject: [PATCH 1/5] fix(auth): exempt oauth callback from csrf protection --- internal/api/middleware/csrf.go | 2 +- internal/api/server.go | 3 +++ 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/internal/api/middleware/csrf.go b/internal/api/middleware/csrf.go index 85d4a32..0c756d3 100644 --- a/internal/api/middleware/csrf.go +++ b/internal/api/middleware/csrf.go @@ -52,7 +52,7 @@ func DefaultCSRFConfig() *CSRFConfig { HttpOnly: true, MaxAge: int(csrfTokenDuration.Seconds()), ExemptMethods: []string{"GET", "HEAD", "OPTIONS"}, - ExemptPaths: []string{}, + ExemptPaths: []string{"/api/auth/callback"}, } } diff --git a/internal/api/server.go b/internal/api/server.go index ee61675..40fdb49 100644 --- a/internal/api/server.go +++ b/internal/api/server.go @@ -86,6 +86,9 @@ func (s *Server) Handler() http.Handler { } r.Use(middleware.SetupCORS()) + // TODO: Enable secure headers? + //r.Use(middleware.Secure(nil)) + r.Use(middleware.CSRF(nil)) // Added CSRF middleware with default config that includes OAuth callback exemption // Create rate limiters with different configurations apiRateLimiter := middleware.NewRateLimiter(s.cache, time.Minute, 60, "api:") // 60 requests per minute for API From a89e8242d2125a5acdcd71d7efba7a45e4af7ada Mon Sep 17 00:00:00 2001 From: s0up4200 Date: Tue, 26 Nov 2024 08:38:31 +0100 Subject: [PATCH 2/5] fix(cors): allow credentials for oauth token exchange --- internal/api/middleware/cors.go | 14 ++++++++++++-- internal/api/server.go | 4 ++-- 2 files changed, 14 insertions(+), 4 deletions(-) diff --git a/internal/api/middleware/cors.go b/internal/api/middleware/cors.go index d4efc76..b150ebc 100644 --- a/internal/api/middleware/cors.go +++ b/internal/api/middleware/cors.go @@ -28,9 +28,19 @@ func SetupCORS() gin.HandlerFunc { "Content-Type", "Accept", "X-Requested-With", + "X-CSRF-Token", + "Access-Control-Allow-Origin", + "Access-Control-Allow-Headers", + "Access-Control-Allow-Methods", + "Access-Control-Allow-Credentials", }, - ExposeHeaders: []string{"Content-Length", "Content-Type"}, - MaxAge: 12 * time.Hour, + AllowCredentials: true, // Important for OAuth flows + ExposeHeaders: []string{ + "Content-Length", + "Content-Type", + "X-CSRF-Token", + }, + MaxAge: 12 * time.Hour, } return cors.New(config) diff --git a/internal/api/server.go b/internal/api/server.go index 40fdb49..5714f16 100644 --- a/internal/api/server.go +++ b/internal/api/server.go @@ -86,9 +86,9 @@ func (s *Server) Handler() http.Handler { } r.Use(middleware.SetupCORS()) - // TODO: Enable secure headers? + // TODO: Enable secure headers and CSRF? //r.Use(middleware.Secure(nil)) - r.Use(middleware.CSRF(nil)) // Added CSRF middleware with default config that includes OAuth callback exemption + //r.Use(middleware.CSRF(nil)) // Create rate limiters with different configurations apiRateLimiter := middleware.NewRateLimiter(s.cache, time.Minute, 60, "api:") // 60 requests per minute for API From 105aab1a5d50031892c1faff13c56379bddf6563 Mon Sep 17 00:00:00 2001 From: s0up4200 Date: Wed, 27 Nov 2024 10:44:23 +0100 Subject: [PATCH 3/5] fix(auth): handle missing token expiry from OIDC provider Some OIDC providers (like PocketID) don't include token expiry times in their responses. This causes authentication to fail as the session immediately expires. - Add default 24h expiry when token.Expiry is zero --- internal/api/handlers/auth.go | 53 +++++++++++++++++++++++++++++------ 1 file changed, 44 insertions(+), 9 deletions(-) diff --git a/internal/api/handlers/auth.go b/internal/api/handlers/auth.go index 6af71e2..376494b 100644 --- a/internal/api/handlers/auth.go +++ b/internal/api/handlers/auth.go @@ -220,6 +220,11 @@ func (h *AuthHandler) Callback(c *gin.Context) { code := c.Query("code") state := c.Query("state") + log.Debug(). + Bool("has_code", code != ""). + Str("state", state). + Msg("received callback") + if code == "" { log.Error().Msg("no code in callback") c.Redirect(http.StatusTemporaryRedirect, "/login?error=no_code") @@ -235,7 +240,9 @@ func (h *AuthHandler) Callback(c *gin.Context) { return } if err == cache.ErrKeyNotFound { - log.Debug().Msg("state not found or expired") + log.Debug(). + Str("state_key", stateKey). + Msg("state not found or expired") } else { log.Error().Err(err).Msg("failed to get state from cache") } @@ -256,7 +263,8 @@ func (h *AuthHandler) Callback(c *gin.Context) { } } - // Exchange code for token using context + log.Debug().Msg("exchanging code for token") + token, err := h.oauth2Config.Exchange(ctx, code) if err != nil { if ctx.Err() != nil { @@ -264,29 +272,47 @@ func (h *AuthHandler) Callback(c *gin.Context) { c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/login?error=timeout", frontendUrl)) return } - log.Error().Err(err).Msg("code exchange failed") + log.Error(). + Err(err). + Str("token_url", h.oauth2Config.Endpoint.TokenURL). + Msg("code exchange failed") c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/login?error=exchange_failed", frontendUrl)) return } + log.Debug(). + Bool("token_received", token.AccessToken != ""). + Msg("token exchange completed") + rawIDToken, ok := token.Extra("id_token").(string) if !ok { - log.Error().Msg("no id_token in token response") + log.Error(). + Interface("extras", token.Extra("")). + Msg("no id_token in token response") c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/login?error=no_id_token", frontendUrl)) return } + // Set a default expiry if none provided + expiryTime := token.Expiry + if expiryTime.IsZero() { + expiryTime = time.Now().Add(24 * time.Hour) + log.Debug(). + Time("assigned_expiry", expiryTime). + Msg("token had no expiry, assigned default 24 hours") + } + sessionData := types.SessionData{ AccessToken: token.AccessToken, TokenType: token.TokenType, RefreshToken: token.RefreshToken, IDToken: rawIDToken, - ExpiresAt: token.Expiry, + ExpiresAt: expiryTime, AuthType: "oidc", } sessionKey := fmt.Sprintf("oidc:session:%s", token.AccessToken) - if err := h.cache.Set(ctx, sessionKey, sessionData, time.Until(token.Expiry)); err != nil { + if err := h.cache.Set(ctx, sessionKey, sessionData, time.Until(expiryTime)); err != nil { if ctx.Err() != nil { log.Error().Err(ctx.Err()).Msg("Context canceled while storing session") c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/login?error=timeout", frontendUrl)) @@ -297,23 +323,32 @@ func (h *AuthHandler) Callback(c *gin.Context) { return } + log.Debug(). + Str("session_key_prefix", "oidc:session"). + Time("expires_at", expiryTime). + Msg("session stored in cache") + var isSecure = c.GetHeader("X-Forwarded-Proto") == "https" c.SetCookie( "session", token.AccessToken, - int(time.Until(token.Expiry).Seconds()), + int(time.Until(expiryTime).Seconds()), "/", "", isSecure, true, ) - c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s?access_token=%s&id_token=%s", + redirectURL := fmt.Sprintf("%s?access_token=%s&id_token=%s", frontendUrl, token.AccessToken, rawIDToken, - )) + ) + + //log.Debug().Msg("redirecting to frontend") + + c.Redirect(http.StatusTemporaryRedirect, redirectURL) } func (h *AuthHandler) Logout(c *gin.Context) { From 87856a1ac65bc406e09409cd09634a6843bf8aea Mon Sep 17 00:00:00 2001 From: s0up4200 Date: Mon, 2 Dec 2024 16:56:51 +0100 Subject: [PATCH 4/5] fix(auth): use random session ID instead of access token in cookies --- internal/api/handlers/auth.go | 18 +++++++++++------- 1 file changed, 11 insertions(+), 7 deletions(-) diff --git a/internal/api/handlers/auth.go b/internal/api/handlers/auth.go index 376494b..be3083e 100644 --- a/internal/api/handlers/auth.go +++ b/internal/api/handlers/auth.go @@ -311,7 +311,15 @@ func (h *AuthHandler) Callback(c *gin.Context) { AuthType: "oidc", } - sessionKey := fmt.Sprintf("oidc:session:%s", token.AccessToken) + // Generate a random session ID instead of using the access token + sessionID, err := generateSecureRandomString(32) + if err != nil { + log.Error().Err(err).Msg("failed to generate session ID") + c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/login?error=session_failed", frontendUrl)) + return + } + + sessionKey := fmt.Sprintf("oidc:session:%s", sessionID) if err := h.cache.Set(ctx, sessionKey, sessionData, time.Until(expiryTime)); err != nil { if ctx.Err() != nil { log.Error().Err(ctx.Err()).Msg("Context canceled while storing session") @@ -323,16 +331,12 @@ func (h *AuthHandler) Callback(c *gin.Context) { return } - log.Debug(). - Str("session_key_prefix", "oidc:session"). - Time("expires_at", expiryTime). - Msg("session stored in cache") - var isSecure = c.GetHeader("X-Forwarded-Proto") == "https" + // Set the random session ID in the cookie instead of the access token c.SetCookie( "session", - token.AccessToken, + sessionID, int(time.Until(expiryTime).Seconds()), "/", "", From 4d1e6e048597184cc49ab924ce0d1331c35ca674 Mon Sep 17 00:00:00 2001 From: s0up4200 Date: Mon, 2 Dec 2024 17:05:51 +0100 Subject: [PATCH 5/5] fix(auth): apply rate limiting to all OIDC endpoints including callback --- internal/api/server.go | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/internal/api/server.go b/internal/api/server.go index 5714f16..a07b9c0 100644 --- a/internal/api/server.go +++ b/internal/api/server.go @@ -93,7 +93,7 @@ func (s *Server) Handler() http.Handler { // Create rate limiters with different configurations apiRateLimiter := middleware.NewRateLimiter(s.cache, time.Minute, 60, "api:") // 60 requests per minute for API healthRateLimiter := middleware.NewRateLimiter(s.cache, time.Minute, 30, "health:") // 30 health checks per minute - authRateLimiter := middleware.NewRateLimiter(s.cache, time.Minute, 30, "auth:") // 30 auth requests per minute + authRateLimiter := middleware.NewRateLimiter(s.cache, time.Second, 1, "auth:") // 1 auth request per second // Special rate limiter for Tailscale services tailscaleRateLimiter := middleware.NewRateLimiter(s.cache, 2*time.Minute, 20, "tailscale:") // 20 requests per 2 minutes @@ -148,13 +148,14 @@ func (s *Server) Handler() http.Handler { // OIDC auth endpoints (only if OIDC is configured) if oidcAuthHandler != nil { - public.GET("/api/auth/callback", oidcAuthHandler.Callback) oidcAuth := public.Group("/api/auth/oidc") oidcAuth.Use(authRateLimiter.RateLimit()) { oidcAuth.GET("/login", oidcAuthHandler.Login) oidcAuth.POST("/logout", oidcAuthHandler.Logout) } + + public.GET("/api/auth/callback", authRateLimiter.RateLimit(), oidcAuthHandler.Callback) } // Built-in auth endpoints