Skip to content
12 changes: 4 additions & 8 deletions internal/ghmcp/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -99,8 +99,10 @@ func createGitHubClients(cfg MCPServerConfig, apiHost apiHost) (*githubClients,
// We use NewEnterpriseClient unconditionally since we already parsed the API host
gqlHTTPClient := &http.Client{
Transport: &bearerAuthTransport{
transport: http.DefaultTransport,
token: cfg.Token,
transport: &github.GraphQLFeaturesTransport{
Transport: http.DefaultTransport,
},
token: cfg.Token,
},
}
gqlClient := githubv4.NewEnterpriseClient(apiHost.graphqlURL.String(), gqlHTTPClient)
Expand Down Expand Up @@ -639,12 +641,6 @@ type bearerAuthTransport struct {
func (t *bearerAuthTransport) RoundTrip(req *http.Request) (*http.Response, error) {
req = req.Clone(req.Context())
req.Header.Set("Authorization", "Bearer "+t.token)

// Check for GraphQL-Features in context and add header if present
if features := github.GetGraphQLFeatures(req.Context()); len(features) > 0 {
req.Header.Set("GraphQL-Features", strings.Join(features, ", "))
}

return t.transport.RoundTrip(req)
}

Expand Down
176 changes: 174 additions & 2 deletions pkg/github/issues.go
Original file line number Diff line number Diff line change
Expand Up @@ -1609,6 +1609,104 @@ func (d *mvpDescription) String() string {
return sb.String()
}

// linkedPullRequest represents a PR linked to an issue by Copilot.
type linkedPullRequest struct {
Number int
URL string
Title string
State string
CreatedAt time.Time
}

// pollConfigKey is a context key for polling configuration.
type pollConfigKey struct{}

// PollConfig configures the PR polling behavior.
type PollConfig struct {
MaxAttempts int
Delay time.Duration
}

// ContextWithPollConfig returns a context with polling configuration.
// Use this in tests to reduce or disable polling.
func ContextWithPollConfig(ctx context.Context, config PollConfig) context.Context {
return context.WithValue(ctx, pollConfigKey{}, config)
}

// getPollConfig returns the polling configuration from context, or defaults.
func getPollConfig(ctx context.Context) PollConfig {
if config, ok := ctx.Value(pollConfigKey{}).(PollConfig); ok {
return config
}
// Default: 9 attempts with 1s delay = 8s max wait
// Based on observed latency in remote server: p50 ~5s, p90 ~7s
return PollConfig{MaxAttempts: 9, Delay: 1 * time.Second}
}

// findLinkedCopilotPR searches for a PR created by the copilot-swe-agent bot that references the given issue.
// It queries the issue's timeline for CrossReferencedEvent items from PRs authored by copilot-swe-agent.
// The createdAfter parameter filters to only return PRs created after the specified time.
func findLinkedCopilotPR(ctx context.Context, client *githubv4.Client, owner, repo string, issueNumber int, createdAfter time.Time) (*linkedPullRequest, error) {
// Query timeline items looking for CrossReferencedEvent from PRs by copilot-swe-agent
var query struct {
Repository struct {
Issue struct {
TimelineItems struct {
Nodes []struct {
TypeName string `graphql:"__typename"`
CrossReferencedEvent struct {
Source struct {
PullRequest struct {
Number int
URL string
Title string
State string
CreatedAt githubv4.DateTime
Author struct {
Login string
}
} `graphql:"... on PullRequest"`
}
} `graphql:"... on CrossReferencedEvent"`
}
} `graphql:"timelineItems(first: 20, itemTypes: [CROSS_REFERENCED_EVENT])"`
} `graphql:"issue(number: $number)"`
} `graphql:"repository(owner: $owner, name: $name)"`
}

variables := map[string]any{
"owner": githubv4.String(owner),
"name": githubv4.String(repo),
"number": githubv4.Int(issueNumber), //nolint:gosec // Issue numbers are always small positive integers
}

if err := client.Query(ctx, &query, variables); err != nil {
return nil, err
}

// Look for a PR from copilot-swe-agent created after the assignment time
for _, node := range query.Repository.Issue.TimelineItems.Nodes {
if node.TypeName != "CrossReferencedEvent" {
continue
}
pr := node.CrossReferencedEvent.Source.PullRequest
if pr.Number > 0 && pr.Author.Login == "copilot-swe-agent" {
// Only return PRs created after the assignment time
if pr.CreatedAt.Time.After(createdAfter) {
return &linkedPullRequest{
Number: pr.Number,
URL: pr.URL,
Title: pr.Title,
State: pr.State,
CreatedAt: pr.CreatedAt.Time,
}, nil
}
}
}

return nil, nil
}

func AssignCopilotToIssue(t translations.TranslationHelperFunc) inventory.ServerTool {
description := mvpDescription{
summary: "Assign Copilot to a specific issue in a GitHub repository.",
Expand Down Expand Up @@ -1659,7 +1757,7 @@ func AssignCopilotToIssue(t translations.TranslationHelperFunc) inventory.Server
},
},
[]scopes.Scope{scopes.Repo},
func(ctx context.Context, deps ToolDependencies, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) {
func(ctx context.Context, deps ToolDependencies, request *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) {
var params struct {
Owner string `mapstructure:"owner"`
Repo string `mapstructure:"repo"`
Expand Down Expand Up @@ -1802,6 +1900,9 @@ func AssignCopilotToIssue(t translations.TranslationHelperFunc) inventory.Server
// The header will be read by the HTTP transport if it's configured to do so
ctxWithFeatures := withGraphQLFeatures(ctx, "issues_copilot_assignment_api_support")

// Capture the time before assignment to filter out older PRs during polling
assignmentTime := time.Now().UTC()

if err := client.Mutate(
ctxWithFeatures,
&updateIssueMutation,
Expand All @@ -1815,7 +1916,78 @@ func AssignCopilotToIssue(t translations.TranslationHelperFunc) inventory.Server
return nil, nil, fmt.Errorf("failed to update issue with agent assignment: %w", err)
}

return utils.NewToolResultText("successfully assigned copilot to issue"), nil, nil
// Poll for a linked PR created by Copilot after the assignment
pollConfig := getPollConfig(ctx)

// Get progress token from request for sending progress notifications
progressToken := request.Params.GetProgressToken()

// Send initial progress notification that assignment succeeded and polling is starting
if progressToken != nil && request.Session != nil && pollConfig.MaxAttempts > 0 {
_ = request.Session.NotifyProgress(ctx, &mcp.ProgressNotificationParams{
ProgressToken: progressToken,
Progress: 0,
Total: float64(pollConfig.MaxAttempts),
Message: "Copilot assigned to issue, waiting for PR creation...",
})
}

var linkedPR *linkedPullRequest
for attempt := range pollConfig.MaxAttempts {
if attempt > 0 {
time.Sleep(pollConfig.Delay)
}

// Send progress notification if progress token is available
if progressToken != nil && request.Session != nil {
_ = request.Session.NotifyProgress(ctx, &mcp.ProgressNotificationParams{
ProgressToken: progressToken,
Progress: float64(attempt + 1),
Total: float64(pollConfig.MaxAttempts),
Message: fmt.Sprintf("Waiting for Copilot to create PR... (attempt %d/%d)", attempt+1, pollConfig.MaxAttempts),
})
}

pr, err := findLinkedCopilotPR(ctx, client, params.Owner, params.Repo, int(params.IssueNumber), assignmentTime)
if err != nil {
// Polling errors are non-fatal, continue to next attempt
continue
}
if pr != nil {
linkedPR = pr
break
}
}

// Build the result
result := map[string]any{
"message": "successfully assigned copilot to issue",
"issue_number": int(updateIssueMutation.UpdateIssue.Issue.Number),
"issue_url": string(updateIssueMutation.UpdateIssue.Issue.URL),
"owner": params.Owner,
"repo": params.Repo,
}

// Add PR info if found during polling
if linkedPR != nil {
result["pull_request"] = map[string]any{
"number": linkedPR.Number,
"url": linkedPR.URL,
"title": linkedPR.Title,
"state": linkedPR.State,
}
result["message"] = "successfully assigned copilot to issue - pull request created"
} else {
result["message"] = "successfully assigned copilot to issue - pull request pending"
result["note"] = "The pull request may still be in progress. Once created, the PR number can be used to check job status, or check the issue timeline for updates."
}

r, err := json.Marshal(result)
if err != nil {
return utils.NewToolResultError(fmt.Sprintf("failed to marshal response: %s", err)), nil, nil
}

return utils.NewToolResultText(string(r)), result, nil
})
}

Expand Down
17 changes: 15 additions & 2 deletions pkg/github/issues_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2765,8 +2765,12 @@ func TestAssignCopilotToIssue(t *testing.T) {
// Create call request
request := createMCPRequest(tc.requestArgs)

// Disable polling in tests to avoid timeouts
ctx := ContextWithPollConfig(context.Background(), PollConfig{MaxAttempts: 0})
ctx = ContextWithDeps(ctx, deps)

// Call handler
result, err := handler(ContextWithDeps(context.Background(), deps), &request)
result, err := handler(ctx, &request)
require.NoError(t, err)

textContent := getTextResult(t, result)
Expand All @@ -2778,7 +2782,16 @@ func TestAssignCopilotToIssue(t *testing.T) {
}

require.False(t, result.IsError, fmt.Sprintf("expected there to be no tool error, text was %s", textContent.Text))
require.Equal(t, textContent.Text, "successfully assigned copilot to issue")

// Verify the JSON response contains expected fields
var response map[string]any
err = json.Unmarshal([]byte(textContent.Text), &response)
require.NoError(t, err, "response should be valid JSON")
assert.Equal(t, float64(123), response["issue_number"])
assert.Equal(t, "https://github.com/owner/repo/issues/123", response["issue_url"])
assert.Equal(t, "owner", response["owner"])
assert.Equal(t, "repo", response["repo"])
assert.Contains(t, response["message"], "successfully assigned copilot to issue")
})
}
}
Expand Down
47 changes: 47 additions & 0 deletions pkg/github/transport.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
package github

import (
"net/http"
"strings"
)

// GraphQLFeaturesTransport is an http.RoundTripper that adds GraphQL-Features
// header to requests based on context values. This is required for using
// non-GA GraphQL API features like the agent assignment API.
//
// This transport is used internally by the MCP server and is also exported
// for library consumers who need to build their own HTTP clients with
// GraphQL feature flag support.
//
// Usage:
//
// httpClient := &http.Client{
// Transport: &github.GraphQLFeaturesTransport{
// Transport: http.DefaultTransport,
// },
// }
// gqlClient := githubv4.NewClient(httpClient)
//
// Then use withGraphQLFeatures(ctx, "feature_name") when calling GraphQL operations.
type GraphQLFeaturesTransport struct {
// Transport is the underlying HTTP transport. If nil, http.DefaultTransport is used.
Transport http.RoundTripper
}

// RoundTrip implements http.RoundTripper.
func (t *GraphQLFeaturesTransport) RoundTrip(req *http.Request) (*http.Response, error) {
transport := t.Transport
if transport == nil {
transport = http.DefaultTransport
}

// Clone the request to avoid mutating the original
req = req.Clone(req.Context())

// Check for GraphQL-Features in context and add header if present
if features := GetGraphQLFeatures(req.Context()); len(features) > 0 {
req.Header.Set("GraphQL-Features", strings.Join(features, ", "))
}

return transport.RoundTrip(req)
}
Loading