Skip to content

Commit

Permalink
load messages in client side (#2942)
Browse files Browse the repository at this point in the history
  • Loading branch information
notmd authored Apr 28, 2023
1 parent 268ea95 commit c7219dc
Show file tree
Hide file tree
Showing 3 changed files with 36 additions and 23 deletions.
7 changes: 3 additions & 4 deletions website/src/components/Chat/ChatContext.tsx
Original file line number Diff line number Diff line change
@@ -1,17 +1,16 @@
import { createContext, PropsWithChildren, useContext, useMemo } from "react";
import { InferenceMessage, ModelInfo } from "src/types/Chat";
import { ModelInfo } from "src/types/Chat";

export type ChatContext = {
modelInfos: ModelInfo[];
messages: InferenceMessage[];
};

const chatContext = createContext<ChatContext>({} as ChatContext);

export const useChatContext = () => useContext(chatContext);

export const ChatContextProvider = ({ children, modelInfos, messages }: PropsWithChildren<ChatContext>) => {
const value = useMemo(() => ({ modelInfos, messages }), [messages, modelInfos]);
export const ChatContextProvider = ({ children, modelInfos }: PropsWithChildren<ChatContext>) => {
const value = useMemo(() => ({ modelInfos }), [modelInfos]);

return <chatContext.Provider value={value}>{children}</chatContext.Provider>;
};
41 changes: 29 additions & 12 deletions website/src/components/Chat/ChatConversation.tsx
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/* eslint-disable @typescript-eslint/no-explicit-any */
import { Box, useBoolean, useToast } from "@chakra-ui/react";
import { Box, CircularProgress, useBoolean, useToast } from "@chakra-ui/react";
import { KeyboardEvent, memo, useCallback, useEffect, useMemo, useRef, useState } from "react";
import { UseFormGetValues } from "react-hook-form";
import SimpleBar from "simplebar-react";
Expand All @@ -9,13 +9,14 @@ import { handleChatEventStream, QueueInfo } from "src/lib/chat_stream";
import { API_ROUTES } from "src/lib/routes";
import {
ChatConfigFormData,
ChatItem,
InferenceMessage,
InferencePostAssistantMessageParams,
InferencePostPrompterMessageParams,
} from "src/types/Chat";
import { mutate } from "swr";
import useSWR from "swr";

import { useChatContext } from "./ChatContext";
import { ChatConversationTree, LAST_ASSISTANT_MESSAGE_ID } from "./ChatConversationTree";
import { ChatForm } from "./ChatForm";
import { ChatMessageEntryProps, EditPromptParams, PendingMessageEntry } from "./ChatMessageEntry";
Expand All @@ -28,12 +29,25 @@ interface ChatConversationProps {

export const ChatConversation = memo(function ChatConversation({ chatId, getConfigValues }: ChatConversationProps) {
const inputRef = useRef<HTMLTextAreaElement>(null);
const [messages, setMessages] = useState<InferenceMessage[]>(useChatContext().messages);
const [messages, setMessages] = useState<InferenceMessage[]>([]);

const [streamedResponse, setResponse] = useState<string | null>(null);
const [queueInfo, setQueueInfo] = useState<QueueInfo | null>(null);
const [isSending, setIsSending] = useBoolean();

const toast = useToast();
const { isLoading: isLoadingMessages } = useSWR<ChatItem>(chatId ? API_ROUTES.GET_CHAT(chatId) : null, get, {
onSuccess(data) {
setMessages(data.messages.sort((a, b) => Date.parse(a.created_at) - Date.parse(b.created_at)));
},
onError: () => {
toast({
title: "Failed to load chat",
status: "error",
});
},
});

const createAndFetchAssistantMessage = useCallback(
async ({ parentId, chatId }: { parentId: string; chatId: string }) => {
const { model_config_name, ...sampling_parameters } = getConfigValues();
Expand Down Expand Up @@ -77,7 +91,6 @@ export const ChatConversation = memo(function ChatConversation({ chatId, getConf
},
[getConfigValues, setIsSending]
);
const toast = useToast();
const sendPrompterMessage = useCallback(async () => {
const content = inputRef.current?.value.trim();
if (!content || isSending) {
Expand Down Expand Up @@ -201,7 +214,7 @@ export const ChatConversation = memo(function ChatConversation({ chatId, getConf
[createAndFetchAssistantMessage, isSending, setIsSending]
);

const { messagesEndRef, scrollableNodeProps } = useAutoScroll(messages, streamedResponse);
const { messagesEndRef, scrollableNodeProps, updateEnableAutoScroll } = useAutoScroll(messages, streamedResponse);

return (
<Box
Expand All @@ -220,7 +233,9 @@ export const ChatConversation = memo(function ChatConversation({ chatId, getConf
bg: "blackAlpha.300",
}}
>
{isLoadingMessages && <CircularProgress isIndeterminate size="20px" mx="auto" />}
<SimpleBar
onMouseDown={updateEnableAutoScroll}
scrollableNodeProps={scrollableNodeProps}
style={{ maxHeight: "100%", height: "100%", minHeight: "0" }}
classNames={{
Expand Down Expand Up @@ -248,13 +263,13 @@ const useAutoScroll = (messages: InferenceMessage[], streamedResponse: string) =
const enableAutoScroll = useRef(true);
const messagesEndRef = useRef<HTMLDivElement>(null);
const chatContainerRef = useRef<HTMLDivElement>(null);
const handleOnScroll = useCallback(() => {
const updateEnableAutoScroll = useCallback(() => {
const container = chatContainerRef.current;
if (!container) {
return;
}

const isEnable = Math.abs(container.scrollHeight - container.scrollTop - container.clientHeight) < 10;
const isEnable = container.scrollHeight - container.scrollTop - container.clientHeight < 10;
enableAutoScroll.current = isEnable;
}, []);

Expand All @@ -269,16 +284,18 @@ const useAutoScroll = (messages: InferenceMessage[], streamedResponse: string) =
const scrollableNodeProps = useMemo(
() => ({
ref: chatContainerRef,
onWheel: handleOnScroll,
onWheel: updateEnableAutoScroll,
// onScroll: handleOnScroll,
onKeyDown: (e: KeyboardEvent) => {
if (e.key === "ArrowUp" || e.key === "ArrowDown") {
handleOnScroll();
updateEnableAutoScroll();
}
},
onTouchMove: handleOnScroll,
onTouchMove: updateEnableAutoScroll,
onMouseDown: updateEnableAutoScroll(),
}),
[handleOnScroll]
[updateEnableAutoScroll]
);

return { messagesEndRef, scrollableNodeProps };
return { messagesEndRef, scrollableNodeProps, updateEnableAutoScroll };
};
11 changes: 4 additions & 7 deletions website/src/pages/chat/[id].tsx
Original file line number Diff line number Diff line change
Expand Up @@ -8,15 +8,14 @@ import { ChatContextProvider } from "src/components/Chat/ChatContext";
import { ChatSection } from "src/components/Chat/ChatSection";
import { getChatLayout } from "src/components/Layout/ChatLayout";
import { createInferenceClient } from "src/lib/oasst_inference_client";
import { InferenceMessage, ModelInfo } from "src/types/Chat";
import { ModelInfo } from "src/types/Chat";

interface ChatProps {
id: string;
modelInfos: ModelInfo[];
messages: InferenceMessage[];
}

const Chat = ({ id, modelInfos, messages }: ChatProps) => {
const Chat = ({ id, modelInfos }: ChatProps) => {
const { t } = useTranslation(["common", "chat"]);

return (
Expand All @@ -25,7 +24,7 @@ const Chat = ({ id, modelInfos, messages }: ChatProps) => {
<title>{t("chat")}</title>
</Head>

<ChatContextProvider modelInfos={modelInfos} messages={messages}>
<ChatContextProvider modelInfos={modelInfos}>
<ChatSection chatId={id} />
</ChatContextProvider>
</>
Expand All @@ -38,7 +37,6 @@ export const getServerSideProps: GetServerSideProps<ChatProps, { id: string }> =
locale = "en",
params,
req,
query,
}) => {
if (!boolean(process.env.ENABLE_CHAT)) {
return {
Expand All @@ -48,13 +46,12 @@ export const getServerSideProps: GetServerSideProps<ChatProps, { id: string }> =

const token = await getToken({ req });
const client = createInferenceClient(token!);
const [modelInfos, chat] = await Promise.all([client.get_models(), client.get_chat(query.id as string)]);
const modelInfos = await client.get_models();

return {
props: {
id: params!.id,
modelInfos,
messages: chat.messages.sort((a, b) => Date.parse(a.created_at) - Date.parse(b.created_at)),
...(await serverSideTranslations(locale)),
},
};
Expand Down

0 comments on commit c7219dc

Please sign in to comment.