Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
76 changes: 58 additions & 18 deletions pkg/cli/logs_timeout_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -109,25 +109,65 @@ func TestTimeoutLogic(t *testing.T) {
}
}

// TestMCPServerDefaultTimeout tests that the MCP server sets a default timeout
func TestMCPServerDefaultTimeout(t *testing.T) {
// Test that when no timeout is specified, MCP server uses 1 minute
timeoutValue := 0
if timeoutValue == 0 {
timeoutValue = 1
}

if timeoutValue != 1 {
t.Errorf("Expected MCP server default timeout to be 1 but got %d", timeoutValue)
}

// Test that explicit timeout overrides the default
timeoutValue = 5
if timeoutValue == 0 {
timeoutValue = 1
// TestEffectiveMCPLogsToolTimeoutMinutes verifies that the MCP logs tool
// scales its implicit timeout with larger fetch windows while preserving
// explicit user-provided timeouts.
func TestEffectiveMCPLogsToolTimeoutMinutes(t *testing.T) {
tests := []struct {
name string
requestedTimeout int
count int
want int
}{
{
name: "explicit timeout is preserved",
requestedTimeout: 5,
count: 100,
want: 5,
},
{
name: "small fetch window keeps one minute default",
requestedTimeout: 0,
count: 40,
want: 1,
},
{
name: "fetch window above forty runs gets two minutes",
requestedTimeout: 0,
count: 41,
want: 2,
},
{
name: "eighty run fetch window stays in two minute tier",
requestedTimeout: 0,
count: 80,
want: 2,
},
{
name: "eighty one run fetch window enters three minute tier",
requestedTimeout: 0,
count: 81,
want: 3,
},
{
name: "default hundred run window gets three minutes",
requestedTimeout: 0,
count: 100,
want: 3,
},
{
name: "unspecified count falls back to default window size",
requestedTimeout: 0,
count: 0,
want: 3,
},

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The test suite covers the 40 → 41 boundary (1 → 2 min) but misses the 80 → 81 boundary (2 → 3 min). Consider adding test cases for count=80 (expect 2 min) and count=81 (expect 3 min) to fully exercise both tier transitions and guard against accidental formula changes.

@copilot please address this.

}

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[/tdd] The tier 1→2 boundary (count=40/41) is tested but the tier 2→3 boundary (count=80/81) is not. A future off-by-one in the ceiling-division formula would only be caught at the tested boundaries.

💡 Two additional cases to add before line 152
{
	name:             "top of tier 2 keeps two minutes",
	requestedTimeout: 0,
	count:            80,
	want:             2,
},
{
	name:             "bottom of tier 3 gets three minutes",
	requestedTimeout: 0,
	count:            81,
	want:             3,
},

These pin both sides of the second tier boundary, matching the depth of coverage already present for the first boundary (40/41).

@copilot please address this.


if timeoutValue != 5 {
t.Errorf("Expected explicit timeout to be preserved but got %d", timeoutValue)
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := effectiveMCPLogsToolTimeoutMinutes(tt.requestedTimeout, tt.count); got != tt.want {
t.Errorf("effectiveMCPLogsToolTimeoutMinutes(%d, %d) = %d, want %d", tt.requestedTimeout, tt.count, got, tt.want)
}
})
}
}
15 changes: 8 additions & 7 deletions pkg/cli/mcp_server_defaults_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -66,11 +66,11 @@ func TestMCPToolElicitationDefaults(t *testing.T) {
t.Fatalf("Failed to generate schema: %v", err)
}

// Add defaults as done in createMCPServer
if err := AddSchemaDefault(schema, "count", 100); err != nil {
// Add defaults as done in registerLogsTool
if err := AddSchemaDefault(schema, "count", defaultMCPLogsToolCount); err != nil {
t.Fatalf("Failed to add count default: %v", err)
}
if err := AddSchemaDefault(schema, "timeout", 50); err != nil {
if err := AddSchemaDefault(schema, "timeout", defaultMCPLogsToolTimeoutMinutesForCount(defaultMCPLogsToolCount)); err != nil {
t.Fatalf("Failed to add timeout default: %v", err)
}
if err := AddSchemaDefault(schema, "max_tokens", 12000); err != nil {
Expand All @@ -92,8 +92,8 @@ func TestMCPToolElicitationDefaults(t *testing.T) {
if err := json.Unmarshal(countProp.Default, &countDefault); err != nil {
t.Fatalf("Failed to unmarshal count default: %v", err)
}
if countDefault != 100 {
t.Errorf("Expected count default to be 100, got %v", countDefault)
if countDefault != defaultMCPLogsToolCount {
t.Errorf("Expected count default to be %d, got %v", defaultMCPLogsToolCount, countDefault)
}

// Verify timeout default
Expand All @@ -108,8 +108,9 @@ func TestMCPToolElicitationDefaults(t *testing.T) {
if err := json.Unmarshal(timeoutProp.Default, &timeoutDefault); err != nil {
t.Fatalf("Failed to unmarshal timeout default: %v", err)
}
if timeoutDefault != 50 {
t.Errorf("Expected timeout default to be 50, got %v", timeoutDefault)
expectedTimeoutDefault := defaultMCPLogsToolTimeoutMinutesForCount(defaultMCPLogsToolCount)
if timeoutDefault != expectedTimeoutDefault {
t.Errorf("Expected timeout default to be %d, got %v", expectedTimeoutDefault, timeoutDefault)
}

// Verify max_tokens default (backward-compat field)
Expand Down
54 changes: 41 additions & 13 deletions pkg/cli/mcp_tools_privileged.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,12 @@ import (
"github.com/modelcontextprotocol/go-sdk/mcp"
)

const (
defaultMCPLogsToolCount = 100
defaultMCPLogsTimeoutMinutes = 1
mcpLogsRunsPerDefaultTimeoutMinute = 40
)

// appendRepoFlagFromEnv appends "--repo <owner/repo>" to args when GITHUB_REPOSITORY
// is set in the environment. This allows gh CLI subcommands to identify the repository
// without falling back to git-based detection, which fails in sandboxed environments
Expand All @@ -41,11 +47,34 @@ type logsArgs struct {
Branch string `json:"branch,omitempty" jsonschema:"Filter runs by branch name"`
AfterRunID int64 `json:"after_run_id,omitempty" jsonschema:"Filter runs with database ID after this value (exclusive)"`
BeforeRunID int64 `json:"before_run_id,omitempty" jsonschema:"Filter runs with database ID before this value (exclusive)"`
Timeout int `json:"timeout,omitempty" jsonschema:"Maximum time in minutes to spend downloading logs (default: 1 for MCP server)"`
Timeout int `json:"timeout,omitempty" jsonschema:"Maximum time in minutes to spend downloading logs (default: auto-scales with count in the MCP server, rounded up in 40-run increments; e.g. 1 minute up to 40, 2 minutes for 41-80, 3 minutes for 81-120, and so on)"`
MaxTokens int `json:"max_tokens,omitempty" jsonschema:"Deprecated: accepted for backward compatibility but ignored. Output is always written to a file."`
Artifacts []string `json:"artifacts,omitempty" jsonschema:"Artifact sets to download (default: usage). Valid sets: all, activation, agent, detection, experiment, firewall, github-api, mcp, usage"`
}

func defaultMCPLogsToolTimeoutMinutesForCount(count int) int {
count = effectiveMCPLogsToolCount(count)

// Round up in 40-run increments so requests slightly above the current
// 60-second threshold automatically get an extra minute of budget.
return max(defaultMCPLogsTimeoutMinutes, (count+mcpLogsRunsPerDefaultTimeoutMinute-1)/mcpLogsRunsPerDefaultTimeoutMinute)
}

func effectiveMCPLogsToolCount(count int) int {
if count > 0 {
return count
}
return defaultMCPLogsToolCount
}

func effectiveMCPLogsToolTimeoutMinutes(requestedTimeout, count int) int {
if requestedTimeout > 0 {
return requestedTimeout
}

return defaultMCPLogsToolTimeoutMinutesForCount(count)
}

// The logs tool requires write+ access and checks actor permissions.
// Returns an error if schema generation fails.
func registerLogsTool(server *mcp.Server, execCmd execCmdFunc, actor string, validateActor bool) error {
Expand All @@ -56,10 +85,12 @@ func registerLogsTool(server *mcp.Server, execCmd execCmdFunc, actor string, val
return err
}
// Add elicitation defaults for common parameters
if err := AddSchemaDefault(logsSchema, "count", 100); err != nil {
if err := AddSchemaDefault(logsSchema, "count", defaultMCPLogsToolCount); err != nil {
mcpLog.Printf("Failed to add default for count: %v", err)
}
if err := AddSchemaDefault(logsSchema, "timeout", 1); err != nil {
// Schema default corresponds to defaultMCPLogsToolCount; runtime timeout
// scales with the effective count used for the request.
if err := AddSchemaDefault(logsSchema, "timeout", defaultMCPLogsToolTimeoutMinutesForCount(defaultMCPLogsToolCount)); err != nil {

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The schema default for timeout is pinned to the 100-run tier (3 min). A caller that provides an explicit count=10 but no timeout will get a 1-minute effective timeout at runtime, despite the schema hint suggesting 3. The jsonschema tag description on the field already documents the tiers, but the schema default could mislead callers who assume it applies to any count value.

Consider adding a small comment here like:

// Schema default corresponds to defaultMCPLogsToolCount; actual runtime value
// scales down for smaller count values — see defaultMCPLogsToolTimeoutMinutesForCount.

@copilot please address this.

mcpLog.Printf("Failed to add default for timeout: %v", err)
}
if err := AddSchemaDefault(logsSchema, "max_tokens", 12000); err != nil {
Expand Down Expand Up @@ -141,9 +172,8 @@ from where the previous request stopped due to timeout.`,
if args.WorkflowName != "" {
cmdArgs = append(cmdArgs, args.WorkflowName)
}
if args.Count > 0 {
cmdArgs = append(cmdArgs, "-c", strconv.Itoa(args.Count))
}
effectiveCount := effectiveMCPLogsToolCount(args.Count)
cmdArgs = append(cmdArgs, "-c", strconv.Itoa(effectiveCount))
if args.StartDate != "" {
cmdArgs = append(cmdArgs, "--start-date", args.StartDate)
}
Expand Down Expand Up @@ -179,19 +209,17 @@ from where the previous request stopped due to timeout.`,

cmdArgs = appendRepoFlagFromEnv(cmdArgs)

// Set timeout to 1 minute for MCP server if not explicitly specified
timeoutValue := args.Timeout
if timeoutValue == 0 {
timeoutValue = 1
}
// Scale the implicit MCP timeout with the requested fetch window so
// larger fleet-wide requests do not hit the 60s server deadline by default.
timeoutValue := effectiveMCPLogsToolTimeoutMinutes(args.Timeout, effectiveCount)
cmdArgs = append(cmdArgs, "--timeout", strconv.Itoa(timeoutValue))
Comment on lines 210 to 215

// Always use --json mode in MCP server
cmdArgs = append(cmdArgs, "--json")

// Log the command being executed for debugging
mcpLog.Printf("Executing logs tool: workflow=%s, count=%d, firewall=%v, no_firewall=%v, filtered_integrity=%v, timeout=%d, command_args=%v",
args.WorkflowName, args.Count, args.Firewall, args.NoFirewall, args.FilteredIntegrity, timeoutValue, cmdArgs)
mcpLog.Printf("Executing logs tool: workflow=%s, requested_count=%d, effective_count=%d, firewall=%v, no_firewall=%v, filtered_integrity=%v, timeout=%d, command_args=%v",
args.WorkflowName, args.Count, effectiveCount, args.Firewall, args.NoFirewall, args.FilteredIntegrity, timeoutValue, cmdArgs)

notifyProgress(ctx, req, 0, 100, "Downloading workflow logs...")

Expand Down
33 changes: 33 additions & 0 deletions pkg/cli/mcp_tools_privileged_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (
"os"
"os/exec"
"slices"
"strconv"
"strings"
"sync"
"testing"
Expand Down Expand Up @@ -261,6 +262,38 @@ func TestLogsToolPassesArtifactsArgument(t *testing.T) {
t.Fatal("expected --artifacts flag in command args")
}

func TestLogsToolUsesEffectiveCountForTimeoutScaling(t *testing.T) {
t.Run("omitted count uses MCP default for both -c and --timeout", func(t *testing.T) {
var capturedArgs []string
mockExecCmd := func(ctx context.Context, args ...string) *exec.Cmd {
capturedArgs = append([]string(nil), args...)
return exec.CommandContext(ctx, "sh", "-c", `printf '%s' "$1"`, "sh", `{"file_path":"/tmp/gh-aw/aw-mcp/logs/runs.json"}`)
}

server := mcp.NewServer(&mcp.Implementation{Name: "test", Version: "1.0"}, nil)
err := registerLogsTool(server, mockExecCmd, "", false)
require.NoError(t, err, "registerLogsTool should succeed")

session := connectInMemory(t, server)
_, err = session.CallTool(context.Background(), &mcp.CallToolParams{
Name: "logs",
Arguments: map[string]any{},
})
require.NoError(t, err, "logs tool should succeed")

countIndex := slices.Index(capturedArgs, "-c")
require.NotEqual(t, -1, countIndex, "logs tool should pass -c to keep MCP/CLI defaults aligned")
require.Less(t, countIndex+1, len(capturedArgs), "-c should have a value")
assert.Equal(t, strconv.Itoa(defaultMCPLogsToolCount), capturedArgs[countIndex+1])

timeoutIndex := slices.Index(capturedArgs, "--timeout")
require.NotEqual(t, -1, timeoutIndex, "logs tool should pass --timeout")
require.Less(t, timeoutIndex+1, len(capturedArgs), "--timeout should have a value")
assert.Equal(t, strconv.Itoa(defaultMCPLogsToolTimeoutMinutesForCount(defaultMCPLogsToolCount)), capturedArgs[timeoutIndex+1])
})

}

// TestAuditToolPassesGithubRepositoryAsRepoFlag verifies that the audit MCP tool
// appends --repo <owner/repo> to the subprocess command when GITHUB_REPOSITORY
// is set, allowing the audit command to resolve the repository without git.
Expand Down
Loading