diff --git a/auth/auth_test.go b/auth/auth_test.go index ef8ea7b3..a0d40226 100644 --- a/auth/auth_test.go +++ b/auth/auth_test.go @@ -8,6 +8,7 @@ import ( "context" "errors" "net/http" + "net/http/httptest" "testing" "time" ) @@ -76,3 +77,105 @@ func TestVerify(t *testing.T) { }) } } + +// Integration tests for Security Best Practices conformance. +// 2.2 Token Passthrough. +// https://modelcontextprotocol.io/specification/2025-06-18/basic/security_best_practices. +// Table-driven middleware tests covering invalid tokens, scope enforcement, and OK path. +func TestBearerMiddleware(t *testing.T) { + const resourceMetadata = "https://auth.example/meta" + verifier := func(_ context.Context, tok string, _ *http.Request) (*TokenInfo, error) { + switch tok { + case "valid": + return &TokenInfo{Expiration: time.Now().Add(time.Hour)}, nil + default: + return nil, ErrInvalidToken + } + } + + tests := []struct { + name string + token string + scopes []string + wantCode int + wantCalled bool + }{ + {name: "invalid-aud", token: "bad-aud", wantCode: http.StatusUnauthorized, wantCalled: false}, + {name: "unknown-issuer", token: "unknown-issuer", wantCode: http.StatusUnauthorized, wantCalled: false}, + {name: "missing-scope", token: "valid", scopes: []string{"s1"}, wantCode: http.StatusForbidden, wantCalled: false}, + {name: "ok", token: "valid", wantCode: http.StatusOK, wantCalled: true}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + called := false + h := RequireBearerToken(verifier, &RequireBearerTokenOptions{ + ResourceMetadataURL: resourceMetadata, + Scopes: tt.scopes, + })(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + called = true + w.WriteHeader(http.StatusOK) + })) + + req := httptest.NewRequest(http.MethodGet, "http://example.com/", nil) + req.Header.Set("Authorization", "Bearer "+tt.token) + rw := httptest.NewRecorder() + + h.ServeHTTP(rw, req) + + if rw.Code != tt.wantCode { + t.Fatalf("got status %d, want %d", rw.Code, tt.wantCode) + } + if called != tt.wantCalled { + t.Fatalf("handler called=%v, want %v", called, tt.wantCalled) + } + if tt.wantCode == http.StatusUnauthorized || tt.wantCode == http.StatusForbidden { + want := "Bearer resource_metadata=" + resourceMetadata + if rw.Header().Get("WWW-Authenticate") != want { + t.Fatalf("unexpected WWW-Authenticate header: %q", rw.Header().Get("WWW-Authenticate")) + } + } + }) + } +} + +func TestHTTPMiddleware_NoTokenPassthrough(t *testing.T) { + // Downstream fake API that records the incoming Authorization header. + var gotAuth string + downstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + gotAuth = r.Header.Get("Authorization") + w.WriteHeader(http.StatusOK) + })) + defer downstream.Close() + + // Verifier accepts the incoming client token. + verifier := func(_ context.Context, token string, _ *http.Request) (*TokenInfo, error) { + if token != "client-token" { + return nil, ErrInvalidToken + } + return &TokenInfo{Expiration: time.Now().Add(time.Hour)}, nil + } + + wrapped := RequireBearerToken(verifier, nil)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Simulate proxy-like behavior: perform a downstream request without + // forwarding the client's Authorization header. + resp, err := http.Get(downstream.URL) + if err != nil { + t.Fatalf("downstream request failed: %v", err) + } + resp.Body.Close() + w.WriteHeader(http.StatusOK) + })) + + req := httptest.NewRequest(http.MethodGet, "http://example.com/", nil) + req.Header.Set("Authorization", "Bearer client-token") + rw := httptest.NewRecorder() + wrapped.ServeHTTP(rw, req) + + if rw.Code != http.StatusOK { + t.Fatalf("got status %d, want %d", rw.Code, http.StatusOK) + } + if gotAuth != "" { + t.Fatalf("downstream Authorization header should be empty; got %q", gotAuth) + } +}