Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 6 additions & 9 deletions internal/ghmcp/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -94,10 +94,13 @@ func createGitHubClients(cfg MCPServerConfig, apiHost apiHost) (*githubClients,

// Construct GraphQL client
// We use NewEnterpriseClient unconditionally since we already parsed the API host
// Layer transports: DefaultTransport -> bearerAuthTransport -> GraphQLFeaturesTransport
gqlHTTPClient := &http.Client{
Transport: &bearerAuthTransport{
transport: http.DefaultTransport,
token: cfg.Token,
Transport: &github.GraphQLFeaturesTransport{
Transport: &bearerAuthTransport{
transport: http.DefaultTransport,
token: cfg.Token,
},
},
}
gqlClient := githubv4.NewEnterpriseClient(apiHost.graphqlURL.String(), gqlHTTPClient)
Expand Down Expand Up @@ -622,12 +625,6 @@ type bearerAuthTransport struct {
func (t *bearerAuthTransport) RoundTrip(req *http.Request) (*http.Response, error) {
req = req.Clone(req.Context())
req.Header.Set("Authorization", "Bearer "+t.token)

// Check for GraphQL-Features in context and add header if present
if features := github.GetGraphQLFeatures(req.Context()); len(features) > 0 {
req.Header.Set("GraphQL-Features", strings.Join(features, ", "))
}

return t.transport.RoundTrip(req)
}

Expand Down
18 changes: 1 addition & 17 deletions pkg/github/issues.go
Original file line number Diff line number Diff line change
Expand Up @@ -1789,7 +1789,7 @@ func AssignCopilotToIssue(t translations.TranslationHelperFunc) inventory.Server

// Add the GraphQL-Features header for the agent assignment API
// The header will be read by the HTTP transport if it's configured to do so
ctxWithFeatures := withGraphQLFeatures(ctx, "issues_copilot_assignment_api_support")
ctxWithFeatures := WithGraphQLFeatures(ctx, "issues_copilot_assignment_api_support")

if err := client.Mutate(
ctxWithFeatures,
Expand Down Expand Up @@ -1913,19 +1913,3 @@ func AssignCodingAgentPrompt(t translations.TranslationHelperFunc) inventory.Ser
},
)
}

// graphQLFeaturesKey is a context key for GraphQL feature flags
type graphQLFeaturesKey struct{}

// withGraphQLFeatures adds GraphQL feature flags to the context
func withGraphQLFeatures(ctx context.Context, features ...string) context.Context {
return context.WithValue(ctx, graphQLFeaturesKey{}, features)
}

// GetGraphQLFeatures retrieves GraphQL feature flags from the context
func GetGraphQLFeatures(ctx context.Context) []string {
if features, ok := ctx.Value(graphQLFeaturesKey{}).([]string); ok {
return features
}
return nil
}
86 changes: 86 additions & 0 deletions pkg/github/transport.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
package github

import (
"context"
"net/http"
"strings"
)

// graphQLFeaturesKey is a context key for GraphQL feature flags.
// These flags enable preview or experimental GitHub API features that are not yet GA.
type graphQLFeaturesKey struct{}

// WithGraphQLFeatures adds GraphQL feature flags to the context.
// The flags are read by GraphQLFeaturesTransport and sent as the GraphQL-Features header.
// This is used by tool handlers that require experimental GitHub API features.
// Remote servers can also use this function in tests to simulate feature flag contexts.
func WithGraphQLFeatures(ctx context.Context, features ...string) context.Context {
return context.WithValue(ctx, graphQLFeaturesKey{}, features)
}

// GetGraphQLFeatures retrieves GraphQL feature flags from the context.
// This function is exported to allow custom HTTP transports (e.g., in remote servers)
// to read feature flags and add them as the "GraphQL-Features" header.
//
// For most use cases, use GraphQLFeaturesTransport instead of calling this directly.
func GetGraphQLFeatures(ctx context.Context) []string {
if features, ok := ctx.Value(graphQLFeaturesKey{}).([]string); ok {
return features
}
return nil
}

// GraphQLFeaturesTransport is an http.RoundTripper that adds GraphQL-Features
// header based on context values set by WithGraphQLFeatures.
//
// This transport should be used in the HTTP client chain for githubv4.Client
// to ensure GraphQL feature flags are properly sent to the GitHub API.
// Without this transport, certain GitHub API features (like Copilot assignment)
// that require feature flags will fail with schema validation errors.
//
// Example usage for local server (layering with auth):
//
// httpClient := &http.Client{
// Transport: &github.GraphQLFeaturesTransport{
// Transport: &authTransport{
// Transport: http.DefaultTransport,
// token: "ghp_...",
// },
// },
// }
// gqlClient := githubv4.NewClient(httpClient)
//
// Example usage for remote server (simple case):
//
// httpClient := &http.Client{
// Transport: &github.GraphQLFeaturesTransport{
// Transport: http.DefaultTransport,
// },
// }
// gqlClient := githubv4.NewClient(httpClient)
//
// The transport reads feature flags from request context using GetGraphQLFeatures.
// Feature flags are added to context by the tool handler via WithGraphQLFeatures.
type GraphQLFeaturesTransport struct {
// Transport is the underlying http.RoundTripper. If nil, http.DefaultTransport is used.
Transport http.RoundTripper
}

// RoundTrip implements http.RoundTripper.
// It adds the GraphQL-Features header if features are present in the request context.
func (t *GraphQLFeaturesTransport) RoundTrip(req *http.Request) (*http.Response, error) {
transport := t.Transport
if transport == nil {
transport = http.DefaultTransport
}

// Clone request to avoid modifying the original
req = req.Clone(req.Context())

// Check for GraphQL-Features in context and add header if present
if features := GetGraphQLFeatures(req.Context()); len(features) > 0 {
req.Header.Set("GraphQL-Features", strings.Join(features, ", "))
}

return transport.RoundTrip(req)
}
142 changes: 142 additions & 0 deletions pkg/github/transport_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,142 @@
package github

import (
"context"
"net/http"
"net/http/httptest"
"testing"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

func TestGraphQLFeaturesTransport(t *testing.T) {
tests := []struct {
name string
features []string
expectHeader bool
expectedHeaderVal string
}{
{
name: "adds single feature to header",
features: []string{"issues_copilot_assignment_api_support"},
expectHeader: true,
expectedHeaderVal: "issues_copilot_assignment_api_support",
},
{
name: "adds multiple features to header",
features: []string{"feature1", "feature2", "feature3"},
expectHeader: true,
expectedHeaderVal: "feature1, feature2, feature3",
},
{
name: "no header when no features in context",
features: nil,
expectHeader: false,
},
{
name: "no header when empty features slice",
features: []string{},
expectHeader: false,
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// Create a test server that captures the request
var capturedReq *http.Request
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
capturedReq = r
w.WriteHeader(http.StatusOK)
}))
defer server.Close()

// Create HTTP client with GraphQLFeaturesTransport
client := &http.Client{
Transport: &GraphQLFeaturesTransport{
Transport: http.DefaultTransport,
},
}

// Create request with or without features in context
ctx := context.Background()
if tt.features != nil {
ctx = WithGraphQLFeatures(ctx, tt.features...)
}

req, err := http.NewRequestWithContext(ctx, "GET", server.URL, nil)
require.NoError(t, err)

// Make request
resp, err := client.Do(req)
require.NoError(t, err)
defer resp.Body.Close()

// Verify header
if tt.expectHeader {
assert.Equal(t, tt.expectedHeaderVal, capturedReq.Header.Get("GraphQL-Features"))
} else {
assert.Empty(t, capturedReq.Header.Get("GraphQL-Features"))
}
})
}
}

func TestGraphQLFeaturesTransport_NilTransport(t *testing.T) {
// Test that nil Transport falls back to http.DefaultTransport
var capturedReq *http.Request
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
capturedReq = r
w.WriteHeader(http.StatusOK)
}))
defer server.Close()

client := &http.Client{
Transport: &GraphQLFeaturesTransport{
Transport: nil, // Explicitly nil
},
}

ctx := WithGraphQLFeatures(context.Background(), "test_feature")
req, err := http.NewRequestWithContext(ctx, "GET", server.URL, nil)
require.NoError(t, err)

resp, err := client.Do(req)
require.NoError(t, err)
defer resp.Body.Close()

assert.Equal(t, "test_feature", capturedReq.Header.Get("GraphQL-Features"))
}

func TestGraphQLFeaturesTransport_PreservesOtherHeaders(t *testing.T) {
// Test that the transport doesn't interfere with other headers
var capturedReq *http.Request
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
capturedReq = r
w.WriteHeader(http.StatusOK)
}))
defer server.Close()

client := &http.Client{
Transport: &GraphQLFeaturesTransport{
Transport: http.DefaultTransport,
},
}

ctx := WithGraphQLFeatures(context.Background(), "feature1")
req, err := http.NewRequestWithContext(ctx, "GET", server.URL, nil)
require.NoError(t, err)

// Add custom headers
req.Header.Set("Authorization", "Bearer test-token")
req.Header.Set("User-Agent", "test-agent")

resp, err := client.Do(req)
require.NoError(t, err)
defer resp.Body.Close()

// Verify all headers are preserved
assert.Equal(t, "feature1", capturedReq.Header.Get("GraphQL-Features"))
assert.Equal(t, "Bearer test-token", capturedReq.Header.Get("Authorization"))
assert.Equal(t, "test-agent", capturedReq.Header.Get("User-Agent"))
}