diff --git a/templates/types/streaming/fastapi/app/api/routers/vercel_response.py b/templates/types/streaming/fastapi/app/api/routers/vercel_response.py index 339ca995..83952b21 100644 --- a/templates/types/streaming/fastapi/app/api/routers/vercel_response.py +++ b/templates/types/streaming/fastapi/app/api/routers/vercel_response.py @@ -51,9 +51,7 @@ async def content_generator(self): event_response = event.to_response() yield self.convert_data(event_response) else: - yield self.convert_data( - {"type": "agent", "data": event.model_dump()} - ) + yield self.convert_data(event.model_dump()) except asyncio.CancelledError: logger.warning("Client cancelled the request!") diff --git a/templates/types/streaming/nextjs/app/components/ui/chat/tools/chat-tools.tsx b/templates/types/streaming/nextjs/app/components/ui/chat/tools/chat-tools.tsx index a7e4eb21..4826274b 100644 --- a/templates/types/streaming/nextjs/app/components/ui/chat/tools/chat-tools.tsx +++ b/templates/types/streaming/nextjs/app/components/ui/chat/tools/chat-tools.tsx @@ -1,69 +1,128 @@ import { Message, - MessageAnnotation, getChatUIAnnotation, + getCustomAnnotation, useChatMessage, useChatUI, } from "@llamaindex/chat-ui"; +import { ChatEvents } from "@llamaindex/chat-ui/widgets"; import { JSONValue } from "ai"; import { useMemo } from "react"; +import { z } from "zod"; import { Artifact, CodeArtifact } from "./artifact"; import { WeatherCard, WeatherData } from "./weather-card"; +const ToolCallSchema = z.object({ + tool_name: z.string(), + tool_kwargs: z.record(z.unknown()), + tool_id: z.string(), + tool_output: z.optional( + z + .object({ + content: z.string(), + tool_name: z.string(), + raw_input: z.object({ + args: z.array(z.unknown()), + kwargs: z.record(z.unknown()), + }), + raw_output: z.record(z.unknown()), + is_error: z.boolean().optional(), + }) + .optional(), + ), + return_direct: z.boolean().optional(), +}); + +type ToolCallEvent = z.infer; + +type GroupedToolCall = { + initial: ToolCallEvent; + output?: ToolCallEvent; +}; + export function ToolAnnotations() { - // TODO: This is a bit of a hack to get the artifact version. better to generate the version in the tool call and - // store it in CodeArtifact const { messages } = useChatUI(); const { message } = useChatMessage(); const artifactVersion = useMemo( () => getArtifactVersion(messages, message), [messages, message], ); - // Get the tool data from the message annotations - const annotations = message.annotations as MessageAnnotation[] | undefined; - const toolData = annotations - ? (getChatUIAnnotation(annotations, "tools") as unknown as ToolData[]) - : null; - return toolData?.[0] ? ( - - ) : null; -} -// TODO: Used to render outputs of tools. If needed, add more renderers here. -function ChatTools({ - data, - artifactVersion, -}: { - data: ToolData; - artifactVersion: number | undefined; -}) { - if (!data) return null; - const { toolCall, toolOutput } = data; + const toolCallEvents = getCustomAnnotation( + message.annotations, + (annotation) => { + const result = ToolCallSchema.safeParse(annotation); + return result.success; + }, + ); + + // Group tool calls by tool_id - we just need to take the latest event for each tool_id + const groupedToolCalls = useMemo(() => { + const groups = new Map(); - if (toolOutput.isError) { - return ( -
- There was an error when calling the tool {toolCall.name} with input:{" "} -
- {JSON.stringify(toolCall.input)} -
- ); - } + toolCallEvents?.forEach((event) => { + groups.set(event.tool_id, { initial: event }); + }); - switch (toolCall.name) { - case "get_weather_information": - const weatherData = toolOutput.output as unknown as WeatherData; - return ; - case "artifact": - return ( - - ); - default: - return null; - } + return Array.from(groups.values()); + }, [toolCallEvents]); + + return ( +
+ {groupedToolCalls.map(({ initial }) => { + switch (initial.tool_name) { + case "query_index": { + const query = initial.tool_kwargs.input; + const eventData = [ + { + title: initial.tool_output + ? `Got ${JSON.stringify((initial.tool_output?.raw_output as any).source_nodes?.length ?? 0)} sources for query: ${query}` + : `Searching information for query: ${query}`, + }, + ]; + + return ( + + ); + } + case "get_weather_information": { + if (!initial.tool_output) + return ( + + ); + const weatherData = initial.tool_output + ?.raw_output as unknown as WeatherData; + return ; + } + case "artifact": { + const artifact = initial.tool_output + ?.content as unknown as CodeArtifact; + return ( + + ); + } + default: + return null; + } + })} +
+ ); } type ToolData = { diff --git a/templates/types/streaming/nextjs/package.json b/templates/types/streaming/nextjs/package.json index 48c43c12..355fd4df 100644 --- a/templates/types/streaming/nextjs/package.json +++ b/templates/types/streaming/nextjs/package.json @@ -17,7 +17,7 @@ "@radix-ui/react-select": "^2.1.1", "@radix-ui/react-slot": "^1.0.2", "@radix-ui/react-tabs": "^1.1.0", - "@llamaindex/chat-ui": "0.0.14", + "@llamaindex/chat-ui": "0.1.0", "ai": "^4.0.3", "ajv": "^8.12.0", "class-variance-authority": "^0.7.1",