Skip to content

Commit

Permalink
bump chat-ui, update chat tools render
Browse files Browse the repository at this point in the history
  • Loading branch information
leehuwuj committed Feb 18, 2025
1 parent 8d3db71 commit ecfd352
Show file tree
Hide file tree
Showing 3 changed files with 106 additions and 49 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -51,9 +51,7 @@ async def content_generator(self):
event_response = event.to_response()
yield self.convert_data(event_response)
else:
yield self.convert_data(
{"type": "agent", "data": event.model_dump()}
)
yield self.convert_data(event.model_dump())

except asyncio.CancelledError:
logger.warning("Client cancelled the request!")
Expand Down
Original file line number Diff line number Diff line change
@@ -1,69 +1,128 @@
import {
Message,
MessageAnnotation,
getChatUIAnnotation,
getCustomAnnotation,
useChatMessage,
useChatUI,
} from "@llamaindex/chat-ui";
import { ChatEvents } from "@llamaindex/chat-ui/widgets";
import { JSONValue } from "ai";
import { useMemo } from "react";
import { z } from "zod";
import { Artifact, CodeArtifact } from "./artifact";
import { WeatherCard, WeatherData } from "./weather-card";

const ToolCallSchema = z.object({
tool_name: z.string(),
tool_kwargs: z.record(z.unknown()),
tool_id: z.string(),
tool_output: z.optional(
z
.object({
content: z.string(),
tool_name: z.string(),
raw_input: z.object({
args: z.array(z.unknown()),
kwargs: z.record(z.unknown()),
}),
raw_output: z.record(z.unknown()),
is_error: z.boolean().optional(),
})
.optional(),
),
return_direct: z.boolean().optional(),
});

type ToolCallEvent = z.infer<typeof ToolCallSchema>;

type GroupedToolCall = {
initial: ToolCallEvent;
output?: ToolCallEvent;
};

export function ToolAnnotations() {
// TODO: This is a bit of a hack to get the artifact version. better to generate the version in the tool call and
// store it in CodeArtifact
const { messages } = useChatUI();
const { message } = useChatMessage();
const artifactVersion = useMemo(
() => getArtifactVersion(messages, message),
[messages, message],
);
// Get the tool data from the message annotations
const annotations = message.annotations as MessageAnnotation[] | undefined;
const toolData = annotations
? (getChatUIAnnotation(annotations, "tools") as unknown as ToolData[])
: null;
return toolData?.[0] ? (
<ChatTools data={toolData[0]} artifactVersion={artifactVersion} />
) : null;
}

// TODO: Used to render outputs of tools. If needed, add more renderers here.
function ChatTools({
data,
artifactVersion,
}: {
data: ToolData;
artifactVersion: number | undefined;
}) {
if (!data) return null;
const { toolCall, toolOutput } = data;
const toolCallEvents = getCustomAnnotation<ToolCallEvent>(
message.annotations,
(annotation) => {
const result = ToolCallSchema.safeParse(annotation);
return result.success;
},
);

// Group tool calls by tool_id - we just need to take the latest event for each tool_id
const groupedToolCalls = useMemo(() => {
const groups = new Map<string, GroupedToolCall>();

if (toolOutput.isError) {
return (
<div className="border-l-2 border-red-400 pl-2">
There was an error when calling the tool {toolCall.name} with input:{" "}
<br />
{JSON.stringify(toolCall.input)}
</div>
);
}
toolCallEvents?.forEach((event) => {
groups.set(event.tool_id, { initial: event });
});

switch (toolCall.name) {
case "get_weather_information":
const weatherData = toolOutput.output as unknown as WeatherData;
return <WeatherCard data={weatherData} />;
case "artifact":
return (
<Artifact
artifact={toolOutput.output as CodeArtifact}
version={artifactVersion}
/>
);
default:
return null;
}
return Array.from(groups.values());
}, [toolCallEvents]);

return (
<div className="space-y-4">
{groupedToolCalls.map(({ initial }) => {
switch (initial.tool_name) {
case "query_index": {
const query = initial.tool_kwargs.input;
const eventData = [
{
title: initial.tool_output
? `Got ${JSON.stringify((initial.tool_output?.raw_output as any).source_nodes?.length ?? 0)} sources for query: ${query}`
: `Searching information for query: ${query}`,
},
];

return (
<ChatEvents
key={initial.tool_id}
data={eventData}
showLoading={!initial.tool_output}
/>
);
}
case "get_weather_information": {
if (!initial.tool_output)
return (
<ChatEvents
key={initial.tool_id}
data={[
{
title: `Getting weather information for ${initial.tool_kwargs.location}`,
},
]}
showLoading={false}
/>
);
const weatherData = initial.tool_output
?.raw_output as unknown as WeatherData;
return <WeatherCard key={initial.tool_id} data={weatherData} />;
}
case "artifact": {
const artifact = initial.tool_output
?.content as unknown as CodeArtifact;
return (
<Artifact
key={initial.tool_id}
artifact={artifact}
version={artifactVersion}
/>
);
}
default:
return null;
}
})}
</div>
);
}

type ToolData = {
Expand Down
2 changes: 1 addition & 1 deletion templates/types/streaming/nextjs/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
"@radix-ui/react-select": "^2.1.1",
"@radix-ui/react-slot": "^1.0.2",
"@radix-ui/react-tabs": "^1.1.0",
"@llamaindex/chat-ui": "0.0.14",
"@llamaindex/chat-ui": "0.1.0",
"ai": "^4.0.3",
"ajv": "^8.12.0",
"class-variance-authority": "^0.7.1",
Expand Down

0 comments on commit ecfd352

Please sign in to comment.