Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
128 changes: 128 additions & 0 deletions internal/ghmcp/server_test.go
Original file line number Diff line number Diff line change
@@ -1,8 +1,12 @@
package ghmcp

import (
"context"
"net/http"
"net/http/httptest"
"testing"

"github.com/github/github-mcp-server/pkg/github"
"github.com/github/github-mcp-server/pkg/translations"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
Expand Down Expand Up @@ -110,3 +114,127 @@ func TestResolveEnabledToolsets(t *testing.T) {
})
}
}

// TestBearerAuthTransport_AddsGraphQLFeaturesHeader verifies that the bearerAuthTransport
// properly reads GraphQL features from context and adds them as a header.
func TestBearerAuthTransport_AddsGraphQLFeaturesHeader(t *testing.T) {
t.Parallel()

tests := []struct {
name string
features []string
expectHeader bool
expectedHeaderValue string
}{
{
name: "single feature",
features: []string{"issues_copilot_assignment_api_support"},
expectHeader: true,
expectedHeaderValue: "issues_copilot_assignment_api_support",
},
{
name: "multiple features",
features: []string{"feature1", "feature2", "feature3"},
expectHeader: true,
expectedHeaderValue: "feature1, feature2, feature3",
},
{
name: "no features",
features: []string{},
expectHeader: false,
},
{
name: "nil features",
features: nil,
expectHeader: false,
},
}

for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
// Create a test server that records the request
var capturedRequest *http.Request
testServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
capturedRequest = r
w.WriteHeader(http.StatusOK)
}))
defer testServer.Close()

// Create the transport chain
transport := &bearerAuthTransport{
transport: http.DefaultTransport,
token: "test-token",
}

// Create an HTTP client with the transport
client := &http.Client{Transport: transport}

// Create a context with GraphQL features
ctx := context.Background()
if tc.features != nil {
ctx = github.WithGraphQLFeatures(ctx, tc.features...)
}

// Make a request with the context
req, err := http.NewRequestWithContext(ctx, "POST", testServer.URL, nil)
require.NoError(t, err)

resp, err := client.Do(req)
require.NoError(t, err)
defer resp.Body.Close()

// Verify the Authorization header is set
assert.Equal(t, "Bearer test-token", capturedRequest.Header.Get("Authorization"))

// Verify the GraphQL-Features header
if tc.expectHeader {
assert.Equal(t, tc.expectedHeaderValue, capturedRequest.Header.Get("GraphQL-Features"))
} else {
assert.Empty(t, capturedRequest.Header.Get("GraphQL-Features"))
}
})
}
}

// TestUserAgentTransport_PreservesGraphQLFeatures verifies that the userAgentTransport
// doesn't interfere with GraphQL features set by bearerAuthTransport.
func TestUserAgentTransport_PreservesGraphQLFeatures(t *testing.T) {
t.Parallel()

// Create a test server that records the request
var capturedRequest *http.Request
testServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
capturedRequest = r
w.WriteHeader(http.StatusOK)
}))
defer testServer.Close()

// Create the transport chain (same as in production)
// userAgentTransport -> bearerAuthTransport -> http.DefaultTransport
transport := &userAgentTransport{
transport: &bearerAuthTransport{
transport: http.DefaultTransport,
token: "test-token",
},
agent: "test-agent/1.0.0",
}

// Create an HTTP client with the transport chain
client := &http.Client{Transport: transport}

// Create a context with GraphQL features
ctx := github.WithGraphQLFeatures(context.Background(), "issues_copilot_assignment_api_support")

// Make a request with the context
req, err := http.NewRequestWithContext(ctx, "POST", testServer.URL, nil)
require.NoError(t, err)

resp, err := client.Do(req)
require.NoError(t, err)
defer resp.Body.Close()

// Verify all headers are set correctly
assert.Equal(t, "test-agent/1.0.0", capturedRequest.Header.Get("User-Agent"))
assert.Equal(t, "Bearer test-token", capturedRequest.Header.Get("Authorization"))
assert.Equal(t, "issues_copilot_assignment_api_support", capturedRequest.Header.Get("GraphQL-Features"))
}
6 changes: 3 additions & 3 deletions pkg/github/issues.go
Original file line number Diff line number Diff line change
Expand Up @@ -1789,7 +1789,7 @@ func AssignCopilotToIssue(t translations.TranslationHelperFunc) inventory.Server

// Add the GraphQL-Features header for the agent assignment API
// The header will be read by the HTTP transport if it's configured to do so
ctxWithFeatures := withGraphQLFeatures(ctx, "issues_copilot_assignment_api_support")
ctxWithFeatures := WithGraphQLFeatures(ctx, "issues_copilot_assignment_api_support")

if err := client.Mutate(
ctxWithFeatures,
Expand Down Expand Up @@ -1917,8 +1917,8 @@ func AssignCodingAgentPrompt(t translations.TranslationHelperFunc) inventory.Ser
// graphQLFeaturesKey is a context key for GraphQL feature flags
type graphQLFeaturesKey struct{}

// withGraphQLFeatures adds GraphQL feature flags to the context
func withGraphQLFeatures(ctx context.Context, features ...string) context.Context {
// WithGraphQLFeatures adds GraphQL feature flags to the context
func WithGraphQLFeatures(ctx context.Context, features ...string) context.Context {
return context.WithValue(ctx, graphQLFeaturesKey{}, features)
}

Expand Down