diff --git a/proxyd/proxyd.go b/proxyd/proxyd.go index 6ce7edb0..c354bd2b 100644 --- a/proxyd/proxyd.go +++ b/proxyd/proxyd.go @@ -33,12 +33,6 @@ func Start(config *Config) (*Server, func(), error) { return nil, nil, errors.New("must define at least one RPC method mapping") } - for authKey := range config.Authentication { - if authKey == "none" { - return nil, nil, errors.New("cannot use none as an auth key") - } - } - // redis primary client var redisClient redis.UniversalClient if config.Redis.URL != "" { @@ -193,7 +187,6 @@ func Start(config *Config) (*Server, func(), error) { opts = append(opts, WithStrippedTrailingXFF()) } opts = append(opts, WithProxydIP(os.Getenv("PROXYD_IP"))) - opts = append(opts, WithSkipIsSyncingCheck(cfg.SkipIsSyncingCheck)) opts = append(opts, WithConsensusSkipPeerCountCheck(cfg.ConsensusSkipPeerCountCheck)) opts = append(opts, WithConsensusForcedCandidate(cfg.ConsensusForcedCandidate)) opts = append(opts, WithWeight(cfg.Weight)) @@ -285,16 +278,22 @@ func Start(config *Config) (*Server, func(), error) { } } - var resolvedAuth map[string]string + var resolvedAuth map[string]string = make(map[string]string) if config.Authentication != nil { - resolvedAuth = make(map[string]string) - for secret, alias := range config.Authentication { - resolvedSecret, err := ReadFromEnvOrConfig(secret) + if authURL, ok := config.Authentication["auth_url"]; ok { + resolvedAuth["auth_url"] = authURL + } + for key, alias := range config.Authentication { + if key == "auth_url" { + continue + } + + resolvedKey, err := ReadFromEnvOrConfig(key) if err != nil { return nil, nil, err } - resolvedAuth[resolvedSecret] = alias + resolvedAuth[resolvedKey] = alias } } diff --git a/proxyd/server.go b/proxyd/server.go index 60cfa3dc..f241e9d4 100644 --- a/proxyd/server.go +++ b/proxyd/server.go @@ -1,6 +1,7 @@ package proxyd import ( + "bytes" "context" "crypto/rand" "encoding/hex" @@ -50,6 +51,13 @@ const ( var emptyArrayResponse = json.RawMessage("[]") +type AuthCallbackRequest struct { + Headers map[string][]string `json:"headers"` + Path string `json:"path"` + Body string `json:"body"` + RemoteAddr string `json:"remote_addr"` +} + type Server struct { BackendGroups map[string]*BackendGroup wsBackendGroup *BackendGroup @@ -76,6 +84,7 @@ type Server struct { cache RPCCache srvMu sync.Mutex rateLimitHeader string + authClient *http.Client } type limiterFunc func(method string) bool @@ -100,6 +109,7 @@ func NewServer( maxBatchSize int, limiterFactory limiterFactoryFunc, ) (*Server, error) { + if cache == nil { cache = &NoopRPCCache{} } @@ -191,6 +201,9 @@ func NewServer( limExemptOrigins: limExemptOrigins, limExemptUserAgents: limExemptUserAgents, rateLimitHeader: rateLimitHeader, + authClient: &http.Client{ + Timeout: 5 * time.Second, + }, }, nil } @@ -614,6 +627,7 @@ func (s *Server) populateContext(w http.ResponseWriter, r *http.Request) context vars := mux.Vars(r) authorization := vars["authorization"] xff := r.Header.Get(s.rateLimitHeader) + if xff == "" { ipPort := strings.Split(r.RemoteAddr, ":") if len(ipPort) == 2 { @@ -628,22 +642,70 @@ func (s *Server) populateContext(w http.ResponseWriter, r *http.Request) context ctx = context.WithValue(ctx, ContextKeyOpTxProxyAuth, opTxProxyAuth) // nolint:staticcheck } - if len(s.authenticatedPaths) > 0 { - if authorization == "" || s.authenticatedPaths[authorization] == "" { - log.Info("blocked unauthorized request", "authorization", authorization) - httpResponseCodesTotal.WithLabelValues("401").Inc() - w.WriteHeader(401) + // Check if we have an external auth URL configured + authURL, hasExternalAuth := s.authenticatedPaths["auth_url"] + if hasExternalAuth && authURL != "" { // nolint:staticcheck + // Use external authentication service + alias, err := s.performAuthCallback(r, authURL) + if err != nil || alias == "" { // Check both error and empty alias + writeRPCError(ctx, w, nil, &RPCErr{Code: -32001, Message: "unauthorized"}) return nil } - + ctx = context.WithValue(ctx, ContextKeyAuth, alias) // nolint:staticcheck + } else { ctx = context.WithValue(ctx, ContextKeyAuth, s.authenticatedPaths[authorization]) // nolint:staticcheck } + return context.WithValue(ctx, ContextKeyReqID, randStr(10)) // nolint:staticcheck +} - return context.WithValue( - ctx, - ContextKeyReqID, // nolint:staticcheck - randStr(10), - ) +func (s *Server) performAuthCallback(r *http.Request, authURL string) (string, error) { + // Get the authorization token from the request + authorization := mux.Vars(r)["authorization"] + if authorization == "" { + return "", fmt.Errorf("missing authorization token") + } + + // Read the body first + bodyBytes, err := io.ReadAll(r.Body) + if err != nil { + log.Error("performAuthCallback failed to read request body", "err", err) + return "", fmt.Errorf("failed to read request body: %w", err) + } + defer r.Body.Close() + + // Create new body for original request + r.Body = io.NopCloser(bytes.NewBuffer(bodyBytes)) + + // Append the token to the auth URL path + authURLWithToken := fmt.Sprintf("%s/%s", + strings.TrimRight(authURL, "/"), + authorization) + + // Create new request to auth URL with same method, headers and new body copy + req, err := http.NewRequestWithContext(r.Context(), r.Method, authURLWithToken, + bytes.NewBuffer(bodyBytes)) + if err != nil { + log.Error("performAuthCallback failed to create request", "err", err) + return "", fmt.Errorf("failed to create auth request: %w", err) + } + + // Copy original headers + req.Header = r.Header.Clone() + + // Use the server's auth client + resp, err := s.authClient.Do(req) + if err != nil { + log.Error("performAuthCallback request failed", "err", err) + return "", fmt.Errorf("auth callback failed: %w", err) + } + defer resp.Body.Close() + + // Only reject if we get a 401 + if resp.StatusCode == http.StatusUnauthorized { + return "", fmt.Errorf("unauthorized") + } + + return authorization, nil } func randStr(l int) string { @@ -805,7 +867,6 @@ func GetAuthCtx(ctx context.Context) string { if !ok { return "none" } - return authUser }