diff --git a/src/mcp-client/local.ts b/src/mcp-client/local.ts index a4b1992..0569440 100644 --- a/src/mcp-client/local.ts +++ b/src/mcp-client/local.ts @@ -3,29 +3,70 @@ import { StdioClientTransport } from "@modelcontextprotocol/sdk/client/stdio.js" import type { LocalMCPServerConfig } from "./types"; import type { MCPClient } from "./types"; +/** + * Transport-like interface for DI/testing + */ +export interface Transport { + close(): Promise; +} + +/** + * Client-like interface for DI/testing + */ +export interface LocalClientLike { + connect(transport: Transport): Promise; + listTools(): Promise<{ tools: any[] }>; + callTool(request: { name: string; arguments: Record }): Promise; +} + +/** + * Options for LocalMCPClient including DI seams for testing + */ +export interface LocalMCPClientOptions { + /** Override Client creation for testing */ + clientFactory?: (name: string) => LocalClientLike; + /** Override transport creation for testing */ + transportFactory?: (opts: { + command: string; + args: string[]; + env: Record; + stderr: "pipe" | "inherit" | "ignore"; + }) => Transport; +} + /** * Local MCP client using stdio transport */ export class LocalMCPClient implements MCPClient { - private client: Client; - private transport: StdioClientTransport | null; + private client: LocalClientLike; + private transport: Transport | null; private toolsCache: any[] | null; private name: string; private config: LocalMCPServerConfig; + private transportFactory: NonNullable; - constructor(config: { name: string } & LocalMCPServerConfig) { + constructor( + config: { name: string } & LocalMCPServerConfig, + options?: LocalMCPClientOptions + ) { this.transport = null; this.toolsCache = null; this.name = config.name; this.config = config; - this.client = new Client( - { - name: `opencode-toolbox-client-${this.name}`, - version: "0.1.0", - }, - {} + // Use provided factories or defaults + const clientFactory = options?.clientFactory ?? ((name: string) => + new Client( + { name: `opencode-toolbox-client-${name}`, version: "0.1.0" }, + {} + ) + ); + + this.transportFactory = options?.transportFactory ?? ((opts) => + new StdioClientTransport(opts) ); + + this.client = clientFactory(this.name); } async connect(): Promise { @@ -33,14 +74,14 @@ export class LocalMCPClient implements MCPClient { throw new Error(`Local MCP server ${this.name} has no command`); } - this.transport = new StdioClientTransport({ + this.transport = this.transportFactory({ command: this.config.command[0]!, args: this.config.command.slice(1), env: { ...(process.env as Record), ...this.config.environment, }, - stderr: "pipe", + stderr: "pipe" as const, }); await this.client.connect(this.transport); diff --git a/src/mcp-client/remote.ts b/src/mcp-client/remote.ts index 951ebc0..05d2b81 100644 --- a/src/mcp-client/remote.ts +++ b/src/mcp-client/remote.ts @@ -4,31 +4,65 @@ import { StreamableHTTPClientTransport } from "@modelcontextprotocol/sdk/client/ import type { RemoteMCPServerConfig } from "./types"; import type { MCPClient } from "./types"; -type RemoteTransport = SSEClientTransport | StreamableHTTPClientTransport; +/** + * Transport-like interface for DI/testing + */ +export interface RemoteTransport { + close(): Promise; +} + +/** + * Client-like interface for DI/testing + */ +export interface RemoteClientLike { + connect(transport: RemoteTransport): Promise; + listTools(): Promise<{ tools: any[] }>; + callTool(request: { name: string; arguments: Record }): Promise; +} + +/** + * Options for RemoteMCPClient including DI seams for testing + */ +export interface RemoteMCPClientOptions { + /** Override Client creation for testing */ + clientFactory?: (name: string) => RemoteClientLike; + /** Override StreamableHTTP transport creation for testing */ + streamableTransportFactory?: (url: URL, headers?: Record) => RemoteTransport; + /** Override SSE transport creation for testing */ + sseTransportFactory?: (url: URL, headers: Record) => RemoteTransport; +} /** * Remote MCP client with auto-detection * Tries Streamable HTTP first (newer), falls back to SSE (legacy) */ export class RemoteMCPClient implements MCPClient { - private client: Client; + private client: RemoteClientLike; private transport: RemoteTransport | null; private toolsCache: any[] | null; private name: string; private config: RemoteMCPServerConfig; private transportType: "streamable-http" | "sse" | null; + private options: RemoteMCPClientOptions; - constructor(config: { name: string } & RemoteMCPServerConfig) { + constructor( + config: { name: string } & RemoteMCPServerConfig, + options?: RemoteMCPClientOptions + ) { this.transport = null; this.toolsCache = null; this.name = config.name; this.config = config; this.transportType = null; + this.options = options ?? {}; this.client = this.createClient(); } - private createClient(): Client { + private createClient(): RemoteClientLike { + if (this.options.clientFactory) { + return this.options.clientFactory(this.name); + } return new Client( { name: `opencode-toolbox-client-${this.name}`, @@ -38,6 +72,28 @@ export class RemoteMCPClient implements MCPClient { ); } + private createStreamableTransport(url: URL): RemoteTransport { + if (this.options.streamableTransportFactory) { + return this.options.streamableTransportFactory(url, this.config.headers); + } + return new StreamableHTTPClientTransport(url, { + requestInit: { + headers: this.config.headers, + }, + }); + } + + private createSSETransport(url: URL, headers: Record): RemoteTransport { + if (this.options.sseTransportFactory) { + return this.options.sseTransportFactory(url, headers); + } + return new SSEClientTransport(url, { + requestInit: { + headers, + }, + }); + } + async connect(): Promise { if (!this.config.url) { throw new Error(`Remote MCP server ${this.name} has no URL`); @@ -47,13 +103,9 @@ export class RemoteMCPClient implements MCPClient { this.transportType = null; // Try Streamable HTTP first (newer protocol) - let streamableTransport: StreamableHTTPClientTransport | null = null; + let streamableTransport: RemoteTransport | null = null; try { - streamableTransport = new StreamableHTTPClientTransport(url, { - requestInit: { - headers: this.config.headers, - }, - }); + streamableTransport = this.createStreamableTransport(url); await this.client.connect(streamableTransport); this.transport = streamableTransport; @@ -71,18 +123,14 @@ export class RemoteMCPClient implements MCPClient { } // Fallback to SSE transport (legacy) - let sseTransport: SSEClientTransport | null = null; + let sseTransport: RemoteTransport | null = null; try { const sseHeaders = { Accept: "text/event-stream", ...this.config.headers, }; - sseTransport = new SSEClientTransport(url, { - requestInit: { - headers: sseHeaders, - }, - }); + sseTransport = this.createSSETransport(url, sseHeaders); await this.client.connect(sseTransport); this.transport = sseTransport; diff --git a/test/unit/bm25.test.ts b/test/unit/bm25.test.ts index 5b6bd82..ea8b7b6 100644 --- a/test/unit/bm25.test.ts +++ b/test/unit/bm25.test.ts @@ -170,3 +170,242 @@ test("BM25 search signature includes arguments", () => { expect(results.length).toBeGreaterThan(0); expect(results[0]?.signature).toBe("send_email(to, subject, body)"); }); + +// --- Incremental operations --- + +test("addTool() adds single tool incrementally", () => { + const index = new BM25Index(); + + const tool = createMockTool("gmail", "send_email", "Send an email"); + index.addTool(tool); + + expect(index.size).toBe(1); + expect(index.has("gmail_send_email")).toBe(true); + + // Search should find it + const results = index.search("email"); + expect(results.length).toBe(1); + expect(results[0]?.idString).toBe("gmail_send_email"); +}); + +test("addTool() skips duplicates", () => { + const index = new BM25Index(); + + const tool = createMockTool("gmail", "send_email", "Send an email"); + index.addTool(tool); + index.addTool(tool); // Add again + + expect(index.size).toBe(1); +}); + +test("addToolsBatch() adds multiple tools", () => { + const index = new BM25Index(); + + const tools = [ + createMockTool("gmail", "send_email", "Send an email"), + createMockTool("github", "create_pr", "Create PR"), + ]; + + index.addToolsBatch(tools); + + expect(index.size).toBe(2); +}); + +test("removeTool() removes existing tool", () => { + const index = new BM25Index(); + + const tool1 = createMockTool("gmail", "send_email", "Send an email"); + const tool2 = createMockTool("github", "create_pr", "Create PR"); + + index.addToolsBatch([tool1, tool2]); + expect(index.size).toBe(2); + + const removed = index.removeTool("gmail_send_email"); + expect(removed).toBe(true); + expect(index.size).toBe(1); + expect(index.has("gmail_send_email")).toBe(false); + expect(index.has("github_create_pr")).toBe(true); +}); + +test("removeTool() returns false for non-existent tool", () => { + const index = new BM25Index(); + + const removed = index.removeTool("nonexistent_tool"); + expect(removed).toBe(false); +}); + +test("removeTool() updates document frequencies correctly", () => { + const index = new BM25Index(); + + // Add two tools with overlapping terms + const tool1 = createMockTool("a", "email_tool", "Send email messages"); + const tool2 = createMockTool("b", "email_util", "Email utility"); + + index.addToolsBatch([tool1, tool2]); + + // Remove one - the shared term "email" should still be searchable + index.removeTool("a_email_tool"); + + const results = index.search("email"); + expect(results.length).toBe(1); + expect(results[0]?.idString).toBe("b_email_util"); +}); + +test("removeTool() handles term frequency going to zero", () => { + const index = new BM25Index(); + + const tool = createMockTool("unique", "special_tool", "Unique description"); + index.addTool(tool); + + index.removeTool("unique_special_tool"); + + // Search for unique term should return nothing + const results = index.search("unique special"); + expect(results.length).toBe(0); +}); + +test("has() returns true for indexed tool", () => { + const index = new BM25Index(); + + index.addTool(createMockTool("gmail", "send", "Send")); + + expect(index.has("gmail_send")).toBe(true); +}); + +test("has() returns false for non-indexed tool", () => { + const index = new BM25Index(); + + expect(index.has("nonexistent_tool")).toBe(false); +}); + +test("getStats() returns correct statistics", () => { + const index = new BM25Index(); + + const tools = [ + createMockTool("a", "tool1", "Description one"), + createMockTool("b", "tool2", "Description two words"), + createMockTool("c", "tool3", "Description three total words"), + ]; + + index.indexTools(tools); + + const stats = index.getStats(); + expect(stats.docCount).toBe(3); + expect(stats.termCount).toBeGreaterThan(0); + expect(stats.avgDocLength).toBeGreaterThan(0); +}); + +// --- Async operations --- + +test("indexToolsAsync() indexes tools asynchronously", async () => { + const index = new BM25Index(); + + const tools = [ + createMockTool("a", "tool1", "First tool"), + createMockTool("b", "tool2", "Second tool"), + createMockTool("c", "tool3", "Third tool"), + ]; + + await index.indexToolsAsync(tools, 1); // Small chunk size to force multiple yields + + expect(index.size).toBe(3); + + const results = index.search("tool"); + expect(results.length).toBe(3); +}); + +test("indexToolsAsync() clears existing index", async () => { + const index = new BM25Index(); + + // Add initial tools + index.addTool(createMockTool("old", "old_tool", "Old tool")); + expect(index.size).toBe(1); + + // Async index should clear and replace + const newTools = [ + createMockTool("new", "new_tool", "New tool"), + ]; + + await index.indexToolsAsync(newTools); + + expect(index.size).toBe(1); + expect(index.has("new_new_tool")).toBe(true); + expect(index.has("old_old_tool")).toBe(false); +}); + +test("addToolsAsync() adds tools incrementally", async () => { + const index = new BM25Index(); + + // Add first batch + index.addTool(createMockTool("a", "tool1", "First")); + + // Add more async + const moreTools = [ + createMockTool("b", "tool2", "Second"), + createMockTool("c", "tool3", "Third"), + ]; + + await index.addToolsAsync(moreTools, 1); + + expect(index.size).toBe(3); +}); + +test("addToolsAsync() yields between chunks", async () => { + const index = new BM25Index(); + + // Create many tools to force multiple chunks + const tools: CatalogTool[] = []; + for (let i = 0; i < 10; i++) { + tools.push(createMockTool(`server${i}`, `tool${i}`, `Description ${i}`)); + } + + // Use small chunk size + await index.addToolsAsync(tools, 3); + + expect(index.size).toBe(10); +}); + +test("async and sync indexing produce same results", async () => { + const tools = [ + createMockTool("a", "send_email", "Send email"), + createMockTool("b", "search_web", "Search the web"), + ]; + + const syncIndex = new BM25Index(); + syncIndex.indexTools(tools); + + const asyncIndex = new BM25Index(); + await asyncIndex.indexToolsAsync(tools); + + expect(asyncIndex.size).toBe(syncIndex.size); + + const syncResults = syncIndex.search("email", 5); + const asyncResults = asyncIndex.search("email", 5); + + expect(asyncResults.length).toBe(syncResults.length); + expect(asyncResults[0]?.idString).toBe(syncResults[0]?.idString); +}); + +test("incremental avg doc length calculation", () => { + const index = new BM25Index(); + + // Add tools one at a time + index.addTool(createMockTool("a", "short", "A")); + index.addTool(createMockTool("b", "medium", "A B C")); + index.addTool(createMockTool("c", "long", "A B C D E F G H")); + + const stats = index.getStats(); + expect(stats.avgDocLength).toBeGreaterThan(0); + expect(stats.docCount).toBe(3); +}); + +test("avg doc length is zero after clearing", () => { + const index = new BM25Index(); + + index.addTool(createMockTool("a", "tool", "Description")); + index.clear(); + + const stats = index.getStats(); + expect(stats.avgDocLength).toBe(0); + expect(stats.docCount).toBe(0); +}); diff --git a/test/unit/local-client.test.ts b/test/unit/local-client.test.ts new file mode 100644 index 0000000..9d2ae63 --- /dev/null +++ b/test/unit/local-client.test.ts @@ -0,0 +1,339 @@ +import { test, expect, describe } from "bun:test"; +import { LocalMCPClient, type Transport, type LocalMCPClientOptions } from "../../src/mcp-client/local"; + +/** + * Create a mock client factory for testing + */ +function createMockClientFactory(options?: { + failConnect?: boolean; + failListTools?: boolean; + tools?: any[]; + callToolResult?: any; +}) { + return (name: string) => { + let connected = false; + return { + async connect(transport: Transport): Promise { + if (options?.failConnect) { + throw new Error("Connection failed"); + } + connected = true; + }, + async listTools(): Promise<{ tools: any[] }> { + if (options?.failListTools) { + throw new Error("List tools failed"); + } + return { tools: options?.tools ?? [] }; + }, + async callTool(request: { name: string; arguments: Record }): Promise { + return options?.callToolResult ?? { content: [{ type: "text", text: "ok" }] }; + }, + }; + }; +} + +/** + * Create a mock transport factory for testing + */ +function createMockTransportFactory(options?: { + failClose?: boolean; + onClose?: () => void; +}) { + return (opts: { + command: string; + args: string[]; + env: Record; + stderr: "pipe" | "inherit" | "ignore"; + }): Transport => { + return { + async close(): Promise { + if (options?.failClose) { + throw new Error("Close failed"); + } + options?.onClose?.(); + }, + }; + }; +} + +describe("LocalMCPClient", () => { + describe("constructor", () => { + test("creates client with default factories", () => { + // This should not throw + const client = new LocalMCPClient({ + name: "test", + type: "local", + command: ["echo", "hello"], + }); + expect(client).toBeDefined(); + }); + + test("accepts custom client factory", () => { + const mockFactory = createMockClientFactory(); + const client = new LocalMCPClient( + { name: "test", type: "local", command: ["echo"] }, + { clientFactory: mockFactory } + ); + expect(client).toBeDefined(); + }); + + test("accepts custom transport factory", () => { + const mockTransport = createMockTransportFactory(); + const client = new LocalMCPClient( + { name: "test", type: "local", command: ["echo"] }, + { transportFactory: mockTransport } + ); + expect(client).toBeDefined(); + }); + }); + + describe("connect", () => { + test("throws when command is empty", async () => { + const client = new LocalMCPClient( + { name: "test", type: "local", command: [] }, + { clientFactory: createMockClientFactory() } + ); + + await expect(client.connect()).rejects.toThrow("has no command"); + }); + + test("throws when command is undefined", async () => { + const client = new LocalMCPClient( + { name: "test", type: "local" }, + { clientFactory: createMockClientFactory() } + ); + + await expect(client.connect()).rejects.toThrow("has no command"); + }); + + test("connects successfully with valid command", async () => { + const client = new LocalMCPClient( + { name: "test", type: "local", command: ["node", "server.js"] }, + { + clientFactory: createMockClientFactory(), + transportFactory: createMockTransportFactory(), + } + ); + + await client.connect(); + // Should not throw + }); + + test("creates transport with correct config", async () => { + let capturedOpts: any = null; + const client = new LocalMCPClient( + { + name: "test", + type: "local", + command: ["node", "server.js", "--port", "3000"], + environment: { MY_VAR: "value" }, + }, + { + clientFactory: createMockClientFactory(), + transportFactory: (opts) => { + capturedOpts = opts; + return { close: async () => {} }; + }, + } + ); + + await client.connect(); + + expect(capturedOpts).not.toBeNull(); + expect(capturedOpts.command).toBe("node"); + expect(capturedOpts.args).toEqual(["server.js", "--port", "3000"]); + expect(capturedOpts.env.MY_VAR).toBe("value"); + expect(capturedOpts.stderr).toBe("pipe"); + }); + + test("merges environment with process.env", async () => { + let capturedEnv: Record = {}; + const client = new LocalMCPClient( + { + name: "test", + type: "local", + command: ["node"], + environment: { CUSTOM: "value" }, + }, + { + clientFactory: createMockClientFactory(), + transportFactory: (opts) => { + capturedEnv = opts.env; + return { close: async () => {} }; + }, + } + ); + + await client.connect(); + + // Should include process.env PATH (or similar) + expect(capturedEnv.CUSTOM).toBe("value"); + // Should include at least some process.env vars + expect(Object.keys(capturedEnv).length).toBeGreaterThan(1); + }); + + test("propagates connection errors", async () => { + const client = new LocalMCPClient( + { name: "test", type: "local", command: ["node"] }, + { + clientFactory: createMockClientFactory({ failConnect: true }), + transportFactory: createMockTransportFactory(), + } + ); + + await expect(client.connect()).rejects.toThrow("Connection failed"); + }); + }); + + describe("listTools", () => { + test("returns tools from client", async () => { + const mockTools = [ + { name: "tool1", description: "Tool 1" }, + { name: "tool2", description: "Tool 2" }, + ]; + + const client = new LocalMCPClient( + { name: "test", type: "local", command: ["node"] }, + { + clientFactory: createMockClientFactory({ tools: mockTools }), + transportFactory: createMockTransportFactory(), + } + ); + + await client.connect(); + const tools = await client.listTools(); + + expect(tools).toEqual(mockTools); + }); + + test("caches tools after first call", async () => { + const mockTools = [{ name: "tool1" }]; + const client = new LocalMCPClient( + { name: "test", type: "local", command: ["node"] }, + { + clientFactory: createMockClientFactory({ tools: mockTools }), + transportFactory: createMockTransportFactory(), + } + ); + + await client.connect(); + await client.listTools(); + + const cached = client.getCachedTools(); + expect(cached).toEqual(mockTools); + }); + + test("propagates listTools errors", async () => { + const client = new LocalMCPClient( + { name: "test", type: "local", command: ["node"] }, + { + clientFactory: createMockClientFactory({ failListTools: true }), + transportFactory: createMockTransportFactory(), + } + ); + + await client.connect(); + await expect(client.listTools()).rejects.toThrow("List tools failed"); + }); + }); + + describe("callTool", () => { + test("forwards call to underlying client", async () => { + const expectedResult = { content: [{ type: "text", text: "result" }] }; + const client = new LocalMCPClient( + { name: "test", type: "local", command: ["node"] }, + { + clientFactory: createMockClientFactory({ callToolResult: expectedResult }), + transportFactory: createMockTransportFactory(), + } + ); + + await client.connect(); + const result = await client.callTool("test_tool", { arg1: "value" }); + + expect(result).toEqual(expectedResult); + }); + }); + + describe("close", () => { + test("closes transport and clears cache", async () => { + let transportClosed = false; + const client = new LocalMCPClient( + { name: "test", type: "local", command: ["node"] }, + { + clientFactory: createMockClientFactory({ tools: [{ name: "t1" }] }), + transportFactory: createMockTransportFactory({ + onClose: () => { + transportClosed = true; + }, + }), + } + ); + + await client.connect(); + await client.listTools(); + expect(client.getCachedTools()).not.toBeNull(); + + await client.close(); + + expect(transportClosed).toBe(true); + expect(client.getCachedTools()).toBeNull(); + }); + + test("is safe to call multiple times", async () => { + const client = new LocalMCPClient( + { name: "test", type: "local", command: ["node"] }, + { + clientFactory: createMockClientFactory(), + transportFactory: createMockTransportFactory(), + } + ); + + await client.connect(); + await client.close(); + await client.close(); // Should not throw + }); + + test("is safe to call without connect", async () => { + const client = new LocalMCPClient( + { name: "test", type: "local", command: ["node"] }, + { + clientFactory: createMockClientFactory(), + transportFactory: createMockTransportFactory(), + } + ); + + await client.close(); // Should not throw + }); + }); + + describe("getCachedTools", () => { + test("returns null before listTools is called", async () => { + const client = new LocalMCPClient( + { name: "test", type: "local", command: ["node"] }, + { + clientFactory: createMockClientFactory(), + transportFactory: createMockTransportFactory(), + } + ); + + expect(client.getCachedTools()).toBeNull(); + }); + + test("returns tools after listTools is called", async () => { + const mockTools = [{ name: "tool1" }]; + const client = new LocalMCPClient( + { name: "test", type: "local", command: ["node"] }, + { + clientFactory: createMockClientFactory({ tools: mockTools }), + transportFactory: createMockTransportFactory(), + } + ); + + await client.connect(); + await client.listTools(); + + expect(client.getCachedTools()).toEqual(mockTools); + }); + }); +}); diff --git a/test/unit/profiler.test.ts b/test/unit/profiler.test.ts new file mode 100644 index 0000000..5c6f319 --- /dev/null +++ b/test/unit/profiler.test.ts @@ -0,0 +1,279 @@ +import { test, expect, describe, beforeEach } from "bun:test"; +import { Profiler, globalProfiler } from "../../src/profiler"; + +describe("Profiler", () => { + let profiler: Profiler; + + beforeEach(() => { + profiler = new Profiler(); + }); + + describe("mark and measure", () => { + test("mark() stores timestamp", () => { + profiler.mark("test-mark"); + // Measure should work after mark + const duration = profiler.measure("test-measure", "test-mark"); + expect(duration).toBeGreaterThanOrEqual(0); + }); + + test("measure() returns -1 for missing mark", () => { + const duration = profiler.measure("test-measure", "nonexistent"); + expect(duration).toBe(-1); + }); + + test("measure() calculates duration correctly", async () => { + profiler.mark("start"); + await new Promise((resolve) => setTimeout(resolve, 10)); + const duration = profiler.measure("elapsed", "start"); + expect(duration).toBeGreaterThanOrEqual(9); + expect(duration).toBeLessThan(100); + }); + + test("measure() stores measurement for later retrieval", () => { + profiler.mark("m1"); + profiler.measure("duration", "m1"); + const stats = profiler.getStats("duration"); + expect(stats).not.toBeNull(); + expect(stats!.count).toBe(1); + }); + }); + + describe("record", () => { + test("record() adds direct measurement", () => { + profiler.record("direct", 100); + profiler.record("direct", 200); + const stats = profiler.getStats("direct"); + expect(stats).not.toBeNull(); + expect(stats!.count).toBe(2); + expect(stats!.min).toBe(100); + expect(stats!.max).toBe(200); + }); + }); + + describe("getStats", () => { + test("getStats() returns null for unknown metric", () => { + const stats = profiler.getStats("unknown"); + expect(stats).toBeNull(); + }); + + test("getStats() calculates min/max/avg correctly", () => { + profiler.record("test", 10); + profiler.record("test", 20); + profiler.record("test", 30); + const stats = profiler.getStats("test"); + expect(stats).not.toBeNull(); + expect(stats!.min).toBe(10); + expect(stats!.max).toBe(30); + expect(stats!.avg).toBe(20); + expect(stats!.total).toBe(60); + }); + + test("getStats() calculates percentiles", () => { + // Add 100 values from 1 to 100 + for (let i = 1; i <= 100; i++) { + profiler.record("perc", i); + } + const stats = profiler.getStats("perc"); + expect(stats).not.toBeNull(); + expect(stats!.p50).toBe(50); + expect(stats!.p95).toBe(95); + expect(stats!.p99).toBe(99); + }); + + test("percentile handles empty array", () => { + // This is covered by the calculateStats null return for empty + const stats = profiler.getStats("empty"); + expect(stats).toBeNull(); + }); + + test("percentile handles single element", () => { + profiler.record("single", 42); + const stats = profiler.getStats("single"); + expect(stats).not.toBeNull(); + expect(stats!.p50).toBe(42); + expect(stats!.p95).toBe(42); + expect(stats!.p99).toBe(42); + }); + }); + + describe("initialization tracking", () => { + test("initStart() sets state to initializing", () => { + expect(profiler.getInitState()).toBe("idle"); + profiler.initStart(); + expect(profiler.getInitState()).toBe("initializing"); + }); + + test("initComplete() sets final state", () => { + profiler.initStart(); + profiler.initComplete("ready"); + expect(profiler.getInitState()).toBe("ready"); + }); + + test("initComplete() with degraded state", () => { + profiler.initStart(); + profiler.initComplete("degraded"); + expect(profiler.getInitState()).toBe("degraded"); + }); + + test("initComplete() with partial state", () => { + profiler.initStart(); + profiler.initComplete("partial"); + expect(profiler.getInitState()).toBe("partial"); + }); + + test("getInitDuration() returns null before init", () => { + expect(profiler.getInitDuration()).toBeNull(); + }); + + test("getInitDuration() returns ongoing duration during init", async () => { + profiler.initStart(); + await new Promise((resolve) => setTimeout(resolve, 10)); + const duration = profiler.getInitDuration(); + expect(duration).not.toBeNull(); + expect(duration!).toBeGreaterThanOrEqual(9); + }); + + test("getInitDuration() returns final duration after init", async () => { + profiler.initStart(); + await new Promise((resolve) => setTimeout(resolve, 10)); + profiler.initComplete("ready"); + const duration = profiler.getInitDuration(); + expect(duration).not.toBeNull(); + expect(duration!).toBeGreaterThanOrEqual(9); + }); + }); + + describe("server metrics", () => { + test("recordServerConnect stores connected metrics", () => { + profiler.recordServerConnect("server1", 100, 5, "connected"); + const report = profiler.export(); + expect(report.initialization.servers).toHaveLength(1); + expect(report.initialization.servers[0]).toEqual({ + name: "server1", + connectTime: 100, + toolCount: 5, + status: "connected", + error: undefined, + }); + }); + + test("recordServerConnect stores error metrics", () => { + profiler.recordServerConnect("server2", -1, 0, "error", "Connection failed"); + const report = profiler.export(); + const server = report.initialization.servers.find((s) => s.name === "server2"); + expect(server).toBeDefined(); + expect(server!.status).toBe("error"); + expect(server!.error).toBe("Connection failed"); + }); + }); + + describe("indexing metrics", () => { + test("recordIndexBuild stores build time and count", () => { + profiler.recordIndexBuild(50, 100); + const report = profiler.export(); + expect(report.indexing.buildTime).toBe(50); + expect(report.indexing.toolCount).toBe(100); + }); + + test("recordIncrementalUpdate increments counter and tool count", () => { + profiler.recordIncrementalUpdate(10); + profiler.recordIncrementalUpdate(5); + const report = profiler.export(); + expect(report.indexing.incrementalUpdates).toBe(2); + expect(report.indexing.toolCount).toBe(15); + }); + }); + + describe("export", () => { + test("export() returns full report structure", () => { + profiler.initStart(); + profiler.recordServerConnect("test", 100, 3, "connected"); + profiler.recordIndexBuild(20, 10); + profiler.record("search.bm25", 5); + profiler.record("search.regex", 3); + profiler.record("tool.execute", 50); + profiler.initComplete("ready"); + + const report = profiler.export(); + + expect(report.timestamp).toBeDefined(); + expect(typeof report.uptime).toBe("number"); + expect(report.initialization.state).toBe("ready"); + expect(report.initialization.duration).not.toBeNull(); + expect(report.indexing.buildTime).toBe(20); + expect(report.searches.bm25).not.toBeNull(); + expect(report.searches.regex).not.toBeNull(); + expect(report.executions).not.toBeNull(); + }); + + test("export() handles no searches/executions", () => { + const report = profiler.export(); + expect(report.searches.bm25).toBeNull(); + expect(report.searches.regex).toBeNull(); + expect(report.executions).toBeNull(); + }); + }); + + describe("reset", () => { + test("reset() clears all state", () => { + profiler.mark("test"); + profiler.record("metric", 100); + profiler.recordServerConnect("srv", 50, 2, "connected"); + profiler.initStart(); + profiler.initComplete("ready"); + profiler.recordIndexBuild(10, 5); + profiler.recordIncrementalUpdate(3); + + profiler.reset(); + + expect(profiler.getInitState()).toBe("idle"); + expect(profiler.getInitDuration()).toBeNull(); + expect(profiler.getStats("metric")).toBeNull(); + + const report = profiler.export(); + expect(report.initialization.servers).toHaveLength(0); + expect(report.indexing.buildTime).toBeNull(); + expect(report.indexing.toolCount).toBe(0); + expect(report.indexing.incrementalUpdates).toBe(0); + }); + }); + + describe("startTimer", () => { + test("startTimer() returns function that records duration", async () => { + const done = profiler.startTimer("timed-op"); + await new Promise((resolve) => setTimeout(resolve, 10)); + const duration = done(); + + expect(duration).toBeGreaterThanOrEqual(9); + + const stats = profiler.getStats("timed-op"); + expect(stats).not.toBeNull(); + expect(stats!.count).toBe(1); + expect(stats!.min).toBeGreaterThanOrEqual(9); + }); + + test("startTimer() can be called multiple times", () => { + const done1 = profiler.startTimer("multi"); + const done2 = profiler.startTimer("multi"); + done1(); + done2(); + + const stats = profiler.getStats("multi"); + expect(stats!.count).toBe(2); + }); + }); +}); + +describe("globalProfiler", () => { + test("globalProfiler is a Profiler instance", () => { + expect(globalProfiler).toBeInstanceOf(Profiler); + }); + + test("globalProfiler can record metrics", () => { + // Just verify it works - don't pollute global state too much + const timer = globalProfiler.startTimer("global-test"); + timer(); + // If we get here without error, it works + expect(true).toBe(true); + }); +}); diff --git a/test/unit/remote-client.test.ts b/test/unit/remote-client.test.ts new file mode 100644 index 0000000..a2e965c --- /dev/null +++ b/test/unit/remote-client.test.ts @@ -0,0 +1,484 @@ +import { test, expect, describe } from "bun:test"; +import { + RemoteMCPClient, + type RemoteTransport, + type RemoteClientLike, + type RemoteMCPClientOptions, +} from "../../src/mcp-client/remote"; + +/** + * Create a mock client factory for testing + */ +function createMockClientFactory(options?: { + failConnect?: boolean; + failListTools?: boolean; + tools?: any[]; + callToolResult?: any; +}): (name: string) => RemoteClientLike { + return (name: string) => ({ + async connect(transport: RemoteTransport): Promise { + if (options?.failConnect) { + throw new Error("Connection failed"); + } + }, + async listTools(): Promise<{ tools: any[] }> { + if (options?.failListTools) { + throw new Error("List tools failed"); + } + return { tools: options?.tools ?? [] }; + }, + async callTool(request: { name: string; arguments: Record }): Promise { + return options?.callToolResult ?? { content: [{ type: "text", text: "ok" }] }; + }, + }); +} + +/** + * Create a mock transport factory + */ +function createMockTransportFactory(options?: { + failConnect?: boolean; + onClose?: () => void; +}): () => RemoteTransport { + return () => ({ + async close(): Promise { + options?.onClose?.(); + }, + }); +} + +describe("RemoteMCPClient", () => { + describe("constructor", () => { + test("creates client with default factories", () => { + const client = new RemoteMCPClient({ + name: "test", + type: "remote", + url: "https://example.com/mcp", + }); + expect(client).toBeDefined(); + }); + + test("accepts custom client factory", () => { + const client = new RemoteMCPClient( + { name: "test", type: "remote", url: "https://example.com" }, + { clientFactory: createMockClientFactory() } + ); + expect(client).toBeDefined(); + }); + }); + + describe("connect", () => { + test("throws when URL is missing", async () => { + const client = new RemoteMCPClient( + { name: "test", type: "remote" }, + { clientFactory: createMockClientFactory() } + ); + + await expect(client.connect()).rejects.toThrow("has no URL"); + }); + + test("connects successfully with streamable HTTP", async () => { + const client = new RemoteMCPClient( + { name: "test", type: "remote", url: "https://example.com/mcp" }, + { + clientFactory: createMockClientFactory(), + streamableTransportFactory: createMockTransportFactory(), + } + ); + + await client.connect(); + expect(client.getTransportType()).toBe("streamable-http"); + }); + + test("falls back to SSE when streamable fails", async () => { + let streamableAttempts = 0; + let sseAttempts = 0; + + const client = new RemoteMCPClient( + { name: "test", type: "remote", url: "https://example.com/mcp" }, + { + clientFactory: (name) => ({ + async connect(transport: RemoteTransport): Promise { + // Fail on first attempt (streamable), succeed on second (SSE) + if (streamableAttempts === 0 && sseAttempts === 0) { + streamableAttempts++; + throw new Error("Streamable not supported"); + } + sseAttempts++; + }, + async listTools() { + return { tools: [] }; + }, + async callTool() { + return {}; + }, + }), + streamableTransportFactory: () => ({ + async close() {}, + }), + sseTransportFactory: () => ({ + async close() {}, + }), + } + ); + + await client.connect(); + + expect(streamableAttempts).toBe(1); + expect(sseAttempts).toBe(1); + expect(client.getTransportType()).toBe("sse"); + }); + + test("throws when both transports fail", async () => { + const client = new RemoteMCPClient( + { name: "test", type: "remote", url: "https://example.com/mcp" }, + { + clientFactory: createMockClientFactory({ failConnect: true }), + streamableTransportFactory: createMockTransportFactory(), + sseTransportFactory: createMockTransportFactory(), + } + ); + + await expect(client.connect()).rejects.toThrow("Connection failed"); + expect(client.getTransportType()).toBeNull(); + }); + + test("passes headers to streamable transport", async () => { + let capturedUrl: URL | null = null; + let capturedHeaders: Record | undefined; + + const client = new RemoteMCPClient( + { + name: "test", + type: "remote", + url: "https://example.com/mcp", + headers: { Authorization: "Bearer token123" }, + }, + { + clientFactory: createMockClientFactory(), + streamableTransportFactory: (url, headers) => { + capturedUrl = url; + capturedHeaders = headers; + return { close: async () => {} }; + }, + } + ); + + await client.connect(); + + expect(capturedUrl?.href).toBe("https://example.com/mcp"); + expect(capturedHeaders?.Authorization).toBe("Bearer token123"); + }); + + test("passes headers to SSE transport with Accept header", async () => { + let capturedHeaders: Record | null = null; + + const client = new RemoteMCPClient( + { + name: "test", + type: "remote", + url: "https://example.com/mcp", + headers: { "X-Custom": "value" }, + }, + { + clientFactory: (name) => ({ + async connect(transport: RemoteTransport): Promise { + // Fail first (streamable), succeed second (SSE) + if (capturedHeaders === null) { + throw new Error("Streamable failed"); + } + }, + async listTools() { + return { tools: [] }; + }, + async callTool() { + return {}; + }, + }), + streamableTransportFactory: () => ({ + async close() {}, + }), + sseTransportFactory: (url, headers) => { + capturedHeaders = headers; + return { close: async () => {} }; + }, + } + ); + + await client.connect(); + + expect(capturedHeaders).not.toBeNull(); + expect(capturedHeaders!.Accept).toBe("text/event-stream"); + expect(capturedHeaders!["X-Custom"]).toBe("value"); + }); + + test("cleans up streamable transport on failure", async () => { + let streamableClosed = false; + + const client = new RemoteMCPClient( + { name: "test", type: "remote", url: "https://example.com/mcp" }, + { + clientFactory: createMockClientFactory({ failConnect: true }), + streamableTransportFactory: () => ({ + async close() { + streamableClosed = true; + }, + }), + sseTransportFactory: createMockTransportFactory(), + } + ); + + await expect(client.connect()).rejects.toThrow(); + expect(streamableClosed).toBe(true); + }); + + test("cleans up SSE transport on failure", async () => { + let sseClosed = false; + + const client = new RemoteMCPClient( + { name: "test", type: "remote", url: "https://example.com/mcp" }, + { + clientFactory: createMockClientFactory({ failConnect: true }), + streamableTransportFactory: createMockTransportFactory(), + sseTransportFactory: () => ({ + async close() { + sseClosed = true; + }, + }), + } + ); + + await expect(client.connect()).rejects.toThrow(); + expect(sseClosed).toBe(true); + }); + + test("resets client between transport attempts", async () => { + let clientCreations = 0; + + const client = new RemoteMCPClient( + { name: "test", type: "remote", url: "https://example.com/mcp" }, + { + clientFactory: (name) => { + clientCreations++; + return { + async connect(transport: RemoteTransport): Promise { + // Fail on first attempt + if (clientCreations === 1) { + throw new Error("Streamable failed"); + } + }, + async listTools() { + return { tools: [] }; + }, + async callTool() { + return {}; + }, + }; + }, + streamableTransportFactory: createMockTransportFactory(), + sseTransportFactory: createMockTransportFactory(), + } + ); + + await client.connect(); + + // Should have created client twice: initial + after streamable failure + expect(clientCreations).toBe(2); + }); + }); + + describe("listTools", () => { + test("returns tools from client", async () => { + const mockTools = [{ name: "tool1" }, { name: "tool2" }]; + + const client = new RemoteMCPClient( + { name: "test", type: "remote", url: "https://example.com/mcp" }, + { + clientFactory: createMockClientFactory({ tools: mockTools }), + streamableTransportFactory: createMockTransportFactory(), + } + ); + + await client.connect(); + const tools = await client.listTools(); + + expect(tools).toEqual(mockTools); + }); + + test("caches tools", async () => { + const mockTools = [{ name: "tool1" }]; + + const client = new RemoteMCPClient( + { name: "test", type: "remote", url: "https://example.com/mcp" }, + { + clientFactory: createMockClientFactory({ tools: mockTools }), + streamableTransportFactory: createMockTransportFactory(), + } + ); + + await client.connect(); + await client.listTools(); + + expect(client.getCachedTools()).toEqual(mockTools); + }); + }); + + describe("callTool", () => { + test("forwards call to client", async () => { + const expectedResult = { result: "success" }; + + const client = new RemoteMCPClient( + { name: "test", type: "remote", url: "https://example.com/mcp" }, + { + clientFactory: createMockClientFactory({ callToolResult: expectedResult }), + streamableTransportFactory: createMockTransportFactory(), + } + ); + + await client.connect(); + const result = await client.callTool("my_tool", { arg: "value" }); + + expect(result).toEqual(expectedResult); + }); + }); + + describe("close", () => { + test("closes transport and clears state", async () => { + let transportClosed = false; + + const client = new RemoteMCPClient( + { name: "test", type: "remote", url: "https://example.com/mcp" }, + { + clientFactory: createMockClientFactory({ tools: [{ name: "t1" }] }), + streamableTransportFactory: () => ({ + async close() { + transportClosed = true; + }, + }), + } + ); + + await client.connect(); + await client.listTools(); + expect(client.getCachedTools()).not.toBeNull(); + expect(client.getTransportType()).toBe("streamable-http"); + + await client.close(); + + expect(transportClosed).toBe(true); + expect(client.getCachedTools()).toBeNull(); + expect(client.getTransportType()).toBeNull(); + }); + + test("is safe to call multiple times", async () => { + const client = new RemoteMCPClient( + { name: "test", type: "remote", url: "https://example.com/mcp" }, + { + clientFactory: createMockClientFactory(), + streamableTransportFactory: createMockTransportFactory(), + } + ); + + await client.connect(); + await client.close(); + await client.close(); // Should not throw + }); + + test("is safe to call without connect", async () => { + const client = new RemoteMCPClient( + { name: "test", type: "remote", url: "https://example.com/mcp" }, + { + clientFactory: createMockClientFactory(), + streamableTransportFactory: createMockTransportFactory(), + } + ); + + await client.close(); // Should not throw + }); + }); + + describe("getTransportType", () => { + test("returns null before connect", () => { + const client = new RemoteMCPClient( + { name: "test", type: "remote", url: "https://example.com/mcp" }, + { clientFactory: createMockClientFactory() } + ); + + expect(client.getTransportType()).toBeNull(); + }); + + test("returns 'streamable-http' on success", async () => { + const client = new RemoteMCPClient( + { name: "test", type: "remote", url: "https://example.com/mcp" }, + { + clientFactory: createMockClientFactory(), + streamableTransportFactory: createMockTransportFactory(), + } + ); + + await client.connect(); + expect(client.getTransportType()).toBe("streamable-http"); + }); + + test("returns 'sse' on fallback", async () => { + let isFirstAttempt = true; + + const client = new RemoteMCPClient( + { name: "test", type: "remote", url: "https://example.com/mcp" }, + { + clientFactory: (name) => ({ + async connect(transport: RemoteTransport): Promise { + if (isFirstAttempt) { + isFirstAttempt = false; + throw new Error("Streamable failed"); + } + }, + async listTools() { + return { tools: [] }; + }, + async callTool() { + return {}; + }, + }), + streamableTransportFactory: createMockTransportFactory(), + sseTransportFactory: createMockTransportFactory(), + } + ); + + await client.connect(); + expect(client.getTransportType()).toBe("sse"); + }); + }); + + describe("getCachedTools", () => { + test("returns null before listTools", async () => { + const client = new RemoteMCPClient( + { name: "test", type: "remote", url: "https://example.com/mcp" }, + { + clientFactory: createMockClientFactory(), + streamableTransportFactory: createMockTransportFactory(), + } + ); + + await client.connect(); + expect(client.getCachedTools()).toBeNull(); + }); + + test("returns tools after listTools", async () => { + const mockTools = [{ name: "cached_tool" }]; + + const client = new RemoteMCPClient( + { name: "test", type: "remote", url: "https://example.com/mcp" }, + { + clientFactory: createMockClientFactory({ tools: mockTools }), + streamableTransportFactory: createMockTransportFactory(), + } + ); + + await client.connect(); + await client.listTools(); + + expect(client.getCachedTools()).toEqual(mockTools); + }); + }); +});