diff --git a/internal/ghmcp/server.go b/internal/ghmcp/server.go index 250f6b4cc..4272fe4e8 100644 --- a/internal/ghmcp/server.go +++ b/internal/ghmcp/server.go @@ -94,10 +94,13 @@ func createGitHubClients(cfg MCPServerConfig, apiHost apiHost) (*githubClients, // Construct GraphQL client // We use NewEnterpriseClient unconditionally since we already parsed the API host + // Layer transports: DefaultTransport -> bearerAuthTransport -> GraphQLFeaturesTransport gqlHTTPClient := &http.Client{ - Transport: &bearerAuthTransport{ - transport: http.DefaultTransport, - token: cfg.Token, + Transport: &github.GraphQLFeaturesTransport{ + Transport: &bearerAuthTransport{ + transport: http.DefaultTransport, + token: cfg.Token, + }, }, } gqlClient := githubv4.NewEnterpriseClient(apiHost.graphqlURL.String(), gqlHTTPClient) @@ -622,12 +625,6 @@ type bearerAuthTransport struct { func (t *bearerAuthTransport) RoundTrip(req *http.Request) (*http.Response, error) { req = req.Clone(req.Context()) req.Header.Set("Authorization", "Bearer "+t.token) - - // Check for GraphQL-Features in context and add header if present - if features := github.GetGraphQLFeatures(req.Context()); len(features) > 0 { - req.Header.Set("GraphQL-Features", strings.Join(features, ", ")) - } - return t.transport.RoundTrip(req) } diff --git a/pkg/github/issues.go b/pkg/github/issues.go index 63174c9e9..912e62194 100644 --- a/pkg/github/issues.go +++ b/pkg/github/issues.go @@ -1789,7 +1789,7 @@ func AssignCopilotToIssue(t translations.TranslationHelperFunc) inventory.Server // Add the GraphQL-Features header for the agent assignment API // The header will be read by the HTTP transport if it's configured to do so - ctxWithFeatures := withGraphQLFeatures(ctx, "issues_copilot_assignment_api_support") + ctxWithFeatures := WithGraphQLFeatures(ctx, "issues_copilot_assignment_api_support") if err := client.Mutate( ctxWithFeatures, @@ -1913,19 +1913,3 @@ func AssignCodingAgentPrompt(t translations.TranslationHelperFunc) inventory.Ser }, ) } - -// graphQLFeaturesKey is a context key for GraphQL feature flags -type graphQLFeaturesKey struct{} - -// withGraphQLFeatures adds GraphQL feature flags to the context -func withGraphQLFeatures(ctx context.Context, features ...string) context.Context { - return context.WithValue(ctx, graphQLFeaturesKey{}, features) -} - -// GetGraphQLFeatures retrieves GraphQL feature flags from the context -func GetGraphQLFeatures(ctx context.Context) []string { - if features, ok := ctx.Value(graphQLFeaturesKey{}).([]string); ok { - return features - } - return nil -} diff --git a/pkg/github/transport.go b/pkg/github/transport.go new file mode 100644 index 000000000..5c14ba014 --- /dev/null +++ b/pkg/github/transport.go @@ -0,0 +1,86 @@ +package github + +import ( + "context" + "net/http" + "strings" +) + +// graphQLFeaturesKey is a context key for GraphQL feature flags. +// These flags enable preview or experimental GitHub API features that are not yet GA. +type graphQLFeaturesKey struct{} + +// WithGraphQLFeatures adds GraphQL feature flags to the context. +// The flags are read by GraphQLFeaturesTransport and sent as the GraphQL-Features header. +// This is used by tool handlers that require experimental GitHub API features. +// Remote servers can also use this function in tests to simulate feature flag contexts. +func WithGraphQLFeatures(ctx context.Context, features ...string) context.Context { + return context.WithValue(ctx, graphQLFeaturesKey{}, features) +} + +// GetGraphQLFeatures retrieves GraphQL feature flags from the context. +// This function is exported to allow custom HTTP transports (e.g., in remote servers) +// to read feature flags and add them as the "GraphQL-Features" header. +// +// For most use cases, use GraphQLFeaturesTransport instead of calling this directly. +func GetGraphQLFeatures(ctx context.Context) []string { + if features, ok := ctx.Value(graphQLFeaturesKey{}).([]string); ok { + return features + } + return nil +} + +// GraphQLFeaturesTransport is an http.RoundTripper that adds GraphQL-Features +// header based on context values set by WithGraphQLFeatures. +// +// This transport should be used in the HTTP client chain for githubv4.Client +// to ensure GraphQL feature flags are properly sent to the GitHub API. +// Without this transport, certain GitHub API features (like Copilot assignment) +// that require feature flags will fail with schema validation errors. +// +// Example usage for local server (layering with auth): +// +// httpClient := &http.Client{ +// Transport: &github.GraphQLFeaturesTransport{ +// Transport: &authTransport{ +// Transport: http.DefaultTransport, +// token: "ghp_...", +// }, +// }, +// } +// gqlClient := githubv4.NewClient(httpClient) +// +// Example usage for remote server (simple case): +// +// httpClient := &http.Client{ +// Transport: &github.GraphQLFeaturesTransport{ +// Transport: http.DefaultTransport, +// }, +// } +// gqlClient := githubv4.NewClient(httpClient) +// +// The transport reads feature flags from request context using GetGraphQLFeatures. +// Feature flags are added to context by the tool handler via WithGraphQLFeatures. +type GraphQLFeaturesTransport struct { + // Transport is the underlying http.RoundTripper. If nil, http.DefaultTransport is used. + Transport http.RoundTripper +} + +// RoundTrip implements http.RoundTripper. +// It adds the GraphQL-Features header if features are present in the request context. +func (t *GraphQLFeaturesTransport) RoundTrip(req *http.Request) (*http.Response, error) { + transport := t.Transport + if transport == nil { + transport = http.DefaultTransport + } + + // Clone request to avoid modifying the original + req = req.Clone(req.Context()) + + // Check for GraphQL-Features in context and add header if present + if features := GetGraphQLFeatures(req.Context()); len(features) > 0 { + req.Header.Set("GraphQL-Features", strings.Join(features, ", ")) + } + + return transport.RoundTrip(req) +} diff --git a/pkg/github/transport_test.go b/pkg/github/transport_test.go new file mode 100644 index 000000000..ed1fc7353 --- /dev/null +++ b/pkg/github/transport_test.go @@ -0,0 +1,142 @@ +package github + +import ( + "context" + "net/http" + "net/http/httptest" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestGraphQLFeaturesTransport(t *testing.T) { + tests := []struct { + name string + features []string + expectHeader bool + expectedHeaderVal string + }{ + { + name: "adds single feature to header", + features: []string{"issues_copilot_assignment_api_support"}, + expectHeader: true, + expectedHeaderVal: "issues_copilot_assignment_api_support", + }, + { + name: "adds multiple features to header", + features: []string{"feature1", "feature2", "feature3"}, + expectHeader: true, + expectedHeaderVal: "feature1, feature2, feature3", + }, + { + name: "no header when no features in context", + features: nil, + expectHeader: false, + }, + { + name: "no header when empty features slice", + features: []string{}, + expectHeader: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Create a test server that captures the request + var capturedReq *http.Request + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + capturedReq = r + w.WriteHeader(http.StatusOK) + })) + defer server.Close() + + // Create HTTP client with GraphQLFeaturesTransport + client := &http.Client{ + Transport: &GraphQLFeaturesTransport{ + Transport: http.DefaultTransport, + }, + } + + // Create request with or without features in context + ctx := context.Background() + if tt.features != nil { + ctx = WithGraphQLFeatures(ctx, tt.features...) + } + + req, err := http.NewRequestWithContext(ctx, "GET", server.URL, nil) + require.NoError(t, err) + + // Make request + resp, err := client.Do(req) + require.NoError(t, err) + defer resp.Body.Close() + + // Verify header + if tt.expectHeader { + assert.Equal(t, tt.expectedHeaderVal, capturedReq.Header.Get("GraphQL-Features")) + } else { + assert.Empty(t, capturedReq.Header.Get("GraphQL-Features")) + } + }) + } +} + +func TestGraphQLFeaturesTransport_NilTransport(t *testing.T) { + // Test that nil Transport falls back to http.DefaultTransport + var capturedReq *http.Request + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + capturedReq = r + w.WriteHeader(http.StatusOK) + })) + defer server.Close() + + client := &http.Client{ + Transport: &GraphQLFeaturesTransport{ + Transport: nil, // Explicitly nil + }, + } + + ctx := WithGraphQLFeatures(context.Background(), "test_feature") + req, err := http.NewRequestWithContext(ctx, "GET", server.URL, nil) + require.NoError(t, err) + + resp, err := client.Do(req) + require.NoError(t, err) + defer resp.Body.Close() + + assert.Equal(t, "test_feature", capturedReq.Header.Get("GraphQL-Features")) +} + +func TestGraphQLFeaturesTransport_PreservesOtherHeaders(t *testing.T) { + // Test that the transport doesn't interfere with other headers + var capturedReq *http.Request + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + capturedReq = r + w.WriteHeader(http.StatusOK) + })) + defer server.Close() + + client := &http.Client{ + Transport: &GraphQLFeaturesTransport{ + Transport: http.DefaultTransport, + }, + } + + ctx := WithGraphQLFeatures(context.Background(), "feature1") + req, err := http.NewRequestWithContext(ctx, "GET", server.URL, nil) + require.NoError(t, err) + + // Add custom headers + req.Header.Set("Authorization", "Bearer test-token") + req.Header.Set("User-Agent", "test-agent") + + resp, err := client.Do(req) + require.NoError(t, err) + defer resp.Body.Close() + + // Verify all headers are preserved + assert.Equal(t, "feature1", capturedReq.Header.Get("GraphQL-Features")) + assert.Equal(t, "Bearer test-token", capturedReq.Header.Get("Authorization")) + assert.Equal(t, "test-agent", capturedReq.Header.Get("User-Agent")) +}