Skip to content

Commit

Permalink
do not ask confirmation for new chat if no interaction happened (#605)
Browse files Browse the repository at this point in the history
* do not ask confirmation for new chat if no interaction happened

* fix test
  • Loading branch information
willydouhard authored Dec 19, 2023
1 parent 8dc1a76 commit 5f8c1b6
Show file tree
Hide file tree
Showing 11 changed files with 83 additions and 45 deletions.
2 changes: 1 addition & 1 deletion backend/chainlit/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def decorator(method):
async def wrapper(self, *args, **kwargs):
if (
isinstance(context.session, WebsocketSession)
and not context.session.has_user_message
and not context.session.has_first_interaction
):
# Queue the method invocation waiting for the first user message
queues = context.session.thread_queues
Expand Down
35 changes: 18 additions & 17 deletions backend/chainlit/emitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,8 +64,7 @@ async def clear_ask(self):
"""Stub method to clear the prompt from the UI."""
pass

async def init_thread(self, step_dict: StepDict):
"""Signal the UI that a new thread (with a user message) exists"""
async def init_thread(self, interaction: str):
pass

async def process_user_message(self, payload: UIMessagePayload) -> Message:
Expand Down Expand Up @@ -167,7 +166,7 @@ def clear_ask(self):

return self.emit("clear_ask", {})

async def flush_thread_queues(self, name: str):
async def flush_thread_queues(self, interaction: str):
if data_layer := get_data_layer():
if isinstance(self.session.user, PersistedUser):
user_id = self.session.user.id
Expand All @@ -176,14 +175,13 @@ async def flush_thread_queues(self, name: str):
await data_layer.update_thread(
thread_id=self.session.thread_id,
user_id=user_id,
metadata={"name": name},
metadata={"name": interaction},
)
await self.session.flush_method_queue()

async def init_thread(self, step: StepDict):
"""Signal the UI that a new thread (with a user message) exists"""
await self.flush_thread_queues(name=step["output"])
await self.emit("init_thread", step)
async def init_thread(self, interaction: str):
await self.flush_thread_queues(interaction)
await self.emit("first_interaction", interaction)

async def process_user_message(self, payload: UIMessagePayload):
step_dict = payload["message"]
Expand All @@ -197,9 +195,9 @@ async def process_user_message(self, payload: UIMessagePayload):

asyncio.create_task(message._create())

if not self.session.has_user_message:
self.session.has_user_message = True
asyncio.create_task(self.init_thread(message.to_dict()))
if not self.session.has_first_interaction:
self.session.has_first_interaction = True
asyncio.create_task(self.init_thread(message.content))

if file_refs:
files = [
Expand Down Expand Up @@ -239,11 +237,13 @@ async def send_ask_user(
] = None

if user_res:
interaction = None
if spec.type == "text":
message_dict_res = cast(StepDict, user_res)
await self.process_user_message(
{"message": message_dict_res, "fileReferences": None}
)
interaction = message_dict_res["output"]
final_res = message_dict_res
elif spec.type == "file":
file_refs = cast(List[FileReference], user_res)
Expand All @@ -253,12 +253,7 @@ async def send_ask_user(
if file["id"] in self.session.files
]
final_res = files
if not self.session.has_user_message:
self.session.has_user_message = True
await self.flush_thread_queues(
name=",".join([file["name"] for file in files])
)

interaction = ",".join([file["name"] for file in files])
if get_data_layer():
coros = [
File(
Expand All @@ -274,6 +269,12 @@ async def send_ask_user(
elif spec.type == "action":
action_res = cast(AskActionResponse, user_res)
final_res = action_res
interaction = action_res["value"]

if not self.session.has_first_interaction and interaction:
self.session.has_first_interaction = True
await self.init_thread(interaction=interaction)

await self.clear_ask()
return final_res
except TimeoutError as e:
Expand Down
2 changes: 1 addition & 1 deletion backend/chainlit/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def __init__(
self.user = user
self.token = token
self.root_message = root_message
self.has_user_message = False
self.has_first_interaction = False
self.user_env = user_env or {}
self.chat_profile = chat_profile
self.active_steps = []
Expand Down
4 changes: 2 additions & 2 deletions backend/chainlit/socket.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ async def connection_successful(sid):
if context.session.thread_id_to_resume and config.code.on_chat_resume:
thread = await resume_thread(context.session)
if thread:
context.session.has_user_message = True
context.session.has_first_interaction = True
await context.emitter.clear_ask()
await context.emitter.resume_thread(thread)
await config.code.on_chat_resume(thread)
Expand Down Expand Up @@ -173,7 +173,7 @@ async def disconnect(sid):
if config.code.on_chat_end and session:
await config.code.on_chat_end()

if session and session.thread_id and session.has_user_message:
if session and session.thread_id and session.has_first_interaction:
await persist_user_session(session.thread_id, session.to_persistable())

async def disconnect_on_timeout(sid):
Expand Down
23 changes: 19 additions & 4 deletions cypress/e2e/chat_profiles/.chainlit/config.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,28 @@ cache = false
# Show the prompt playground
prompt_playground = true

# Process and display HTML in messages. This can be a security risk (see https://stackoverflow.com/questions/19603097/why-is-it-dangerous-to-render-user-generated-html-or-javascript)
unsafe_allow_html = false

# Process and display mathematical expressions. This can clash with "$" characters in messages.
latex = false

# Authorize users to upload files with messages
multi_modal = true

# Allows user to use speech to text
[features.speech_to_text]
enabled = false
# See all languages here https://github.com/JamesBrill/react-speech-recognition/blob/HEAD/docs/API.md#language-string
# language = "en-US"

[UI]
# Name of the app and chatbot.
name = "Chatbot"

# Show the readme while the thread is empty.
show_readme_as_default = false

# Description of the app and chatbot. This is used for HTML tags.
# description = ""

Expand All @@ -41,9 +59,6 @@ hide_cot = false
# The CSS file can be served from the public directory or via an external link.
# custom_css = "/public/test.css"

# If the app is served behind a reverse proxy (like cloud run) we need to know the base url for oauth
# base_url = "https://mydomain.com"

# Override default MUI light theme. (Check theme.ts)
[UI.theme.light]
#background = "#FAFAFA"
Expand All @@ -66,4 +81,4 @@ hide_cot = false


[meta]
generated_by = "0.7.1"
generated_by = "1.0.0rc2"
16 changes: 14 additions & 2 deletions cypress/e2e/chat_profiles/spec.cy.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import { runTestServer } from '../../support/testUtils';
import { runTestServer, submitMessage } from '../../support/testUtils';

describe('Chat profiles', () => {
before(() => {
Expand Down Expand Up @@ -28,7 +28,6 @@ describe('Chat profiles', () => {
// Change chat profile

cy.get('[data-test="chat-profile:GPT-4"]').click();
cy.get('#confirm').click();

cy.get('.step')
.should('have.length', 1)
Expand All @@ -48,5 +47,18 @@ describe('Chat profiles', () => {
'contain',
'starting chat with admin using the GPT-4 chat profile'
);

submitMessage('hello');
cy.get('.step').should('have.length', 2).eq(1).should('contain', 'hello');
cy.get('[data-test="chat-profile:GPT-5"]').click();
cy.get('#confirm').click();

cy.get('.step')
.should('have.length', 1)
.eq(0)
.should(
'contain',
'starting chat with admin using the GPT-5 chat profile'
);
});
});
22 changes: 16 additions & 6 deletions frontend/src/components/molecules/chatProfiles.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,11 @@ import { useRecoilValue } from 'recoil';

import { Box, Popover, Tab, Tabs } from '@mui/material';

import { useChatInteract, useChatSession } from '@chainlit/react-client';
import {
useChatInteract,
useChatMessages,
useChatSession
} from '@chainlit/react-client';
import {
InputStateHandler,
Markdown,
Expand All @@ -19,6 +23,7 @@ import NewChatDialog from './newChatDialog';
export default function ChatProfiles() {
const pSettings = useRecoilValue(projectSettingsState);
const { chatProfile, setChatProfile } = useChatSession();
const { firstInteraction } = useChatMessages();
const [anchorEl, setAnchorEl] = useState<HTMLElement | null>(null);
const [chatProfileDescription, setChatProfileDescription] = useState('');
const { clear } = useChatInteract();
Expand All @@ -31,12 +36,13 @@ export default function ChatProfiles() {
setNewChatProfile(null);
};

const handleConfirm = () => {
if (!newChatProfile) {
const handleConfirm = (newChatProfileWithoutConfirm?: string) => {
const chatProfile = newChatProfileWithoutConfirm || newChatProfile;
if (!chatProfile) {
// Should never happen
throw new Error('Retry clicking on a profile before starting a new chat');
}
setChatProfile(newChatProfile);
setChatProfile(chatProfile);
setNewChatProfile(null);
clear();
handleClose();
Expand Down Expand Up @@ -70,7 +76,11 @@ export default function ChatProfiles() {
value={chatProfile || ''}
onChange={(event: React.SyntheticEvent, newValue: string) => {
setNewChatProfile(newValue);
setOpenDialog(true);
if (firstInteraction) {
setOpenDialog(true);
} else {
handleConfirm(newValue);
}
}}
variant="scrollable"
sx={{
Expand Down Expand Up @@ -178,7 +188,7 @@ export default function ChatProfiles() {
<NewChatDialog
open={openDialog}
handleClose={handleClose}
handleConfirm={handleConfirm}
handleConfirm={() => handleConfirm()}
/>
</Box>
);
Expand Down
4 changes: 2 additions & 2 deletions libs/react-client/src/state.ts
Original file line number Diff line number Diff line change
Expand Up @@ -112,8 +112,8 @@ export const tasklistState = atom<ITasklistElement[]>({
default: []
});

export const firstUserMessageState = atom<IStep | undefined>({
key: 'FirstUserMessage',
export const firstUserInteraction = atom<string | undefined>({
key: 'FirstUserInteraction',
default: undefined
});

Expand Down
6 changes: 3 additions & 3 deletions libs/react-client/src/useChatInteract.ts
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ import {
chatSettingsInputsState,
chatSettingsValueState,
elementState,
firstUserMessageState,
firstUserInteraction,
loadingState,
messagesState,
sessionIdState,
Expand All @@ -32,7 +32,7 @@ const useChatInteract = () => {
const resetSessionId = useResetRecoilState(sessionIdState);
const resetChatSettingsValue = useResetRecoilState(chatSettingsValueState);

const setFirstUserMessage = useSetRecoilState(firstUserMessageState);
const setFirstUserInteraction = useSetRecoilState(firstUserInteraction);
const setLoading = useSetRecoilState(loadingState);
const setMessages = useSetRecoilState(messagesState);
const setElements = useSetRecoilState(elementState);
Expand All @@ -47,7 +47,7 @@ const useChatInteract = () => {
session?.socket.disconnect();
setIdToResume(undefined);
resetSessionId();
setFirstUserMessage(undefined);
setFirstUserInteraction(undefined);
setMessages([]);
setElements([]);
setAvatars([]);
Expand Down
6 changes: 3 additions & 3 deletions libs/react-client/src/useChatMessages.ts
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
import { useRecoilValue } from 'recoil';

import { firstUserMessageState, messagesState } from './state';
import { firstUserInteraction, messagesState } from './state';

const useChatMessages = () => {
const messages = useRecoilValue(messagesState);
const firstUserMessage = useRecoilValue(firstUserMessageState);
const firstInteraction = useRecoilValue(firstUserInteraction);

return {
messages,
firstUserMessage
firstInteraction
};
};

Expand Down
8 changes: 4 additions & 4 deletions libs/react-client/src/useChatSession.ts
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ import {
chatSettingsInputsState,
chatSettingsValueState,
elementState,
firstUserMessageState,
firstUserInteraction,
loadingState,
messagesState,
sessionIdState,
Expand Down Expand Up @@ -49,7 +49,7 @@ const useChatSession = () => {
const [session, setSession] = useRecoilState(sessionState);

const resetChatSettingsValue = useResetRecoilState(chatSettingsValueState);
const setFirstUserMessage = useSetRecoilState(firstUserMessageState);
const setFirstUserInteraction = useSetRecoilState(firstUserInteraction);
const setLoading = useSetRecoilState(loadingState);
const setMessages = useSetRecoilState(messagesState);
const setAskUser = useSetRecoilState(askUserState);
Expand Down Expand Up @@ -139,8 +139,8 @@ const useChatSession = () => {
setMessages((oldMessages) => addMessage(oldMessages, message));
});

socket.on('init_thread', (message: IStep) => {
setFirstUserMessage(message);
socket.on('first_interaction', (interaction: string) => {
setFirstUserInteraction(interaction);
});

socket.on('update_message', (message: IStep) => {
Expand Down

0 comments on commit 5f8c1b6

Please sign in to comment.