Skip to content

feat(chat): add cancel button in chat to stop generation #36

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions src/common/chat/cloudChat.ts
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import { configuration } from "../utils/configuration";
type Parameters = {
temperature: number;
n_predict: number;
controller?: AbortController;
};

export const sendChatRequestCloud = async (
Expand Down Expand Up @@ -45,6 +46,9 @@ export const sendChatRequestCloud = async (
});

const stream = await model.pipe(parser).stream(messages, {
configurable: {
signal: parameters.controller,
},
maxConcurrency: 1,
});

Expand Down
2 changes: 2 additions & 0 deletions src/common/chat/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ export async function* chat(
history: ChatMessage[],
config?: {
provideHighlightedText?: boolean;
abortController: AbortController;
}
) {
const loggerCompletion = logCompletion();
Expand All @@ -36,6 +37,7 @@ export async function* chat(
n_predict: 4096,
stop: [],
temperature: 0.7,
controller: config?.abortController,
};

const { stopTask } = statusBar.startTask();
Expand Down
5 changes: 4 additions & 1 deletion src/common/chat/localChat.ts
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,10 @@ export async function* sendChatRequestLocal(
const startTime = performance.now();

let timings;
for await (const chunk of llama(prompt, parametersForCompletion, { url })) {
for await (const chunk of llama(prompt, parametersForCompletion, {
url,
controller: parameters.controller,
})) {
// @ts-ignore
if (chunk.data) {
// @ts-ignore
Expand Down
72 changes: 54 additions & 18 deletions src/common/panel/chat.ts
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
import { Disposable, Webview, window, Uri } from "vscode";
import { Disposable, Webview, Uri } from "vscode";
import * as vscode from "vscode";
import { getUri } from "../utils/getUri";
import { getNonce } from "../utils/getNonce";
import { chat } from "../chat";
import { ChatMessage } from "../prompt/promptChat";

export type MessageType =
| {
Expand All @@ -13,15 +14,16 @@ export type MessageType =
}
| {
type: "e2w-response";
id: string;
command: string;
messageId: string;
done: boolean;
data: any;
};

export class ChatPanel implements vscode.WebviewViewProvider {
private disposables: Disposable[] = [];
private webview: Webview | undefined;
private messageCallback: Record<string, any> = {};

constructor(private readonly extensionUri: vscode.Uri) {}

Expand Down Expand Up @@ -94,26 +96,19 @@ export class ChatPanel implements vscode.WebviewViewProvider {
private setWebviewMessageListener(webview: Webview) {
webview.onDidReceiveMessage(
async (message: any) => {
const sendResponse = (messageToResponse: any, done: boolean) => {
this.postMessage({
type: "e2w-response",
command: message.type,
messageId: message.messageId,
data: messageToResponse,
done: done,
});
};
if (message.type in this.messageCallback) {
this.messageCallback[message.type]();
return;
}
const type = message.type;
const data = message.data;

switch (type) {
case "sendMessage":
for await (const message of chat(data, {
provideHighlightedText: true,
})) {
sendResponse(message, false);
}
sendResponse("", true);
await this.handleStartGeneration({
chatMessage: message.data,
messageId: message.messageId,
messageType: message.type,
});
return;
}
},
Expand All @@ -122,6 +117,47 @@ export class ChatPanel implements vscode.WebviewViewProvider {
);
}

private addMessageListener(
commandOrMessageId: string,
callback: (message: any) => void
) {
this.messageCallback[commandOrMessageId] = callback;
}

private async handleStartGeneration({
messageId,
messageType,
chatMessage,
}: {
messageId: string;
messageType: string;
chatMessage: ChatMessage[];
}) {
const sendResponse = (messageToResponse: any, done: boolean) => {
this.postMessage({
type: "e2w-response",
id: messageId,
command: messageType,
data: messageToResponse,
done: done,
});
};
const abortController = new AbortController();

this.addMessageListener("abort-generate", () => {
abortController.abort();
});

for await (const message of chat(chatMessage, {
provideHighlightedText: true,
abortController,
})) {
sendResponse(message, false);
}

sendResponse("", true);
}

public async sendMessageToWebview(
command: MessageType["command"],
data: MessageType["data"]
Expand Down
8 changes: 3 additions & 5 deletions webviews/src/App.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ export const App = () => {
input,
setInput,
startNewChat,
stop,
} = useChat();

useMessageListener("startNewChat", () => {
Expand Down Expand Up @@ -50,14 +51,11 @@ export const App = () => {
buttonEnd={
<VSCodeButton
appearance="icon"
disabled={isLoading}
onClick={handleSubmit}
onClick={isLoading ? stop : handleSubmit}
>
<span
className={`codicon ${
isLoading
? "codicon-loading codicon-modifier-spin codicon-modifier-disabled"
: "codicon-send"
isLoading ? "codicon-debug-stop" : "codicon-send"
}`}
></span>
</VSCodeButton>
Expand Down
3 changes: 1 addition & 2 deletions webviews/src/components/TextArea/index.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ const TextArea = ({
onSubmit();
}

event.preventDefault(); // Prevents the addition of a new line in the text field
event.preventDefault();
}
}}
></textarea>
Expand All @@ -63,5 +63,4 @@ const TextArea = ({
);
};

// 24 42 61
export default TextArea;
87 changes: 45 additions & 42 deletions webviews/src/hooks/useChat.ts
Original file line number Diff line number Diff line change
@@ -1,26 +1,58 @@
import { useCallback, useState } from "react";
import { useCallback, useRef, useState } from "react";
import { randomMessageId } from "../utilities/messageId";
import { vscode } from "../utilities/vscode";

export type ChatMessage = {
role: string;
content: string;
chatMessageId: string;
};

export const useChat = () => {
const [chatMessages, setChatMessages] = useState<
{
role: string;
content: string;
chatMessageId: string;
}[]
>([]);
const [chatMessages, setChatMessages] = useState<ChatMessage[]>([]);

const [input, setInput] = useState("");
const [isLoading, setIsLoading] = useState(false);

const abortController = useRef(new AbortController());

const sendMessage = async (chatHistoryLocal: ChatMessage[]) => {
const messageId = randomMessageId();
for await (const newMessage of vscode.startGeneration(chatHistoryLocal, {
signal: abortController.current.signal,
})) {
setChatMessages((chatHistoryLocal) => {
const messages = chatHistoryLocal.filter(
(message) => message.chatMessageId !== messageId
);

const currentChatMessage = chatHistoryLocal.find(
(message) => message.chatMessageId === messageId
);

return [
...messages,
{
role: "ai",
content: (currentChatMessage?.content || "") + newMessage,
chatMessageId: messageId,
},
];
});
}
setIsLoading(false);
};

const handleSubmit = () => {
if (isLoading) {
return;
}
if (input === "") {
return;
}
if (abortController.current.signal.aborted) {
abortController.current = new AbortController();
}

setChatMessages((value) => {
const messageId = randomMessageId();
Expand All @@ -41,40 +73,10 @@ export const useChat = () => {
setInput("");
};

const sendMessage = async (chatHistoryLocal: any) => {
const messageId = randomMessageId();
await vscode.postMessageCallback(
{
type: "sendMessage",
data: chatHistoryLocal,
},
(newMessage) => {
setChatMessages((chatHistoryLocal) => {
const messages = chatHistoryLocal.filter(
(message) => message.chatMessageId !== messageId
);

const currentChatMessage = chatHistoryLocal.find(
(message) => message.chatMessageId === messageId
);

if (newMessage.done) {
setIsLoading(false);
return chatHistoryLocal;
}

return [
...messages,
{
role: "ai",
content: (currentChatMessage?.content || "") + newMessage.data,
chatMessageId: messageId,
},
];
});
}
);
};
const stop = useCallback(() => {
abortController.current.abort();
setIsLoading(false);
}, [abortController]);

const startNewChat = useCallback(() => {
setChatMessages([]);
Expand All @@ -87,5 +89,6 @@ export const useChat = () => {
setInput,
handleSubmit,
startNewChat,
stop,
};
};
30 changes: 30 additions & 0 deletions webviews/src/utilities/transformCallback2AsyncGenerator.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
export class Transform<T> {
private open = true;
private queue: T[] = [];
private resolve: (() => void) | undefined;

async *stream(): AsyncGenerator<T> {
this.open = true;

while (this.open) {
if (this.queue.length) {
yield this.queue.shift()!;
continue;
}

await new Promise<void>((resolveLocal) => {
this.resolve = resolveLocal;
});
}
}

push(data: T): void {
this.queue.push(data);
this.resolve?.();
}

close(): void {
this.open = false;
this.resolve?.();
}
}
Loading