Skip to content

Commit

Permalink
Handle chat error (#2959)
Browse files Browse the repository at this point in the history
Handle error when create prompter or assistant message
  • Loading branch information
notmd authored Apr 28, 2023
1 parent c7219dc commit e82a717
Show file tree
Hide file tree
Showing 4 changed files with 55 additions and 16 deletions.
4 changes: 4 additions & 0 deletions inference/server/oasst_inference_server/routes/chats.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,8 @@ async def create_prompter_message(
chat_id=chat_id, parent_id=request.parent_id, content=request.content
)
return prompter_message.to_read()
except fastapi.HTTPException:
raise
except Exception:
logger.exception("Error adding prompter message")
return fastapi.Response(status_code=500)
Expand Down Expand Up @@ -154,6 +156,8 @@ async def create_assistant_message(
status_code=fastapi.status.HTTP_503_SERVICE_UNAVAILABLE,
detail="The server is currently busy. Please try again later.",
)
except fastapi.HTTPException:
raise
except Exception:
logger.exception("Error adding prompter message")
return fastapi.Response(status_code=500)
Expand Down
41 changes: 33 additions & 8 deletions website/src/components/Chat/ChatConversation.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import SimpleBar from "simplebar-react";
import { useMessageVote } from "src/hooks/chat/useMessageVote";
import { get, post } from "src/lib/api";
import { handleChatEventStream, QueueInfo } from "src/lib/chat_stream";
import { OasstError } from "src/lib/oasst_api_client";
import { API_ROUTES } from "src/lib/routes";
import {
ChatConfigFormData,
Expand Down Expand Up @@ -34,8 +35,8 @@ export const ChatConversation = memo(function ChatConversation({ chatId, getConf
const [streamedResponse, setResponse] = useState<string | null>(null);
const [queueInfo, setQueueInfo] = useState<QueueInfo | null>(null);
const [isSending, setIsSending] = useBoolean();

const toast = useToast();

const { isLoading: isLoadingMessages } = useSWR<ChatItem>(chatId ? API_ROUTES.GET_CHAT(chatId) : null, get, {
onSuccess(data) {
setMessages(data.messages.sort((a, b) => Date.parse(a.created_at) - Date.parse(b.created_at)));
Expand All @@ -58,9 +59,21 @@ export const ChatConversation = memo(function ChatConversation({ chatId, getConf
sampling_parameters,
};

const assistant_message: InferenceMessage = await post(API_ROUTES.CREATE_ASSISTANT_MESSAGE, {
arg: assistant_arg,
});
let assistant_message: InferenceMessage;
try {
assistant_message = await post(API_ROUTES.CREATE_ASSISTANT_MESSAGE, {
arg: assistant_arg,
});
} catch (e) {
if (e instanceof OasstError) {
toast({
title: e.message,
status: "error",
});
}
setIsSending.off();
return;
}

// we have to do this manually since we want to stream the chunks
// there is also EventSource, but it is callback based
Expand Down Expand Up @@ -89,7 +102,7 @@ export const ChatConversation = memo(function ChatConversation({ chatId, getConf
setResponse(null);
setIsSending.off();
},
[getConfigValues, setIsSending]
[getConfigValues, setIsSending, toast]
);
const sendPrompterMessage = useCallback(async () => {
const content = inputRef.current?.value.trim();
Expand Down Expand Up @@ -120,7 +133,19 @@ export const ChatConversation = memo(function ChatConversation({ chatId, getConf
parent_id: parentId,
};

const prompter_message: InferenceMessage = await post(API_ROUTES.CREATE_PROMPTER_MESSAGE, { arg: prompter_arg });
let prompter_message: InferenceMessage;
try {
prompter_message = await post(API_ROUTES.CREATE_PROMPTER_MESSAGE, { arg: prompter_arg });
} catch (e) {
if (e instanceof OasstError) {
toast({
title: e.message,
status: "error",
});
}
setIsSending.off();
return;
}
if (messages.length === 0) {
// revalidate chat list after creating the first prompter message to make sure the message already has title
mutate(API_ROUTES.LIST_CHAT);
Expand Down Expand Up @@ -259,7 +284,7 @@ export const ChatConversation = memo(function ChatConversation({ chatId, getConf
);
});

const useAutoScroll = (messages: InferenceMessage[], streamedResponse: string) => {
const useAutoScroll = (messages: InferenceMessage[], streamedResponse: string | null) => {
const enableAutoScroll = useRef(true);
const messagesEndRef = useRef<HTMLDivElement>(null);
const chatContainerRef = useRef<HTMLDivElement>(null);
Expand All @@ -278,7 +303,7 @@ const useAutoScroll = (messages: InferenceMessage[], streamedResponse: string) =
return;
}

messagesEndRef.current.scrollIntoView({ behavior: "smooth" });
messagesEndRef.current?.scrollIntoView({ behavior: "smooth" });
}, [messages, streamedResponse]);

const scrollableNodeProps = useMemo(
Expand Down
13 changes: 9 additions & 4 deletions website/src/pages/api/chat/assistant_message.ts
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import { AxiosError } from "axios";
import { withoutRole } from "src/lib/auth";
import { isChatEnable } from "src/lib/isChatEnable";
import { createInferenceClient } from "src/lib/oasst_inference_client";
Expand All @@ -9,12 +10,16 @@ const handler = withoutRole("banned", async (req, res, token) => {
}
const client = createInferenceClient(token);

const data = await client.post_assistant_message(req.body as InferencePostAssistantMessageParams);

if (data) {
try {
const data = await client.post_assistant_message(req.body as InferencePostAssistantMessageParams);
return res.status(200).json(data);
} catch (e) {
if (!(e instanceof AxiosError)) {
return res.status(500).end();
}
console.log(e);
return res.status(e.response?.status ?? 500).json({ message: e.response?.data.detail ?? "Something went wrong" });
}
res.status(400).end();
});

export default handler;
13 changes: 9 additions & 4 deletions website/src/pages/api/chat/prompter_message.ts
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import { AxiosError } from "axios";
import { withoutRole } from "src/lib/auth";
import { isChatEnable } from "src/lib/isChatEnable";
import { createInferenceClient } from "src/lib/oasst_inference_client";
Expand All @@ -9,12 +10,16 @@ const handler = withoutRole("banned", async (req, res, token) => {
}
const client = createInferenceClient(token);

const data = await client.post_prompter_message(req.body as InferencePostPrompterMessageParams);

if (data) {
try {
const data = await client.post_prompter_message(req.body as InferencePostPrompterMessageParams);
return res.status(200).json(data);
} catch (e) {
if (!(e instanceof AxiosError)) {
return res.status(500).end();
}
console.log(e);
return res.status(e.response?.status ?? 500).json({ message: e.response?.data.detail ?? "Something went wrong" });
}
res.status(400).end();
});

export default handler;

0 comments on commit e82a717

Please sign in to comment.