diff --git a/knip.json b/knip.json index b777d8c2a..036404ee4 100644 --- a/knip.json +++ b/knip.json @@ -7,7 +7,6 @@ "docs" ], "workspaces": { - "packages/appkit": {}, "packages/appkit-ui": { "ignoreDependencies": ["tailwindcss", "tw-animate-css"] } @@ -17,6 +16,11 @@ "**/*.example.tsx", "**/*.css", "packages/appkit/src/plugins/vector-search/**", + "packages/appkit/src/plugin/index.ts", + "packages/appkit/src/plugins/agents/index.ts", + "packages/appkit/src/plugins/agents/tools/index.ts", + "packages/appkit/src/plugins/agents/from-plugin.ts", + "packages/appkit/src/plugins/agents/load-agents.ts", "template/**", "tools/**", "docs/**" diff --git a/packages/appkit/src/agents/databricks.ts b/packages/appkit/src/agents/databricks.ts new file mode 100644 index 000000000..3b902c2ef --- /dev/null +++ b/packages/appkit/src/agents/databricks.ts @@ -0,0 +1,799 @@ +import type { + AgentAdapter, + AgentEvent, + AgentInput, + AgentRunContext, + AgentToolDefinition, +} from "shared"; +import { stream as servingStream } from "../connectors/serving/client"; + +/** Default cap for a single incomplete SSE line tail (DoS guard). */ +const DEFAULT_MAX_SSE_LINE_CHARS = 1024 * 1024; + +/** Default cap for accumulated assistant text from `delta.content`. */ +const DEFAULT_MAX_STREAM_TEXT_CHARS = 4 * 1024 * 1024; + +/** Default cap for accumulated JSON arguments per streamed tool call index. */ +const DEFAULT_MAX_TOOL_ARGUMENT_CHARS = 2 * 1024 * 1024; + +/** Cap text length before running Python-style tool-call regex (ReDoS guard). */ +const PYTHON_STYLE_TOOL_PARSE_MAX_INPUT = 64 * 1024; + +/** Fallback HTTP timeout when the raw fetch adapter path receives no AbortSignal from the runner. */ +const RAW_FETCH_DEFAULT_TIMEOUT_MS = 120_000; + +function isRecord(value: unknown): value is Record { + return typeof value === "object" && value !== null; +} + +function extractLlamaToolJsonSlice(text: string): string | undefined { + const start = text.indexOf("[{"); + if (start < 0) return undefined; + const endBracket = text.lastIndexOf("}]"); + if (endBracket < start) return undefined; + return text.slice(start, endBracket + 2); +} + +/** OpenAI SSE payload: `{ choices: [{ delta }] }`. */ +function openAiChoicesDelta(parsed: unknown): unknown { + if (!isRecord(parsed)) return undefined; + const choices = parsed.choices; + if (!Array.isArray(choices) || choices.length < 1) return undefined; + const first = choices[0]; + if (!isRecord(first)) return undefined; + return first.delta; +} + +function isStreamingDeltaToolCall(value: unknown): value is DeltaToolCall { + if (!isRecord(value)) return false; + return typeof value.index === "number"; +} + +function throwIfExceedsStreamLimit( + label: string, + currentLength: number, + chunk: string, + max: number, +): void { + if (currentLength + chunk.length > max) { + throw new Error( + `DatabricksAdapter: ${label} exceeds configured limit (${max} UTF-16 code units)`, + ); + } +} + +/** + * Transport shim: given an OpenAI-compatible request body, returns the raw + * SSE byte stream from the serving endpoint. Injected at construction time so + * callers can swap in the workspace SDK (factory paths), a bare `fetch` + * (the raw constructor), or a test fake. + */ +type StreamBody = ( + body: Record, + signal?: AbortSignal, +) => Promise>; + +/** + * Escape-hatch options: provide an `endpointUrl` + `authenticate()` and the + * adapter uses a bare `fetch()` to call it. Useful for tests and for pointing + * the adapter at non-workspace endpoints (reverse proxies, mocks). + */ +interface RawFetchAdapterOptions { + endpointUrl: string; + authenticate: () => Promise>; + maxSteps?: number; + maxTokens?: number; + /** Max length of one SSE line (including an incomplete tail in the buffer). */ + maxSseLineChars?: number; + /** Max total length of assistant `delta.content` across the stream. */ + maxStreamTextChars?: number; + /** Max length of streamed `function.arguments` per tool call index. */ + maxToolArgumentsChars?: number; +} + +/** + * Preferred options: caller provides the transport function directly. + * The `fromServingEndpoint` / `fromModelServing` factories use this to route + * through `connectors/serving/stream`, which centralises URL encoding, auth + * via the SDK's `apiClient.request`, and any future retries/telemetry. + */ +interface StreamBodyAdapterOptions { + streamBody: StreamBody; + maxSteps?: number; + maxTokens?: number; + maxSseLineChars?: number; + maxStreamTextChars?: number; + maxToolArgumentsChars?: number; +} + +type DatabricksAdapterOptions = + | RawFetchAdapterOptions + | StreamBodyAdapterOptions; + +function isStreamBodyOptions( + o: DatabricksAdapterOptions, +): o is StreamBodyAdapterOptions { + return "streamBody" in o; +} + +/** + * Duck-typed subset of the Databricks SDK `WorkspaceClient`. Callers of + * `fromServingEndpoint` and `fromModelServing` pass a real `WorkspaceClient`, + * but we only need the `apiClient.request` surface — so we declare the minimal + * interface rather than importing the SDK type directly. This keeps the adapter + * free of a hard compile-time dependency on `@databricks/sdk-experimental`. + */ +interface WorkspaceClientLike { + apiClient: { + request(options: Record): Promise; + }; +} + +interface ServingEndpointOptions { + workspaceClient: WorkspaceClientLike; + endpointName: string; + maxSteps?: number; + maxTokens?: number; + maxSseLineChars?: number; + maxStreamTextChars?: number; + maxToolArgumentsChars?: number; +} + +interface ModelServingOptions { + maxSteps?: number; + maxTokens?: number; + workspaceClient?: WorkspaceClientLike; + maxSseLineChars?: number; + maxStreamTextChars?: number; + maxToolArgumentsChars?: number; +} + +interface OpenAIMessage { + role: "system" | "user" | "assistant" | "tool"; + content: string | null; + tool_calls?: OpenAIToolCall[]; + tool_call_id?: string; +} + +interface OpenAIToolCall { + id: string; + type: "function"; + function: { name: string; arguments: string }; +} + +interface OpenAITool { + type: "function"; + function: { + name: string; + description: string; + parameters: unknown; + }; +} + +interface DeltaToolCall { + index: number; + id?: string; + type?: string; + function?: { name?: string; arguments?: string }; +} + +/** + * Adapter that talks directly to Databricks Model Serving `/invocations` endpoint. + * + * No dependency on the Vercel AI SDK or LangChain. Uses raw `fetch()` to POST + * OpenAI-compatible payloads and parses the SSE stream itself. Calls + * `authenticate()` per-request so tokens are always fresh. + * + * Handles both structured `tool_calls` responses and text-based tool call + * fallback parsing for models that output tool calls as text. + * + * @example Using the factory (recommended) + * ```ts + * import { createApp, createAgent, agents } from "@databricks/appkit"; + * import { DatabricksAdapter } from "@databricks/appkit/beta"; + * import { WorkspaceClient } from "@databricks/sdk-experimental"; + * + * const adapter = DatabricksAdapter.fromServingEndpoint({ + * workspaceClient: new WorkspaceClient({}), + * endpointName: "my-endpoint", + * }); + * + * await createApp({ + * plugins: [ + * agents({ + * agents: { + * assistant: createAgent({ + * instructions: "You are a helpful assistant.", + * model: adapter, + * }), + * }, + * }), + * ], + * }); + * ``` + * + * @example Using the raw constructor + * ```ts + * const adapter = new DatabricksAdapter({ + * endpointUrl: "https://host/serving-endpoints/my-endpoint/invocations", + * authenticate: async () => ({ Authorization: `Bearer ${token}` }), + * }); + * ``` + */ +export class DatabricksAdapter implements AgentAdapter { + private streamBody: StreamBody; + private maxSteps: number; + private maxTokens: number; + private maxSseLineChars: number; + private maxStreamTextChars: number; + private maxToolArgumentsChars: number; + + constructor(options: DatabricksAdapterOptions) { + this.maxSteps = options.maxSteps ?? 10; + this.maxTokens = options.maxTokens ?? 4096; + this.maxSseLineChars = + options.maxSseLineChars ?? DEFAULT_MAX_SSE_LINE_CHARS; + this.maxStreamTextChars = + options.maxStreamTextChars ?? DEFAULT_MAX_STREAM_TEXT_CHARS; + this.maxToolArgumentsChars = + options.maxToolArgumentsChars ?? DEFAULT_MAX_TOOL_ARGUMENT_CHARS; + + if (isStreamBodyOptions(options)) { + this.streamBody = options.streamBody; + } else { + const { endpointUrl, authenticate } = options; + this.streamBody = async (body, signal) => { + const fetchSignal = + signal ?? AbortSignal.timeout(RAW_FETCH_DEFAULT_TIMEOUT_MS); + const authHeaders = await authenticate(); + const response = await fetch(endpointUrl, { + method: "POST", + headers: { + "Content-Type": "application/json", + ...authHeaders, + }, + body: JSON.stringify(body), + signal: fetchSignal, + }); + if (!response.ok) { + const errorText = await response.text().catch(() => "Unknown error"); + throw new Error( + `Databricks API error (${response.status}): ${errorText}`, + ); + } + if (!response.body) throw new Error("No response body"); + return response.body; + }; + } + } + + /** + * Creates a DatabricksAdapter for a Databricks Model Serving endpoint. + * + * Routes through the shared `connectors/serving/stream` helper, which + * delegates to the SDK's `apiClient.request({ raw: true })`. That gives the + * adapter centralised URL encoding + authentication with the rest of the + * serving surface — no bespoke `fetch()` + `authenticate()` plumbing. + */ + static async fromServingEndpoint( + options: ServingEndpointOptions, + ): Promise { + const { + workspaceClient, + endpointName, + maxSteps, + maxTokens, + maxSseLineChars, + maxStreamTextChars, + maxToolArgumentsChars, + } = options; + return new DatabricksAdapter({ + streamBody: (body, signal) => + // Cast through the structural shape: the connector types + // `workspaceClient` as the SDK's concrete `WorkspaceClient`, but we + // only need `apiClient.request`. + servingStream( + workspaceClient as unknown as Parameters[0], + endpointName, + body, + signal, + ), + maxSteps, + maxTokens, + maxSseLineChars, + maxStreamTextChars, + maxToolArgumentsChars, + }); + } + + /** + * Creates a DatabricksAdapter from a Model Serving endpoint name. + * Auto-creates a WorkspaceClient internally. Reads the endpoint name + * from the argument or the `DATABRICKS_SERVING_ENDPOINT_NAME` env var. + * + * @example + * ```ts + * // Reads endpoint from DATABRICKS_SERVING_ENDPOINT_NAME env var + * const adapter = await DatabricksAdapter.fromModelServing(); + * + * // Explicit endpoint + * const adapter = await DatabricksAdapter.fromModelServing("my-endpoint"); + * + * // With options + * const adapter = await DatabricksAdapter.fromModelServing("my-endpoint", { + * maxSteps: 5, + * maxTokens: 2048, + * }); + * ``` + */ + static async fromModelServing( + endpointName?: string, + options?: ModelServingOptions, + ): Promise { + const resolvedEndpoint = + endpointName ?? process.env.DATABRICKS_SERVING_ENDPOINT_NAME; + + if (!resolvedEndpoint) { + throw new Error( + "No endpoint name provided and DATABRICKS_SERVING_ENDPOINT_NAME env var is not set. " + + "Pass an endpoint name or set DATABRICKS_SERVING_ENDPOINT_NAME.", + ); + } + + let workspaceClient: WorkspaceClientLike | undefined = + options?.workspaceClient; + if (!workspaceClient) { + const sdk = await import("@databricks/sdk-experimental"); + workspaceClient = new sdk.WorkspaceClient( + {}, + ) as unknown as WorkspaceClientLike; + } + + return DatabricksAdapter.fromServingEndpoint({ + workspaceClient, + endpointName: resolvedEndpoint, + maxSteps: options?.maxSteps, + maxTokens: options?.maxTokens, + maxSseLineChars: options?.maxSseLineChars, + maxStreamTextChars: options?.maxStreamTextChars, + maxToolArgumentsChars: options?.maxToolArgumentsChars, + }); + } + + async *run( + input: AgentInput, + context: AgentRunContext, + ): AsyncGenerator { + // Databricks API requires tool names to match [a-zA-Z0-9_-]. + // Our tool names use dots (e.g. "analytics.query"), so we swap dots + // for double-underscores in the wire format and map back on receipt. + const nameToWire = new Map(); + const wireToName = new Map(); + for (const tool of input.tools) { + const wire = tool.name.replace(/\./g, "__"); + if (wireToName.has(wire) && wireToName.get(wire) !== tool.name) { + throw new Error( + `Tool name collision: '${tool.name}' and '${wireToName.get(wire)}' both map to wire name '${wire}'`, + ); + } + nameToWire.set(tool.name, wire); + wireToName.set(wire, tool.name); + } + + const tools = this.buildTools(input.tools, nameToWire); + const messages = this.buildMessages(input.messages, nameToWire); + + yield { type: "status", status: "running" }; + + for (let step = 0; step < this.maxSteps; step++) { + if (context.signal?.aborted) break; + + const { text, toolCalls } = yield* this.streamCompletion( + messages, + tools, + context, + ); + + if (toolCalls.length === 0) { + const parsed = parseTextToolCalls(text); + if (parsed.length > 0) { + yield* this.executeToolCalls(parsed, messages, context, nameToWire); + continue; + } + break; + } + + messages.push({ + role: "assistant", + content: text || null, + tool_calls: toolCalls, + }); + + for (const tc of toolCalls) { + const wireName = tc.function.name; + const originalName = wireToName.get(wireName) ?? wireName; + yield* this.executeSingleTool(tc, originalName, messages, context); + } + } + } + + /** Parse wire arguments, emit tool_call / tool_result, append tool messages. */ + private async *executeSingleTool( + tc: OpenAIToolCall, + originalName: string, + messages: OpenAIMessage[], + context: AgentRunContext, + ): AsyncGenerator { + let args: unknown; + try { + args = JSON.parse(tc.function.arguments); + } catch { + args = {}; + } + + yield { type: "tool_call", callId: tc.id, name: originalName, args }; + + try { + const result = await context.executeTool(originalName, args); + const resultStr = + typeof result === "string" ? result : JSON.stringify(result); + + yield { type: "tool_result", callId: tc.id, result }; + + messages.push({ + role: "tool", + content: resultStr, + tool_call_id: tc.id, + }); + } catch (error) { + const errMsg = + error instanceof Error ? error.message : "Tool execution failed"; + + yield { + type: "tool_result", + callId: tc.id, + result: null, + error: errMsg, + }; + + messages.push({ + role: "tool", + content: JSON.stringify({ error: errMsg }), + tool_call_id: tc.id, + }); + } + } + + private async *streamCompletion( + messages: OpenAIMessage[], + tools: OpenAITool[], + context: AgentRunContext, + ): AsyncGenerator< + AgentEvent, + { text: string; toolCalls: OpenAIToolCall[] }, + unknown + > { + const body: Record = { + messages, + stream: true, + max_tokens: this.maxTokens, + }; + + if (tools.length > 0) { + body.tools = tools; + } + + let responseBody: ReadableStream; + try { + responseBody = await this.streamBody(body, context.signal); + } catch (err) { + const msg = err instanceof Error ? err.message : "Stream request failed"; + yield { type: "status", status: "error", error: msg }; + throw err; + } + + const reader = responseBody.getReader(); + + const decoder = new TextDecoder(); + let buffer = ""; + let fullText = ""; + const toolCallAccumulator = new Map< + number, + { id: string; name: string; arguments: string } + >(); + + try { + while (true) { + if (context.signal?.aborted) break; + + const { done, value } = await reader.read(); + if (done) break; + + buffer += decoder.decode(value, { stream: true }); + const lines = buffer.split("\n"); + buffer = lines.pop() ?? ""; + + if (buffer.length > this.maxSseLineChars) { + throw new Error( + `DatabricksAdapter: SSE line buffer exceeds configured limit (${this.maxSseLineChars} UTF-16 code units)`, + ); + } + + for (const line of lines) { + if (line.length > this.maxSseLineChars) { + throw new Error( + `DatabricksAdapter: SSE line exceeds configured limit (${this.maxSseLineChars} UTF-16 code units)`, + ); + } + + const trimmed = line.trim(); + if (!trimmed.startsWith("data: ")) continue; + const data = trimmed.slice(6); + if (data === "[DONE]") continue; + + let parsed: unknown; + try { + parsed = JSON.parse(data); + } catch (parseErr) { + console.debug( + "[DatabricksAdapter] malformed SSE data line JSON", + { line: `${data.slice(0, 256)}${data.length > 256 ? "…" : ""}` }, + parseErr, + ); + continue; + } + + const deltaUnknown = openAiChoicesDelta(parsed); + if (!isRecord(deltaUnknown)) continue; + + if (typeof deltaUnknown.content === "string") { + const content = deltaUnknown.content; + throwIfExceedsStreamLimit( + "streamed assistant text", + fullText.length, + content, + this.maxStreamTextChars, + ); + fullText += content; + yield { type: "message_delta" as const, content }; + } + + const toolCallsRaw = deltaUnknown.tool_calls; + if (!Array.isArray(toolCallsRaw)) continue; + + for (const tc of toolCallsRaw) { + if (!isStreamingDeltaToolCall(tc)) continue; + const existing = toolCallAccumulator.get(tc.index); + if (existing) { + if (tc.function?.arguments) { + throwIfExceedsStreamLimit( + "tool call arguments", + existing.arguments.length, + tc.function.arguments, + this.maxToolArgumentsChars, + ); + existing.arguments += tc.function.arguments; + } + } else { + const initial = tc.function?.arguments ?? ""; + if (initial.length > this.maxToolArgumentsChars) { + throw new Error( + `DatabricksAdapter: tool call arguments exceed configured limit (${this.maxToolArgumentsChars} UTF-16 code units)`, + ); + } + toolCallAccumulator.set(tc.index, { + id: tc.id ?? `call_${tc.index}`, + name: tc.function?.name ?? "", + arguments: initial, + }); + } + } + } + } + } finally { + try { + await reader.cancel(); + } catch (cancelErr) { + console.debug( + "[DatabricksAdapter] reader.cancel() failed during teardown", + cancelErr, + ); + } + try { + reader.releaseLock(); + } catch (unlockErr) { + console.debug( + "[DatabricksAdapter] reader.releaseLock() failed during teardown", + unlockErr, + ); + } + } + + const toolCalls: OpenAIToolCall[] = Array.from( + toolCallAccumulator.values(), + ).map((tc) => ({ + id: tc.id, + type: "function" as const, + function: { name: tc.name, arguments: tc.arguments || "{}" }, + })); + + return { text: fullText, toolCalls }; + } + + private async *executeToolCalls( + calls: Array<{ name: string; args: unknown }>, + messages: OpenAIMessage[], + context: AgentRunContext, + nameToWire: Map, + ): AsyncGenerator { + const wireToolName = (name: string) => + nameToWire.get(name) ?? name.replace(/\./g, "__"); + + const toolCallObjs: OpenAIToolCall[] = calls.map((c, i) => ({ + id: `text_call_${i}`, + type: "function" as const, + function: { + name: wireToolName(c.name), + arguments: JSON.stringify(c.args), + }, + })); + + messages.push({ + role: "assistant", + content: null, + tool_calls: toolCallObjs, + }); + + for (let i = 0; i < toolCallObjs.length; i++) { + const tc = toolCallObjs[i]; + const originalName = calls[i]?.name ?? tc.function.name; + yield* this.executeSingleTool(tc, originalName, messages, context); + } + } + + /** + * Maps AppKit {@link AgentInput} messages into OpenAI-compatible wire messages. + * Preserves multi-turn tool state (`toolCalls` → `tool_calls`, `toolCallId` → + * `tool_call_id`) so resumed threads and hydrated history reach the model. + */ + private buildMessages( + messages: AgentInput["messages"], + nameToWire: Map, + ): OpenAIMessage[] { + const wireToolName = (name: string) => + nameToWire.get(name) ?? name.replace(/\./g, "__"); + + return messages.map((m) => { + let content: string | null = m.content; + if ( + m.role === "assistant" && + m.toolCalls && + m.toolCalls.length > 0 && + (!m.content || m.content.trim() === "") + ) { + content = null; + } + + const out: OpenAIMessage = { + role: m.role as OpenAIMessage["role"], + content, + }; + + if (m.toolCallId) { + out.tool_call_id = m.toolCallId; + } + + if (m.toolCalls && m.toolCalls.length > 0) { + out.tool_calls = m.toolCalls.map((tc) => ({ + id: tc.id, + type: "function" as const, + function: { + name: wireToolName(tc.name), + arguments: + typeof tc.args === "string" + ? tc.args + : JSON.stringify(tc.args ?? {}), + }, + })); + } + + return out; + }); + } + + private buildTools( + definitions: AgentToolDefinition[], + nameToWire: Map, + ): OpenAITool[] { + return definitions.map((def) => ({ + type: "function" as const, + function: { + name: nameToWire.get(def.name) ?? def.name, + description: def.description, + parameters: def.parameters, + }, + })); + } +} + +// --------------------------------------------------------------------------- +// Text-based tool call parsing (fallback) +// --------------------------------------------------------------------------- + +/** + * Parses text-based tool calls from model output. + * + * Handles two formats: + * 1. Llama native: `[{"name": "tool_name", "parameters": {"arg": "val"}}]` + * 2. Python-style: `[tool_name(arg1='val1', arg2='val2')]` + */ +export function parseTextToolCalls( + text: string, +): Array<{ name: string; args: unknown }> { + const trimmed = text.trim(); + + const jsonResult = tryParseLlamaJsonToolCalls(trimmed); + if (jsonResult.length > 0) return jsonResult; + + const pyResult = tryParsePythonStyleToolCalls(trimmed); + if (pyResult.length > 0) return pyResult; + + return []; +} + +function isLlamaToolJsonItem(value: unknown): value is Record< + string, + unknown +> & { + name: string; +} { + if (!isRecord(value)) return false; + return typeof value.name === "string"; +} + +function tryParseLlamaJsonToolCalls( + text: string, +): Array<{ name: string; args: unknown }> { + const slice = extractLlamaToolJsonSlice(text); + if (!slice) return []; + + try { + const parsed: unknown = JSON.parse(slice); + if (!Array.isArray(parsed)) return []; + + return parsed.filter(isLlamaToolJsonItem).map((item) => ({ + name: item.name, + args: item.parameters ?? item.arguments ?? item.args ?? {}, + })); + } catch { + return []; + } +} + +function tryParsePythonStyleToolCalls( + text: string, +): Array<{ name: string; args: unknown }> { + if (text.length > PYTHON_STYLE_TOOL_PARSE_MAX_INPUT) { + return []; + } + + const pattern = /\[?([a-zA-Z_][\w.]*)\(([^)]*)\)\]?/g; + const results: Array<{ name: string; args: unknown }> = []; + + for (const match of text.matchAll(pattern)) { + const name = match[1]; + const argsStr = match[2]; + + const args: Record = {}; + const argPattern = /(\w+)\s*=\s*(?:'([^']*)'|"([^"]*)"|(\S+))/g; + for (const argMatch of argsStr.matchAll(argPattern)) { + const key = argMatch[1]; + const value = argMatch[2] ?? argMatch[3] ?? argMatch[4]; + args[key] = value; + } + + results.push({ name, args }); + } + + return results; +} diff --git a/packages/appkit/src/agents/tests/databricks.test.ts b/packages/appkit/src/agents/tests/databricks.test.ts new file mode 100644 index 000000000..fd51bc0fc --- /dev/null +++ b/packages/appkit/src/agents/tests/databricks.test.ts @@ -0,0 +1,882 @@ +import type { AgentEvent, AgentToolDefinition, Message } from "shared"; +import { afterEach, beforeEach, describe, expect, test, vi } from "vitest"; +import { DatabricksAdapter, parseTextToolCalls } from "../databricks"; + +const mockAuthenticate = vi + .fn() + .mockResolvedValue({ Authorization: "Bearer test-token" }); + +function sseChunk(data: string): string { + return `data: ${data}\n\n`; +} + +function textDelta(content: string): string { + return sseChunk( + JSON.stringify({ + choices: [{ delta: { content } }], + }), + ); +} + +function toolCallDelta( + index: number, + id: string | undefined, + name: string | undefined, + args: string, +): string { + return sseChunk( + JSON.stringify({ + choices: [ + { + delta: { + tool_calls: [ + { + index, + ...(id && { id }), + ...(name && { type: "function" }), + function: { + ...(name && { name }), + arguments: args, + }, + }, + ], + }, + }, + ], + }), + ); +} + +function createReadableStream(chunks: string[]): ReadableStream { + const encoder = new TextEncoder(); + let i = 0; + return new ReadableStream({ + pull(controller) { + if (i < chunks.length) { + controller.enqueue(encoder.encode(chunks[i])); + i++; + } else { + controller.close(); + } + }, + }); +} + +function mockFetch(chunks: string[]): typeof globalThis.fetch { + return vi.fn().mockResolvedValue({ + ok: true, + body: createReadableStream(chunks), + text: () => Promise.resolve(""), + }); +} + +function createTestMessages(): Message[] { + return [{ id: "1", role: "user", content: "Hello", createdAt: new Date() }]; +} + +function createTestTools(): AgentToolDefinition[] { + return [ + { + name: "analytics.query", + description: "Run SQL", + parameters: { + type: "object", + properties: { query: { type: "string" } }, + required: ["query"], + }, + }, + ]; +} + +function createAdapter(overrides?: { + endpointUrl?: string; + authenticate?: () => Promise>; + maxSteps?: number; + maxTokens?: number; + maxSseLineChars?: number; + maxStreamTextChars?: number; + maxToolArgumentsChars?: number; +}) { + return new DatabricksAdapter({ + endpointUrl: + "https://test.databricks.com/serving-endpoints/my-endpoint/invocations", + authenticate: mockAuthenticate, + ...overrides, + }); +} + +describe("DatabricksAdapter", () => { + const originalFetch = globalThis.fetch; + + afterEach(() => { + globalThis.fetch = originalFetch; + mockAuthenticate.mockClear(); + }); + + test("streams text deltas from the model", async () => { + globalThis.fetch = mockFetch([ + textDelta("Hello"), + textDelta(" world"), + sseChunk("[DONE]"), + ]); + + const adapter = createAdapter(); + const events: AgentEvent[] = []; + + for await (const event of adapter.run( + { messages: createTestMessages(), tools: [], threadId: "t1" }, + { executeTool: vi.fn() }, + )) { + events.push(event); + } + + expect(events[0]).toEqual({ type: "status", status: "running" }); + expect(events[1]).toEqual({ type: "message_delta", content: "Hello" }); + expect(events[2]).toEqual({ type: "message_delta", content: " world" }); + }); + + test("calls authenticate() per request for fresh headers", async () => { + globalThis.fetch = mockFetch([textDelta("Hi"), sseChunk("[DONE]")]); + + const adapter = createAdapter(); + + for await (const _ of adapter.run( + { messages: createTestMessages(), tools: [], threadId: "t1" }, + { executeTool: vi.fn() }, + )) { + // drain + } + + expect(mockAuthenticate).toHaveBeenCalledTimes(1); + + const [, init] = (globalThis.fetch as any).mock.calls[0]; + expect(init.headers.Authorization).toBe("Bearer test-token"); + }); + + test("throws when two tool names map to the same wire format", async () => { + const adapter = createAdapter(); + const conflictingTools: AgentToolDefinition[] = [ + { + name: "foo.bar", + description: "one", + parameters: { type: "object", properties: {} }, + }, + { + name: "foo__bar", + description: "two", + parameters: { type: "object", properties: {} }, + }, + ]; + + await expect(async () => { + for await (const _ of adapter.run( + { + messages: createTestMessages(), + tools: conflictingTools, + threadId: "t1", + }, + { executeTool: vi.fn() }, + )) { + // drain + } + }).rejects.toThrow( + /Tool name collision: .* both map to wire name 'foo__bar'/, + ); + }); + + test("handles structured tool calls and executes them", async () => { + const executeTool = vi.fn().mockResolvedValue([{ trip_id: 1 }]); + + let callCount = 0; + globalThis.fetch = vi.fn().mockImplementation(() => { + callCount++; + if (callCount === 1) { + return Promise.resolve({ + ok: true, + body: createReadableStream([ + toolCallDelta(0, "call_1", "analytics__query", ""), + toolCallDelta(0, undefined, undefined, '{"query":'), + toolCallDelta(0, undefined, undefined, '"SELECT 1"}'), + sseChunk("[DONE]"), + ]), + }); + } + return Promise.resolve({ + ok: true, + body: createReadableStream([ + textDelta("Here are the results"), + sseChunk("[DONE]"), + ]), + }); + }); + + const adapter = createAdapter(); + const events: AgentEvent[] = []; + + for await (const event of adapter.run( + { + messages: createTestMessages(), + tools: createTestTools(), + threadId: "t1", + }, + { executeTool }, + )) { + events.push(event); + } + + expect(events).toContainEqual({ + type: "tool_call", + callId: "call_1", + name: "analytics.query", + args: { query: "SELECT 1" }, + }); + + expect(executeTool).toHaveBeenCalledWith("analytics.query", { + query: "SELECT 1", + }); + + expect(events).toContainEqual( + expect.objectContaining({ + type: "tool_result", + callId: "call_1", + result: [{ trip_id: 1 }], + }), + ); + + expect(events).toContainEqual({ + type: "message_delta", + content: "Here are the results", + }); + + // authenticate() called once per streamCompletion + expect(mockAuthenticate).toHaveBeenCalledTimes(2); + }); + + test("text-parsed tool calls use wire names on follow-up requests", async () => { + const executeTool = vi.fn().mockResolvedValue({ ok: true }); + let callCount = 0; + + const llamaToolJson = + '[{"name": "analytics.query", "parameters": {"query": "SELECT 1"}}]'; + + globalThis.fetch = vi.fn().mockImplementation(() => { + callCount++; + if (callCount === 1) { + return Promise.resolve({ + ok: true, + body: createReadableStream([ + textDelta(llamaToolJson), + sseChunk("[DONE]"), + ]), + }); + } + return Promise.resolve({ + ok: true, + body: createReadableStream([textDelta("Done."), sseChunk("[DONE]")]), + }); + }); + + const adapter = createAdapter(); + + for await (const _ of adapter.run( + { + messages: createTestMessages(), + tools: createTestTools(), + threadId: "t1", + }, + { executeTool }, + )) { + // drain + } + + expect(executeTool).toHaveBeenCalledWith("analytics.query", { + query: "SELECT 1", + }); + + expect(globalThis.fetch).toHaveBeenCalledTimes(2); + const [, secondInit] = (globalThis.fetch as any).mock.calls[1]; + const secondBody = JSON.parse(secondInit.body); + + expect(secondBody.messages[1]).toEqual({ + role: "assistant", + content: null, + tool_calls: [ + { + id: "text_call_0", + type: "function", + function: { + name: "analytics__query", + arguments: JSON.stringify({ query: "SELECT 1" }), + }, + }, + ], + }); + + expect(secondBody.messages[2]).toEqual({ + role: "tool", + content: JSON.stringify({ ok: true }), + tool_call_id: "text_call_0", + }); + }); + + test("respects maxSteps limit", async () => { + globalThis.fetch = vi.fn().mockImplementation(() => + Promise.resolve({ + ok: true, + body: createReadableStream([ + toolCallDelta( + 0, + "call_loop", + "analytics__query", + '{"query":"SELECT 1"}', + ), + sseChunk("[DONE]"), + ]), + }), + ); + + const adapter = createAdapter({ maxSteps: 2 }); + const events: AgentEvent[] = []; + + for await (const event of adapter.run( + { + messages: createTestMessages(), + tools: createTestTools(), + threadId: "t1", + }, + { executeTool: vi.fn().mockResolvedValue("ok") }, + )) { + events.push(event); + } + + expect(globalThis.fetch).toHaveBeenCalledTimes(2); + }); + + test("sends correct request to endpoint URL", async () => { + globalThis.fetch = mockFetch([textDelta("Hi"), sseChunk("[DONE]")]); + + const adapter = createAdapter(); + + for await (const _ of adapter.run( + { + messages: createTestMessages(), + tools: createTestTools(), + threadId: "t1", + }, + { executeTool: vi.fn() }, + )) { + // drain + } + + const [url, init] = (globalThis.fetch as any).mock.calls[0]; + expect(url).toBe( + "https://test.databricks.com/serving-endpoints/my-endpoint/invocations", + ); + + const body = JSON.parse(init.body); + expect(body.stream).toBe(true); + expect(body.tools).toHaveLength(1); + expect(body.tools[0].function.name).toBe("analytics__query"); + expect(body.messages[0]).toEqual({ + role: "user", + content: "Hello", + }); + }); + + test("forwards tool thread fields from input messages to the request body", async () => { + globalThis.fetch = mockFetch([textDelta("Done"), sseChunk("[DONE]")]); + + const adapter = createAdapter(); + + const threadMessages: Message[] = [ + { id: "1", role: "user", content: "Run SQL", createdAt: new Date() }, + { + id: "2", + role: "assistant", + content: "", + createdAt: new Date(), + toolCalls: [ + { + id: "call_1", + name: "analytics.query", + args: { query: "SELECT 1" }, + }, + ], + }, + { + id: "3", + role: "tool", + content: '{"rows":[]}', + createdAt: new Date(), + toolCallId: "call_1", + }, + ]; + + for await (const _ of adapter.run( + { + messages: threadMessages, + tools: createTestTools(), + threadId: "t1", + }, + { executeTool: vi.fn() }, + )) { + // drain + } + + const [, init] = (globalThis.fetch as any).mock.calls[0]; + const body = JSON.parse(init.body); + + expect(body.messages[1]).toEqual({ + role: "assistant", + content: null, + tool_calls: [ + { + id: "call_1", + type: "function", + function: { + name: "analytics__query", + arguments: JSON.stringify({ query: "SELECT 1" }), + }, + }, + ], + }); + + expect(body.messages[2]).toEqual({ + role: "tool", + content: '{"rows":[]}', + tool_call_id: "call_1", + }); + }); + + test("throws when SSE line buffer exceeds maxSseLineChars", async () => { + globalThis.fetch = vi.fn().mockResolvedValue({ + ok: true, + body: createReadableStream(["no-newline-", "xxxxxxxxxx"]), + text: () => Promise.resolve(""), + }); + + const adapter = createAdapter({ maxSseLineChars: 12 }); + + await expect(async () => { + for await (const _ of adapter.run( + { messages: createTestMessages(), tools: [], threadId: "t1" }, + { executeTool: vi.fn() }, + )) { + // drain + } + }).rejects.toThrow(/SSE line buffer exceeds configured limit/); + }); + + test("throws when a complete SSE line exceeds maxSseLineChars", async () => { + const longPayload = `${"x".repeat(30)}\n`; + globalThis.fetch = vi.fn().mockResolvedValue({ + ok: true, + body: createReadableStream([longPayload]), + text: () => Promise.resolve(""), + }); + + const adapter = createAdapter({ maxSseLineChars: 20 }); + + await expect(async () => { + for await (const _ of adapter.run( + { messages: createTestMessages(), tools: [], threadId: "t1" }, + { executeTool: vi.fn() }, + )) { + // drain + } + }).rejects.toThrow(/SSE line exceeds configured limit/); + }); + + test("throws when streamed assistant text exceeds maxStreamTextChars", async () => { + globalThis.fetch = mockFetch([ + textDelta("abcde"), + textDelta("f"), + sseChunk("[DONE]"), + ]); + + const adapter = createAdapter({ maxStreamTextChars: 5 }); + + await expect(async () => { + for await (const _ of adapter.run( + { messages: createTestMessages(), tools: [], threadId: "t1" }, + { executeTool: vi.fn() }, + )) { + // drain + } + }).rejects.toThrow(/streamed assistant text exceeds configured limit/); + }); + + test("throws when streamed tool arguments exceed maxToolArgumentsChars", async () => { + globalThis.fetch = vi.fn().mockResolvedValue({ + ok: true, + body: createReadableStream([ + toolCallDelta(0, "c1", "t", '{"a":"'), + toolCallDelta(0, undefined, undefined, 'xxxx"}'), + sseChunk("[DONE]"), + ]), + text: () => Promise.resolve(""), + }); + + const adapter = createAdapter({ maxToolArgumentsChars: 8 }); + + await expect(async () => { + for await (const _ of adapter.run( + { + messages: createTestMessages(), + tools: [ + { + name: "t", + description: "x", + parameters: { type: "object", properties: {} }, + }, + ], + threadId: "t1", + }, + { executeTool: vi.fn().mockResolvedValue("ok") }, + )) { + // drain + } + }).rejects.toThrow(/tool call arguments exceed/); + }); + + test("throws on non-ok response", async () => { + globalThis.fetch = vi.fn().mockResolvedValue({ + ok: false, + status: 401, + text: () => Promise.resolve("Unauthorized"), + }); + + const adapter = createAdapter(); + + await expect(async () => { + for await (const _ of adapter.run( + { messages: createTestMessages(), tools: [], threadId: "t1" }, + { executeTool: vi.fn() }, + )) { + // drain + } + }).rejects.toThrow("Databricks API error (401): Unauthorized"); + }); + + test("yields error status then throws when injected streamBody fails", async () => { + const adapter = new DatabricksAdapter({ + streamBody: async () => Promise.reject(new Error("serving_unreachable")), + maxSteps: 1, + }); + + const events: AgentEvent[] = []; + await expect(async () => { + for await (const ev of adapter.run( + { messages: createTestMessages(), tools: [], threadId: "t1" }, + { executeTool: vi.fn() }, + )) { + events.push(ev); + } + }).rejects.toThrow("serving_unreachable"); + + expect(events[0]).toEqual({ type: "status", status: "running" }); + expect(events[1]).toEqual({ + type: "status", + status: "error", + error: "serving_unreachable", + }); + }); + + test("yields tool_result with error when executeTool rejects", async () => { + const executeTool = vi.fn().mockRejectedValue(new Error("tool_denied")); + + globalThis.fetch = vi.fn().mockResolvedValue({ + ok: true, + body: createReadableStream([ + toolCallDelta( + 0, + "call_fail", + "analytics__query", + '{"query":"SELECT 2"}', + ), + sseChunk("[DONE]"), + ]), + text: () => Promise.resolve(""), + }); + + const adapter = createAdapter({ maxSteps: 1 }); + const events: AgentEvent[] = []; + + for await (const ev of adapter.run( + { + messages: createTestMessages(), + tools: createTestTools(), + threadId: "t1", + }, + { executeTool }, + )) { + events.push(ev); + } + + expect(events).toContainEqual({ + type: "tool_call", + callId: "call_fail", + name: "analytics.query", + args: { query: "SELECT 2" }, + }); + + expect(events).toContainEqual({ + type: "tool_result", + callId: "call_fail", + result: null, + error: "tool_denied", + }); + + expect(executeTool).toHaveBeenCalledWith("analytics.query", { + query: "SELECT 2", + }); + }); + + test("uses AbortSignal.timeout for raw fetch when context has no signal", async () => { + globalThis.fetch = mockFetch([textDelta("Hello"), sseChunk("[DONE]")]); + + const ac = new AbortController(); + const timeoutSpy = vi + .spyOn(AbortSignal, "timeout") + .mockReturnValue(ac.signal); + + const adapter = createAdapter(); + + for await (const _ of adapter.run( + { + messages: createTestMessages(), + tools: createTestTools(), + threadId: "t1", + }, + { executeTool: vi.fn(), signal: undefined }, + )) { + // drain + } + + expect(timeoutSpy).toHaveBeenCalledWith(120_000); + timeoutSpy.mockRestore(); + }); + + test("logs and skips malformed JSON in SSE lines", async () => { + const debugSpy = vi.spyOn(console, "debug").mockImplementation(() => {}); + globalThis.fetch = mockFetch([ + sseChunk("{not-json-truncated"), + textDelta("ok"), + sseChunk("[DONE]"), + ]); + + const adapter = createAdapter(); + const events: AgentEvent[] = []; + + for await (const ev of adapter.run( + { + messages: createTestMessages(), + tools: createTestTools(), + threadId: "t1", + }, + { executeTool: vi.fn() }, + )) { + events.push(ev); + } + + expect( + debugSpy.mock.calls.some(([msg]) => { + return typeof msg === "string" && msg.includes("malformed SSE"); + }), + ).toBe(true); + expect( + events.some((e) => e.type === "message_delta" && e.content === "ok"), + ).toBe(true); + debugSpy.mockRestore(); + }); +}); + +describe("DatabricksAdapter.fromServingEndpoint", () => { + test("routes tool-free chat through apiClient.request with a streaming payload", async () => { + const apiClient = { + request: vi.fn().mockResolvedValue({ + contents: createReadableStream([textDelta("Hi"), sseChunk("[DONE]")]), + }), + }; + + const adapter = await DatabricksAdapter.fromServingEndpoint({ + workspaceClient: { apiClient }, + endpointName: "my-model", + }); + + for await (const _ of adapter.run( + { messages: createTestMessages(), tools: [], threadId: "t1" }, + { executeTool: vi.fn() }, + )) { + // drain + } + + expect(apiClient.request).toHaveBeenCalledTimes(1); + const [requestArgs] = apiClient.request.mock.calls[0]; + expect(requestArgs.path).toBe("/serving-endpoints/my-model/invocations"); + expect(requestArgs.method).toBe("POST"); + expect(requestArgs.raw).toBe(true); + expect(requestArgs.payload.stream).toBe(true); + // Auth + url encoding are the connector's (and the SDK's) concerns — the + // adapter no longer reaches into the workspace config. + }); + + test("URL-encodes endpoint names with special characters", async () => { + const apiClient = { + request: vi.fn().mockResolvedValue({ + contents: createReadableStream([textDelta("Hi"), sseChunk("[DONE]")]), + }), + }; + + const adapter = await DatabricksAdapter.fromServingEndpoint({ + workspaceClient: { apiClient }, + endpointName: "my model/with spaces", + }); + + for await (const _ of adapter.run( + { messages: createTestMessages(), tools: [], threadId: "t1" }, + { executeTool: vi.fn() }, + )) { + // drain + } + + const [requestArgs] = apiClient.request.mock.calls[0]; + expect(requestArgs.path).toBe( + "/serving-endpoints/my%20model%2Fwith%20spaces/invocations", + ); + }); +}); + +describe("DatabricksAdapter.fromModelServing", () => { + const originalEnv = process.env; + + beforeEach(() => { + process.env = { ...originalEnv }; + }); + + afterEach(() => { + process.env = originalEnv; + }); + + test("reads endpoint from DATABRICKS_SERVING_ENDPOINT_NAME env var", async () => { + process.env.DATABRICKS_SERVING_ENDPOINT_NAME = "my-model"; + + vi.mock("@databricks/sdk-experimental", () => ({ + WorkspaceClient: vi.fn().mockImplementation(() => ({ + apiClient: { request: vi.fn() }, + })), + })); + + const adapter = await DatabricksAdapter.fromModelServing(); + expect(adapter).toBeInstanceOf(DatabricksAdapter); + }); + + test("throws when no endpoint name and no env var", async () => { + delete process.env.DATABRICKS_SERVING_ENDPOINT_NAME; + + await expect(DatabricksAdapter.fromModelServing()).rejects.toThrow( + "No endpoint name provided", + ); + }); + + test("explicit endpoint name takes precedence over env var", async () => { + process.env.DATABRICKS_SERVING_ENDPOINT_NAME = "env-model"; + + const apiClient = { + request: vi.fn().mockResolvedValue({ + contents: createReadableStream([textDelta("Hi"), sseChunk("[DONE]")]), + }), + }; + + const adapter = await DatabricksAdapter.fromModelServing("explicit-model", { + workspaceClient: { apiClient }, + }); + + expect(adapter).toBeInstanceOf(DatabricksAdapter); + + for await (const _ of adapter.run( + { messages: createTestMessages(), tools: [], threadId: "t1" }, + { executeTool: vi.fn() }, + )) { + // drain + } + + const [requestArgs] = apiClient.request.mock.calls[0]; + expect(requestArgs.path).toBe( + "/serving-endpoints/explicit-model/invocations", + ); + }); +}); + +describe("parseTextToolCalls", () => { + test("parses Llama JSON format", () => { + const text = + '[{"name": "analytics.query", "parameters": {"query": "SELECT 1"}}]'; + const result = parseTextToolCalls(text); + + expect(result).toEqual([ + { name: "analytics.query", args: { query: "SELECT 1" } }, + ]); + }); + + test("parses multiple Llama JSON tool calls", () => { + const text = + '[{"name": "analytics.query", "parameters": {"query": "SELECT 1"}}, {"name": "files.uploads.list", "parameters": {}}]'; + const result = parseTextToolCalls(text); + + expect(result).toHaveLength(2); + expect(result[0].name).toBe("analytics.query"); + expect(result[1].name).toBe("files.uploads.list"); + }); + + test("parses Python-style tool calls", () => { + const text = + "[analytics.query(query='SELECT * FROM trips ORDER BY date DESC LIMIT 10')]"; + const result = parseTextToolCalls(text); + + expect(result).toEqual([ + { + name: "analytics.query", + args: { + query: "SELECT * FROM trips ORDER BY date DESC LIMIT 10", + }, + }, + ]); + }); + + test("parses Python-style with multiple args", () => { + const text = + "[files.uploads.read(path='/data/file.csv', encoding='utf-8')]"; + const result = parseTextToolCalls(text); + + expect(result).toEqual([ + { + name: "files.uploads.read", + args: { path: "/data/file.csv", encoding: "utf-8" }, + }, + ]); + }); + + test("returns empty array for plain text", () => { + expect(parseTextToolCalls("Hello, how can I help?")).toEqual([]); + expect(parseTextToolCalls("")).toEqual([]); + expect(parseTextToolCalls("The answer is 42")).toEqual([]); + }); + + test("handles Llama format with 'arguments' key", () => { + const text = + '[{"name": "lakebase.query", "arguments": {"text": "SELECT 1"}}]'; + const result = parseTextToolCalls(text); + + expect(result).toEqual([ + { name: "lakebase.query", args: { text: "SELECT 1" } }, + ]); + }); + + test("returns empty when Python-style fallback text exceeds size cap", () => { + const cap = 64 * 1024; + const filler = "x".repeat(cap); + const suffix = "[analytics.query(query='SELECT 1')]"; + expect(parseTextToolCalls(`${filler}${suffix}`)).toEqual([]); + }); +}); diff --git a/packages/appkit/src/beta.ts b/packages/appkit/src/beta.ts index 57db86362..04e893bf3 100644 --- a/packages/appkit/src/beta.ts +++ b/packages/appkit/src/beta.ts @@ -4,4 +4,5 @@ // // The exports below are auto-generated from each plugin's manifest.json // "stability" field. See tools/generate-plugin-entries.ts. +export { DatabricksAdapter, parseTextToolCalls } from "./agents/databricks"; export * from "./plugins/beta-exports.generated"; diff --git a/packages/appkit/src/connectors/serving/client.ts b/packages/appkit/src/connectors/serving/client.ts index 886d2bb3f..83f065e69 100644 --- a/packages/appkit/src/connectors/serving/client.ts +++ b/packages/appkit/src/connectors/serving/client.ts @@ -1,8 +1,46 @@ -import type { serving, WorkspaceClient } from "@databricks/sdk-experimental"; +import type { + CancellationToken, + serving, + WorkspaceClient, +} from "@databricks/sdk-experimental"; +import { Context } from "@databricks/sdk-experimental"; import { createLogger } from "../../logging/logger"; const logger = createLogger("connectors:serving"); +/** + * Bridges {@link AbortSignal} to the SDK's {@link CancellationToken} so + * `apiClient.request` can abort the outbound HTTP request (and stop pulling + * the SSE body) when the agent run is cancelled. + */ +function cancellationTokenFromAbortSignal( + signal: AbortSignal, +): CancellationToken { + const listeners = new Set<() => void>(); + const fire = () => { + for (const cb of listeners) { + try { + cb(); + } catch { + // ignore listener failures — abort must stay best-effort + } + } + }; + signal.addEventListener("abort", fire, { passive: true }); + + return { + get isCancellationRequested() { + return signal.aborted; + }, + onCancellationRequested(callback: (e?: unknown) => unknown) { + listeners.add(callback as () => void); + if (signal.aborted) { + void callback(); + } + }, + }; +} + /** * Invokes a serving endpoint using the SDK's high-level query API. * Returns a typed QueryEndpointResponse. @@ -35,21 +73,31 @@ export async function stream( client: WorkspaceClient, endpointName: string, body: Record, + signal?: AbortSignal, ): Promise> { const { stream: _stream, ...cleanBody } = body; logger.debug("Streaming from endpoint %s", endpointName); - const response = (await client.apiClient.request({ - path: `/serving-endpoints/${encodeURIComponent(endpointName)}/invocations`, - method: "POST", - headers: new Headers({ - "Content-Type": "application/json", - Accept: "text/event-stream", - }), - payload: { ...cleanBody, stream: true }, - raw: true, - })) as { contents: ReadableStream }; + const context = signal + ? new Context({ + cancellationToken: cancellationTokenFromAbortSignal(signal), + }) + : undefined; + + const response = (await client.apiClient.request( + { + path: `/serving-endpoints/${encodeURIComponent(endpointName)}/invocations`, + method: "POST", + headers: new Headers({ + "Content-Type": "application/json", + Accept: "text/event-stream", + }), + payload: { ...cleanBody, stream: true }, + raw: true, + }, + context, + )) as { contents: ReadableStream }; if (!response.contents) { throw new Error("Response body is null — streaming not supported"); diff --git a/packages/appkit/src/connectors/serving/tests/client.test.ts b/packages/appkit/src/connectors/serving/tests/client.test.ts index 389585b04..d243621e0 100644 --- a/packages/appkit/src/connectors/serving/tests/client.test.ts +++ b/packages/appkit/src/connectors/serving/tests/client.test.ts @@ -1,3 +1,4 @@ +import { Context } from "@databricks/sdk-experimental"; import { afterEach, describe, expect, test, vi } from "vitest"; import { invoke, stream } from "../client"; @@ -109,6 +110,24 @@ describe("Serving Connector", () => { raw: true, payload: expect.objectContaining({ stream: true }), }), + undefined, + ); + }); + + test("passes SDK Context when AbortSignal is provided", async () => { + const client = createMockClient(); + client.apiClient.request.mockResolvedValue({ + contents: new ReadableStream(), + }); + + const controller = new AbortController(); + await stream(client, "my-endpoint", { messages: [] }, controller.signal); + + expect(client.apiClient.request).toHaveBeenCalledWith( + expect.objectContaining({ + path: "/serving-endpoints/my-endpoint/invocations", + }), + expect.any(Context), ); }); diff --git a/packages/shared/src/agent.ts b/packages/shared/src/agent.ts new file mode 100644 index 000000000..ef532c7c7 --- /dev/null +++ b/packages/shared/src/agent.ts @@ -0,0 +1,213 @@ +import type { JSONSchema7 } from "json-schema"; + +// --------------------------------------------------------------------------- +// Tool definitions +// --------------------------------------------------------------------------- + +export interface ToolAnnotations { + readOnly?: boolean; + destructive?: boolean; + idempotent?: boolean; + requiresUserContext?: boolean; +} + +export interface AgentToolDefinition { + name: string; + description: string; + parameters: JSONSchema7; + annotations?: ToolAnnotations; +} + +export interface ToolProvider { + getAgentTools(): AgentToolDefinition[]; + executeAgentTool( + name: string, + args: unknown, + signal?: AbortSignal, + ): Promise; +} + +// --------------------------------------------------------------------------- +// Messages & threads +// --------------------------------------------------------------------------- + +export interface Message { + id: string; + role: "user" | "assistant" | "system" | "tool"; + content: string; + toolCallId?: string; + toolCalls?: ToolCall[]; + createdAt: Date; +} + +export interface ToolCall { + id: string; + name: string; + args: unknown; +} + +export interface Thread { + id: string; + userId: string; + messages: Message[]; + createdAt: Date; + updatedAt: Date; +} + +// --------------------------------------------------------------------------- +// Thread store +// --------------------------------------------------------------------------- + +export interface ThreadStore { + create(userId: string): Promise; + get(threadId: string, userId: string): Promise; + list(userId: string): Promise; + addMessage(threadId: string, userId: string, message: Message): Promise; + delete(threadId: string, userId: string): Promise; +} + +// --------------------------------------------------------------------------- +// Agent events (SSE protocol) +// --------------------------------------------------------------------------- + +export type AgentEvent = + | { type: "message_delta"; content: string } + | { type: "message"; content: string } + | { type: "tool_call"; callId: string; name: string; args: unknown } + | { + type: "tool_result"; + callId: string; + result: unknown; + error?: string; + } + | { type: "thinking"; content: string } + | { + type: "status"; + status: "running" | "waiting" | "complete" | "error"; + error?: string; + } + | { type: "metadata"; data: Record }; + +// --------------------------------------------------------------------------- +// Responses API types (OpenAI-compatible wire format for HTTP boundary) +// Self-contained — no openai package dependency. +// --------------------------------------------------------------------------- + +export interface OutputTextContent { + type: "output_text"; + text: string; +} + +export interface ResponseOutputMessage { + type: "message"; + id: string; + status: "in_progress" | "completed"; + role: "assistant"; + content: OutputTextContent[]; +} + +export interface ResponseFunctionToolCall { + type: "function_call"; + id: string; + call_id: string; + name: string; + arguments: string; +} + +export interface ResponseFunctionCallOutput { + type: "function_call_output"; + id: string; + call_id: string; + output: string; +} + +export type ResponseOutputItem = + | ResponseOutputMessage + | ResponseFunctionToolCall + | ResponseFunctionCallOutput; + +export interface ResponseOutputItemAddedEvent { + type: "response.output_item.added"; + output_index: number; + item: ResponseOutputItem; + sequence_number: number; +} + +export interface ResponseOutputItemDoneEvent { + type: "response.output_item.done"; + output_index: number; + item: ResponseOutputItem; + sequence_number: number; +} + +export interface ResponseTextDeltaEvent { + type: "response.output_text.delta"; + item_id: string; + output_index: number; + content_index: number; + delta: string; + sequence_number: number; +} + +export interface ResponseCompletedEvent { + type: "response.completed"; + sequence_number: number; + response: Record; +} + +export interface ResponseErrorEvent { + type: "error"; + error: string; + sequence_number: number; +} + +export interface ResponseFailedEvent { + type: "response.failed"; + sequence_number: number; +} + +export interface AppKitThinkingEvent { + type: "appkit.thinking"; + content: string; + sequence_number: number; +} + +export interface AppKitMetadataEvent { + type: "appkit.metadata"; + data: Record; + sequence_number: number; +} + +export type ResponseStreamEvent = + | ResponseOutputItemAddedEvent + | ResponseOutputItemDoneEvent + | ResponseTextDeltaEvent + | ResponseCompletedEvent + | ResponseErrorEvent + | ResponseFailedEvent + | AppKitThinkingEvent + | AppKitMetadataEvent; + +// --------------------------------------------------------------------------- +// Adapter contract +// --------------------------------------------------------------------------- + +export interface AgentInput { + messages: Message[]; + tools: AgentToolDefinition[]; + threadId: string; + signal?: AbortSignal; +} + +export interface AgentRunContext { + /** Tool implementations should sanitize failure text — errors become `tool_result.error` and can flow back into the LLM transcript. */ + executeTool: (name: string, args: unknown) => Promise; + signal?: AbortSignal; +} + +export interface AgentAdapter { + run( + input: AgentInput, + context: AgentRunContext, + ): AsyncGenerator; +} diff --git a/packages/shared/src/index.ts b/packages/shared/src/index.ts index 627d70d6c..9829729a7 100644 --- a/packages/shared/src/index.ts +++ b/packages/shared/src/index.ts @@ -1,3 +1,4 @@ +export * from "./agent"; export * from "./cache"; export * from "./execute"; export * from "./genie"; diff --git a/pnpm-lock.yaml b/pnpm-lock.yaml index 684f6e2e4..c1d8f247e 100644 --- a/pnpm-lock.yaml +++ b/pnpm-lock.yaml @@ -5551,7 +5551,7 @@ packages: basic-ftp@5.0.5: resolution: {integrity: sha512-4Bcg1P8xhUuqcii/S0Z9wiHIrQVPMermM1any+MX5GeGD7faD3/msQUDGLol9wOcz4/jbg/WJnGqoJF6LiBdtg==} engines: {node: '>=10.0.0'} - deprecated: Security vulnerability fixed in 5.2.1, please upgrade + deprecated: Security vulnerability fixed in 5.2.0, please upgrade batch@0.6.1: resolution: {integrity: sha512-x+VAiMRL6UPkx+kudNvxTl6hB2XNNCG2r+7wixVfIYwu/2HKRXimwQyaumLjMveWvT2Hkd/cAJw+QBMfJ/EKVw==} @@ -6665,7 +6665,6 @@ packages: dottie@2.0.6: resolution: {integrity: sha512-iGCHkfUc5kFekGiqhe8B/mdaurD+lakO9txNnTvKtA6PISrw86LgqHvRzWYPyoE2Ph5aMIrCw9/uko6XHTKCwA==} - deprecated: Package no longer supported. Contact Support at https://www.npmjs.com/support for more info. drizzle-orm@0.45.1: resolution: {integrity: sha512-Te0FOdKIistGNPMq2jscdqngBRfBpC8uMFVwqjf6gtTVJHIQ/dosgV/CLBU2N4ZJBsXL5savCba9b0YJskKdcA==}