@@ -8,10 +8,11 @@ import (
88 "sync"
99 "testing"
1010
11- "github.com/stacklok/toolhive/pkg/vmcp/auth/mocks"
1211 "github.com/stretchr/testify/assert"
1312 "github.com/stretchr/testify/require"
1413 "go.uber.org/mock/gomock"
14+
15+ "github.com/stacklok/toolhive/pkg/vmcp/auth/mocks"
1516)
1617
1718type testContextKey struct {}
@@ -164,6 +165,7 @@ func TestDefaultOutgoingAuthenticator_AuthenticateRequest(t *testing.T) {
164165 auth := NewDefaultOutgoingAuthenticator ()
165166 strategy := mocks .NewMockStrategy (ctrl )
166167 strategy .EXPECT ().Name ().Return ("bearer" ).AnyTimes ()
168+ strategy .EXPECT ().Validate (gomock .Any ()).Return (nil )
167169 strategy .EXPECT ().Authenticate (gomock .Any (), gomock .Any (), gomock .Any ()).DoAndReturn (
168170 func (_ context.Context , req * http.Request , _ map [string ]any ) error {
169171 // Add a header to verify the request was modified
@@ -201,6 +203,7 @@ func TestDefaultOutgoingAuthenticator_AuthenticateRequest(t *testing.T) {
201203 strategyErr := errors .New ("authentication failed" )
202204 strategy := mocks .NewMockStrategy (ctrl )
203205 strategy .EXPECT ().Name ().Return ("bearer" ).AnyTimes ()
206+ strategy .EXPECT ().Validate (gomock .Any ()).Return (nil )
204207 strategy .EXPECT ().Authenticate (gomock .Any (), gomock .Any (), gomock .Any ()).Return (strategyErr )
205208 require .NoError (t , auth .RegisterStrategy ("bearer" , strategy ))
206209
@@ -223,6 +226,7 @@ func TestDefaultOutgoingAuthenticator_AuthenticateRequest(t *testing.T) {
223226
224227 strategy := mocks .NewMockStrategy (ctrl )
225228 strategy .EXPECT ().Name ().Return ("bearer" ).AnyTimes ()
229+ strategy .EXPECT ().Validate (gomock .Any ()).Return (nil )
226230 strategy .EXPECT ().Authenticate (gomock .Any (), gomock .Any (), gomock .Any ()).DoAndReturn (
227231 func (ctx context.Context , _ * http.Request , metadata map [string ]any ) error {
228232 receivedCtx = ctx
@@ -246,6 +250,35 @@ func TestDefaultOutgoingAuthenticator_AuthenticateRequest(t *testing.T) {
246250 assert .Equal (t , "test-value" , receivedCtx .Value (testKey ))
247251 assert .Equal (t , metadata , receivedMetadata )
248252 })
253+
254+ t .Run ("validates metadata before authentication" , func (t * testing.T ) {
255+ t .Parallel ()
256+ ctrl := gomock .NewController (t )
257+ t .Cleanup (ctrl .Finish )
258+
259+ auth := NewDefaultOutgoingAuthenticator ()
260+ strategy := mocks .NewMockStrategy (ctrl )
261+ strategy .EXPECT ().Name ().Return ("test-strategy" ).AnyTimes ()
262+
263+ require .NoError (t , auth .RegisterStrategy ("test-strategy" , strategy ))
264+
265+ req := httptest .NewRequest (http .MethodGet , "http://example.com" , nil )
266+ metadata := map [string ]any {"invalid" : "data" }
267+
268+ // Expect Validate to be called and return error
269+ strategy .EXPECT ().
270+ Validate (metadata ).
271+ Return (errors .New ("invalid metadata" ))
272+
273+ // Authenticate should NOT be called if validation fails
274+ // (no EXPECT for Authenticate)
275+
276+ err := auth .AuthenticateRequest (context .Background (), req , "test-strategy" , metadata )
277+
278+ require .Error (t , err )
279+ assert .Contains (t , err .Error (), "invalid metadata for strategy" )
280+ assert .Contains (t , err .Error (), "test-strategy" )
281+ })
249282}
250283
251284func TestDefaultOutgoingAuthenticator_ConcurrentAccess (t * testing.T ) {
@@ -320,6 +353,7 @@ func TestDefaultOutgoingAuthenticator_ConcurrentAccess(t *testing.T) {
320353
321354 strategy := mocks .NewMockStrategy (ctrl )
322355 strategy .EXPECT ().Name ().Return ("bearer" ).AnyTimes ()
356+ strategy .EXPECT ().Validate (gomock .Any ()).Return (nil ).AnyTimes ()
323357 strategy .EXPECT ().Authenticate (gomock .Any (), gomock .Any (), gomock .Any ()).DoAndReturn (
324358 func (_ context.Context , req * http.Request , _ map [string ]any ) error {
325359 authMu .Lock ()
0 commit comments