Skip to content

Commit 3e51d71

Browse files
committed
Add context helpers for Identity propagation in vmcp auth
Adds several context-related helpers that will be used to propagate Identity through vMCP. Related: #2377
1 parent fa30c90 commit 3e51d71

File tree

2 files changed

+144
-0
lines changed

2 files changed

+144
-0
lines changed

pkg/vmcp/auth/context.go

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
package auth
2+
3+
import "context"
4+
5+
// IdentityContextKey is the key used to store Identity in the request context.
6+
// This provides type-safe context storage and retrieval for authenticated identities.
7+
//
8+
// Using an empty struct as the key prevents collisions with other context keys,
9+
// as each empty struct type is distinct even if they have the same name in different packages.
10+
type IdentityContextKey struct{}
11+
12+
// WithIdentity stores an Identity in the context.
13+
// If identity is nil, the original context is returned unchanged.
14+
//
15+
// This function is typically called by authentication middleware after successful
16+
// authentication to make the identity available to downstream handlers.
17+
//
18+
// Example:
19+
//
20+
// identity := &Identity{Subject: "user123", Name: "Alice"}
21+
// ctx = WithIdentity(ctx, identity)
22+
func WithIdentity(ctx context.Context, identity *Identity) context.Context {
23+
if identity == nil {
24+
return ctx
25+
}
26+
return context.WithValue(ctx, IdentityContextKey{}, identity)
27+
}
28+
29+
// IdentityFromContext retrieves an Identity from the context.
30+
// Returns the identity and true if present, nil and false otherwise.
31+
//
32+
// This function is typically called by authorization middleware or handlers that need
33+
// to check who the authenticated user is.
34+
//
35+
// Example:
36+
//
37+
// identity, ok := IdentityFromContext(ctx)
38+
// if !ok {
39+
// return errors.New("no authenticated identity")
40+
// }
41+
// log.Printf("Request from user: %s", identity.Subject)
42+
func IdentityFromContext(ctx context.Context) (*Identity, bool) {
43+
identity, ok := ctx.Value(IdentityContextKey{}).(*Identity)
44+
return identity, ok
45+
}

pkg/vmcp/auth/context_test.go

Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,99 @@
1+
package auth
2+
3+
import (
4+
"context"
5+
"testing"
6+
)
7+
8+
// TestIdentityContext_StoreAndRetrieve verifies basic context storage and retrieval functionality.
9+
func TestIdentityContext_StoreAndRetrieve(t *testing.T) {
10+
t.Parallel()
11+
ctx := context.Background()
12+
13+
// Create a test identity
14+
identity := &Identity{
15+
Subject: "user123",
16+
Name: "Alice Smith",
17+
18+
Groups: []string{"admins", "developers"},
19+
Claims: map[string]any{
20+
"org_id": "org456",
21+
},
22+
Token: "test-token",
23+
TokenType: "Bearer",
24+
Metadata: map[string]string{
25+
"source": "test",
26+
},
27+
}
28+
29+
// Store identity in context
30+
ctx = WithIdentity(ctx, identity)
31+
32+
// Retrieve identity from context
33+
retrieved, ok := IdentityFromContext(ctx)
34+
if !ok {
35+
t.Fatal("expected identity to be present in context")
36+
}
37+
38+
// Verify all fields match
39+
if retrieved.Subject != identity.Subject {
40+
t.Errorf("expected Subject %q, got %q", identity.Subject, retrieved.Subject)
41+
}
42+
if retrieved.Name != identity.Name {
43+
t.Errorf("expected Name %q, got %q", identity.Name, retrieved.Name)
44+
}
45+
if retrieved.Email != identity.Email {
46+
t.Errorf("expected Email %q, got %q", identity.Email, retrieved.Email)
47+
}
48+
if len(retrieved.Groups) != len(identity.Groups) {
49+
t.Errorf("expected %d groups, got %d", len(identity.Groups), len(retrieved.Groups))
50+
}
51+
if retrieved.Token != identity.Token {
52+
t.Errorf("expected Token %q, got %q", identity.Token, retrieved.Token)
53+
}
54+
if retrieved.TokenType != identity.TokenType {
55+
t.Errorf("expected TokenType %q, got %q", identity.TokenType, retrieved.TokenType)
56+
}
57+
}
58+
59+
// TestIdentityContext_NilIdentity verifies that WithIdentity returns the original context
60+
// when called with a nil identity.
61+
func TestIdentityContext_NilIdentity(t *testing.T) {
62+
t.Parallel()
63+
ctx := context.Background()
64+
65+
// Store nil identity - should return original context unchanged
66+
resultCtx := WithIdentity(ctx, nil)
67+
68+
// Verify we got the same context back
69+
if resultCtx != ctx {
70+
t.Error("expected WithIdentity to return original context when identity is nil")
71+
}
72+
73+
// Verify no identity is stored
74+
identity, ok := IdentityFromContext(resultCtx)
75+
if ok {
76+
t.Errorf("expected no identity in context, got %v", identity)
77+
}
78+
if identity != nil {
79+
t.Errorf("expected nil identity, got %v", identity)
80+
}
81+
}
82+
83+
// TestIdentityContext_NotPresent verifies that IdentityFromContext returns nil and false
84+
// when no identity has been stored in the context.
85+
func TestIdentityContext_NotPresent(t *testing.T) {
86+
t.Parallel()
87+
ctx := context.Background()
88+
89+
// Try to retrieve identity from empty context
90+
identity, ok := IdentityFromContext(ctx)
91+
92+
// Verify nothing was found
93+
if ok {
94+
t.Error("expected ok to be false when no identity is present")
95+
}
96+
if identity != nil {
97+
t.Errorf("expected nil identity, got %v", identity)
98+
}
99+
}

0 commit comments

Comments
 (0)