diff --git a/client/src/App.tsx b/client/src/App.tsx index 12e9a7bd0..0757c5e00 100644 --- a/client/src/App.tsx +++ b/client/src/App.tsx @@ -1148,21 +1148,50 @@ const App = () => { ? ((response as { _meta?: Record })._meta ?? {}) : undefined; - latestToolResult = { - content: [ + + if (taskStatus.status === "input_required") { + // Per MCP spec: when input_required, call tasks/result to give + // the server a chance to deliver queued elicitation/sampling + // requests. After elicitation is handled, the task transitions + // back to "working" — do NOT set taskCompleted here. + latestToolResult = { + content: [ + { + type: "text", + text: `Task status: input_required${taskStatus.statusMessage ? ` - ${taskStatus.statusMessage}` : ""}. Awaiting input...`, + }, + ], + _meta: { + ...(pollingResponseMeta || {}), + "io.modelcontextprotocol/related-task": { taskId }, + }, + }; + setToolResult(latestToolResult); + await sendMCPRequest( { - type: "text", - text: `Task status: ${taskStatus.status}${taskStatus.statusMessage ? ` - ${taskStatus.statusMessage}` : ""}. Polling...`, + method: "tasks/result", + params: { taskId }, }, - ], - _meta: { - ...(pollingResponseMeta || {}), - "io.modelcontextprotocol/related-task": { taskId }, - }, - }; - setToolResult(latestToolResult); - // Refresh tasks list to show progress - void listTasks(); + CompatibilityCallToolResultSchema, + ); + void listTasks(); + } else { + latestToolResult = { + content: [ + { + type: "text", + text: `Task status: ${taskStatus.status}${taskStatus.statusMessage ? ` - ${taskStatus.statusMessage}` : ""}. Polling...`, + }, + ], + _meta: { + ...(pollingResponseMeta || {}), + "io.modelcontextprotocol/related-task": { taskId }, + }, + }; + setToolResult(latestToolResult); + // Refresh tasks list to show progress + void listTasks(); + } } } catch (pollingError) { console.error("Error polling task status:", pollingError); diff --git a/client/src/__tests__/App.taskPolling.test.tsx b/client/src/__tests__/App.taskPolling.test.tsx new file mode 100644 index 000000000..1eff4951c --- /dev/null +++ b/client/src/__tests__/App.taskPolling.test.tsx @@ -0,0 +1,336 @@ +import { fireEvent, render, screen, waitFor } from "@testing-library/react"; +import "@testing-library/jest-dom"; +import App from "../App"; +import { useConnection } from "../lib/hooks/useConnection"; +import type { Client } from "@modelcontextprotocol/sdk/client/index.js"; + +// Mock auth dependencies +jest.mock("@modelcontextprotocol/sdk/client/auth.js", () => ({ + auth: jest.fn(), +})); + +jest.mock("../lib/oauth-state-machine", () => ({ + OAuthStateMachine: jest.fn(), +})); + +jest.mock("../lib/auth", () => ({ + InspectorOAuthClientProvider: jest.fn().mockImplementation(() => ({ + tokens: jest.fn().mockResolvedValue(null), + clear: jest.fn(), + })), + DebugInspectorOAuthClientProvider: jest.fn(), +})); + +jest.mock("../utils/configUtils", () => ({ + ...jest.requireActual("../utils/configUtils"), + getMCPProxyAddress: jest.fn(() => "http://localhost:6277"), + getMCPProxyAuthToken: jest.fn(() => ({ + token: "", + header: "X-MCP-Proxy-Auth", + })), + getMCPTaskTtl: jest.fn(() => 30000), + getInitialTransportType: jest.fn(() => "stdio"), + getInitialSseUrl: jest.fn(() => "http://localhost:3001/sse"), + getInitialCommand: jest.fn(() => "mcp-server-everything"), + getInitialArgs: jest.fn(() => ""), + initializeInspectorConfig: jest.fn(() => ({})), + saveInspectorConfig: jest.fn(), +})); + +jest.mock("../lib/hooks/useDraggablePane", () => ({ + useDraggablePane: () => ({ + height: 300, + handleDragStart: jest.fn(), + }), + useDraggableSidebar: () => ({ + width: 320, + isDragging: false, + handleDragStart: jest.fn(), + }), +})); + +jest.mock("../components/Sidebar", () => ({ + __esModule: true, + default: () =>
Sidebar
, +})); + +jest.mock("../components/ResourcesTab", () => ({ + __esModule: true, + default: () =>
ResourcesTab
, +})); + +jest.mock("../components/PromptsTab", () => ({ + __esModule: true, + default: () =>
PromptsTab
, +})); + +jest.mock("../components/TasksTab", () => ({ + __esModule: true, + default: () =>
TasksTab
, +})); + +jest.mock("../components/ConsoleTab", () => ({ + __esModule: true, + default: () =>
ConsoleTab
, +})); + +jest.mock("../components/PingTab", () => ({ + __esModule: true, + default: () =>
PingTab
, +})); + +jest.mock("../components/SamplingTab", () => ({ + __esModule: true, + default: () =>
SamplingTab
, +})); + +jest.mock("../components/RootsTab", () => ({ + __esModule: true, + default: () =>
RootsTab
, +})); + +jest.mock("../components/ElicitationTab", () => ({ + __esModule: true, + default: () =>
ElicitationTab
, +})); + +jest.mock("../components/MetadataTab", () => ({ + __esModule: true, + default: () =>
MetadataTab
, +})); + +jest.mock("../components/AuthDebugger", () => ({ + __esModule: true, + default: () =>
AuthDebugger
, +})); + +jest.mock("../components/HistoryAndNotifications", () => ({ + __esModule: true, + default: () =>
HistoryAndNotifications
, +})); + +jest.mock("../components/AppsTab", () => ({ + __esModule: true, + default: () =>
AppsTab
, +})); + +jest.mock("../components/ToolsTab", () => ({ + __esModule: true, + default: ({ + callTool, + toolResult, + }: { + callTool: ( + name: string, + params: Record, + metadata?: Record, + runAsTask?: boolean, + ) => Promise; + toolResult: { content: Array<{ type: string; text: string }> } | null; + }) => ( +
+ + {toolResult && ( +
+ {toolResult.content.map((c, i) => ( + {c.text} + ))} +
+ )} +
+ ), +})); + +global.fetch = jest.fn().mockResolvedValue({ json: () => Promise.resolve({}) }); + +jest.mock("../lib/hooks/useConnection", () => ({ + useConnection: jest.fn(), +})); + +describe("App - task polling with input_required status", () => { + const mockUseConnection = jest.mocked(useConnection); + + beforeEach(() => { + jest.clearAllMocks(); + window.location.hash = "#tools"; + }); + + it("calls tasks/result when polling sees input_required, then continues until completed", async () => { + const taskId = "task-abc-123"; + let tasksGetCallCount = 0; + + const makeRequest = jest.fn(async (request: { method: string }) => { + if (request.method === "tools/list") { + return { + tools: [ + { + name: "myTool", + inputSchema: { type: "object", properties: {} }, + }, + ], + nextCursor: undefined, + }; + } + + if (request.method === "tools/call") { + return { + task: { taskId, status: "input_required", pollInterval: 10 }, + }; + } + + if (request.method === "tasks/get") { + tasksGetCallCount++; + // First poll: still input_required; second poll: completed + if (tasksGetCallCount === 1) { + return { + taskId, + status: "input_required", + statusMessage: "Needs input", + }; + } + return { taskId, status: "completed" }; + } + + if (request.method === "tasks/result") { + return { + content: [{ type: "text", text: "final task result" }], + }; + } + + if (request.method === "tasks/list") { + return { tasks: [] }; + } + + throw new Error(`Unexpected method: ${request.method}`); + }); + + mockUseConnection.mockReturnValue({ + connectionStatus: "connected", + serverCapabilities: { tools: { listChanged: true } }, + serverImplementation: null, + mcpClient: { + request: jest.fn(), + notification: jest.fn(), + close: jest.fn(), + } as unknown as Client, + requestHistory: [], + clearRequestHistory: jest.fn(), + makeRequest, + cancelTask: jest.fn(), + listTasks: jest.fn(), + sendNotification: jest.fn(), + handleCompletion: jest.fn(), + completionsSupported: false, + connect: jest.fn(), + disconnect: jest.fn(), + } as ReturnType); + + render(); + + fireEvent.click(screen.getByRole("button", { name: /run task tool/i })); + + // tasks/result should be called for input_required, then again for completed + await waitFor(() => { + const resultCalls = makeRequest.mock.calls.filter( + ([req]) => (req as { method: string }).method === "tasks/result", + ); + expect(resultCalls.length).toBeGreaterThanOrEqual(1); + }); + + // Final result should be displayed after completed + await waitFor(() => { + expect(screen.getByTestId("tool-result")).toHaveTextContent( + "final task result", + ); + }); + }); + + it("does not call tasks/result while status is working (non-input_required non-terminal)", async () => { + const taskId = "task-working-456"; + let tasksGetCallCount = 0; + + const makeRequest = jest.fn(async (request: { method: string }) => { + if (request.method === "tools/list") { + return { + tools: [ + { + name: "myTool", + inputSchema: { type: "object", properties: {} }, + }, + ], + nextCursor: undefined, + }; + } + + if (request.method === "tools/call") { + return { + task: { taskId, status: "working", pollInterval: 10 }, + }; + } + + if (request.method === "tasks/get") { + tasksGetCallCount++; + if (tasksGetCallCount < 3) { + return { taskId, status: "working" }; + } + return { taskId, status: "completed" }; + } + + if (request.method === "tasks/result") { + return { + content: [{ type: "text", text: "working tool result" }], + }; + } + + if (request.method === "tasks/list") { + return { tasks: [] }; + } + + throw new Error(`Unexpected method: ${request.method}`); + }); + + mockUseConnection.mockReturnValue({ + connectionStatus: "connected", + serverCapabilities: { tools: { listChanged: true } }, + serverImplementation: null, + mcpClient: { + request: jest.fn(), + notification: jest.fn(), + close: jest.fn(), + } as unknown as Client, + requestHistory: [], + clearRequestHistory: jest.fn(), + makeRequest, + cancelTask: jest.fn(), + listTasks: jest.fn(), + sendNotification: jest.fn(), + handleCompletion: jest.fn(), + completionsSupported: false, + connect: jest.fn(), + disconnect: jest.fn(), + } as ReturnType); + + render(); + + fireEvent.click(screen.getByRole("button", { name: /run task tool/i })); + + await waitFor(() => { + expect(screen.getByTestId("tool-result")).toHaveTextContent( + "working tool result", + ); + }); + + // tasks/result should only have been called once — for the completed status, not for working + const resultCalls = makeRequest.mock.calls.filter( + ([req]) => (req as { method: string }).method === "tasks/result", + ); + expect(resultCalls).toHaveLength(1); + }); +});