Skip to content

Commit c0f1bf0

Browse files
committed
Add metadata validation to AuthenticateRequest
Call strategy.Validate() before strategy.Authenticate() to catch invalid or malicious metadata early. This prevents type confusion, injection attacks, and panics from invalid metadata in strategy implementations. Changes: - Add Validate() call in AuthenticateRequest() - Proper error wrapping with strategy name - Add test verifying validation is enforced - Update existing tests to expect Validate() calls
1 parent 362380b commit c0f1bf0

File tree

2 files changed

+41
-1
lines changed

2 files changed

+41
-1
lines changed

pkg/vmcp/auth/outgoing_authenticator.go

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,7 @@ func (a *DefaultOutgoingAuthenticator) GetStrategy(name string) (Strategy, error
108108
//
109109
// Returns an error if:
110110
// - The strategy is not found
111+
// - The metadata validation fails
111112
// - The strategy's Authenticate method fails
112113
func (a *DefaultOutgoingAuthenticator) AuthenticateRequest(
113114
ctx context.Context,
@@ -120,5 +121,10 @@ func (a *DefaultOutgoingAuthenticator) AuthenticateRequest(
120121
return err
121122
}
122123

124+
// Validate metadata before using it
125+
if err := strategy.Validate(metadata); err != nil {
126+
return fmt.Errorf("invalid metadata for strategy %q: %w", strategyName, err)
127+
}
128+
123129
return strategy.Authenticate(ctx, req, metadata)
124130
}

pkg/vmcp/auth/outgoing_authenticator_test.go

Lines changed: 35 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,11 @@ import (
88
"sync"
99
"testing"
1010

11-
"github.com/stacklok/toolhive/pkg/vmcp/auth/mocks"
1211
"github.com/stretchr/testify/assert"
1312
"github.com/stretchr/testify/require"
1413
"go.uber.org/mock/gomock"
14+
15+
"github.com/stacklok/toolhive/pkg/vmcp/auth/mocks"
1516
)
1617

1718
type testContextKey struct{}
@@ -164,6 +165,7 @@ func TestDefaultOutgoingAuthenticator_AuthenticateRequest(t *testing.T) {
164165
auth := NewDefaultOutgoingAuthenticator()
165166
strategy := mocks.NewMockStrategy(ctrl)
166167
strategy.EXPECT().Name().Return("bearer").AnyTimes()
168+
strategy.EXPECT().Validate(gomock.Any()).Return(nil)
167169
strategy.EXPECT().Authenticate(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(
168170
func(_ context.Context, req *http.Request, _ map[string]any) error {
169171
// Add a header to verify the request was modified
@@ -201,6 +203,7 @@ func TestDefaultOutgoingAuthenticator_AuthenticateRequest(t *testing.T) {
201203
strategyErr := errors.New("authentication failed")
202204
strategy := mocks.NewMockStrategy(ctrl)
203205
strategy.EXPECT().Name().Return("bearer").AnyTimes()
206+
strategy.EXPECT().Validate(gomock.Any()).Return(nil)
204207
strategy.EXPECT().Authenticate(gomock.Any(), gomock.Any(), gomock.Any()).Return(strategyErr)
205208
require.NoError(t, auth.RegisterStrategy("bearer", strategy))
206209

@@ -223,6 +226,7 @@ func TestDefaultOutgoingAuthenticator_AuthenticateRequest(t *testing.T) {
223226

224227
strategy := mocks.NewMockStrategy(ctrl)
225228
strategy.EXPECT().Name().Return("bearer").AnyTimes()
229+
strategy.EXPECT().Validate(gomock.Any()).Return(nil)
226230
strategy.EXPECT().Authenticate(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(
227231
func(ctx context.Context, _ *http.Request, metadata map[string]any) error {
228232
receivedCtx = ctx
@@ -246,6 +250,35 @@ func TestDefaultOutgoingAuthenticator_AuthenticateRequest(t *testing.T) {
246250
assert.Equal(t, "test-value", receivedCtx.Value(testKey))
247251
assert.Equal(t, metadata, receivedMetadata)
248252
})
253+
254+
t.Run("validates metadata before authentication", func(t *testing.T) {
255+
t.Parallel()
256+
ctrl := gomock.NewController(t)
257+
t.Cleanup(ctrl.Finish)
258+
259+
auth := NewDefaultOutgoingAuthenticator()
260+
strategy := mocks.NewMockStrategy(ctrl)
261+
strategy.EXPECT().Name().Return("test-strategy").AnyTimes()
262+
263+
require.NoError(t, auth.RegisterStrategy("test-strategy", strategy))
264+
265+
req := httptest.NewRequest(http.MethodGet, "http://example.com", nil)
266+
metadata := map[string]any{"invalid": "data"}
267+
268+
// Expect Validate to be called and return error
269+
strategy.EXPECT().
270+
Validate(metadata).
271+
Return(errors.New("invalid metadata"))
272+
273+
// Authenticate should NOT be called if validation fails
274+
// (no EXPECT for Authenticate)
275+
276+
err := auth.AuthenticateRequest(context.Background(), req, "test-strategy", metadata)
277+
278+
require.Error(t, err)
279+
assert.Contains(t, err.Error(), "invalid metadata for strategy")
280+
assert.Contains(t, err.Error(), "test-strategy")
281+
})
249282
}
250283

251284
func TestDefaultOutgoingAuthenticator_ConcurrentAccess(t *testing.T) {
@@ -320,6 +353,7 @@ func TestDefaultOutgoingAuthenticator_ConcurrentAccess(t *testing.T) {
320353

321354
strategy := mocks.NewMockStrategy(ctrl)
322355
strategy.EXPECT().Name().Return("bearer").AnyTimes()
356+
strategy.EXPECT().Validate(gomock.Any()).Return(nil).AnyTimes()
323357
strategy.EXPECT().Authenticate(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(
324358
func(_ context.Context, req *http.Request, _ map[string]any) error {
325359
authMu.Lock()

0 commit comments

Comments
 (0)