Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

do not ask confirmation for new chat if no interaction happened #605

Merged
merged 2 commits into from
Dec 19, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion backend/chainlit/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def decorator(method):
async def wrapper(self, *args, **kwargs):
if (
isinstance(context.session, WebsocketSession)
and not context.session.has_user_message
and not context.session.has_first_interaction
):
# Queue the method invocation waiting for the first user message
queues = context.session.thread_queues
Expand Down
35 changes: 18 additions & 17 deletions backend/chainlit/emitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,8 +64,7 @@ async def clear_ask(self):
"""Stub method to clear the prompt from the UI."""
pass

async def init_thread(self, step_dict: StepDict):
"""Signal the UI that a new thread (with a user message) exists"""
async def init_thread(self, interaction: str):
pass

async def process_user_message(self, payload: UIMessagePayload) -> Message:
Expand Down Expand Up @@ -167,7 +166,7 @@ def clear_ask(self):

return self.emit("clear_ask", {})

async def flush_thread_queues(self, name: str):
async def flush_thread_queues(self, interaction: str):
if data_layer := get_data_layer():
if isinstance(self.session.user, PersistedUser):
user_id = self.session.user.id
Expand All @@ -176,14 +175,13 @@ async def flush_thread_queues(self, name: str):
await data_layer.update_thread(
thread_id=self.session.thread_id,
user_id=user_id,
metadata={"name": name},
metadata={"name": interaction},
)
await self.session.flush_method_queue()

async def init_thread(self, step: StepDict):
"""Signal the UI that a new thread (with a user message) exists"""
await self.flush_thread_queues(name=step["output"])
await self.emit("init_thread", step)
async def init_thread(self, interaction: str):
await self.flush_thread_queues(interaction)
await self.emit("first_interaction", interaction)

async def process_user_message(self, payload: UIMessagePayload):
step_dict = payload["message"]
Expand All @@ -197,9 +195,9 @@ async def process_user_message(self, payload: UIMessagePayload):

asyncio.create_task(message._create())

if not self.session.has_user_message:
self.session.has_user_message = True
asyncio.create_task(self.init_thread(message.to_dict()))
if not self.session.has_first_interaction:
self.session.has_first_interaction = True
asyncio.create_task(self.init_thread(message.content))

if file_refs:
files = [
Expand Down Expand Up @@ -239,11 +237,13 @@ async def send_ask_user(
] = None

if user_res:
interaction = None
if spec.type == "text":
message_dict_res = cast(StepDict, user_res)
await self.process_user_message(
{"message": message_dict_res, "fileReferences": None}
)
interaction = message_dict_res["output"]
final_res = message_dict_res
elif spec.type == "file":
file_refs = cast(List[FileReference], user_res)
Expand All @@ -253,12 +253,7 @@ async def send_ask_user(
if file["id"] in self.session.files
]
final_res = files
if not self.session.has_user_message:
self.session.has_user_message = True
await self.flush_thread_queues(
name=",".join([file["name"] for file in files])
)

interaction = ",".join([file["name"] for file in files])
if get_data_layer():
coros = [
File(
Expand All @@ -274,6 +269,12 @@ async def send_ask_user(
elif spec.type == "action":
action_res = cast(AskActionResponse, user_res)
final_res = action_res
interaction = action_res["value"]

if not self.session.has_first_interaction and interaction:
self.session.has_first_interaction = True
await self.init_thread(interaction=interaction)

await self.clear_ask()
return final_res
except TimeoutError as e:
Expand Down
2 changes: 1 addition & 1 deletion backend/chainlit/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def __init__(
self.user = user
self.token = token
self.root_message = root_message
self.has_user_message = False
self.has_first_interaction = False
self.user_env = user_env or {}
self.chat_profile = chat_profile
self.active_steps = []
Expand Down
4 changes: 2 additions & 2 deletions backend/chainlit/socket.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ async def connection_successful(sid):
if context.session.thread_id_to_resume and config.code.on_chat_resume:
thread = await resume_thread(context.session)
if thread:
context.session.has_user_message = True
context.session.has_first_interaction = True
await context.emitter.clear_ask()
await context.emitter.resume_thread(thread)
await config.code.on_chat_resume(thread)
Expand Down Expand Up @@ -173,7 +173,7 @@ async def disconnect(sid):
if config.code.on_chat_end and session:
await config.code.on_chat_end()

if session and session.thread_id and session.has_user_message:
if session and session.thread_id and session.has_first_interaction:
await persist_user_session(session.thread_id, session.to_persistable())

async def disconnect_on_timeout(sid):
Expand Down
23 changes: 19 additions & 4 deletions cypress/e2e/chat_profiles/.chainlit/config.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,28 @@ cache = false
# Show the prompt playground
prompt_playground = true

# Process and display HTML in messages. This can be a security risk (see https://stackoverflow.com/questions/19603097/why-is-it-dangerous-to-render-user-generated-html-or-javascript)
unsafe_allow_html = false

# Process and display mathematical expressions. This can clash with "$" characters in messages.
latex = false

# Authorize users to upload files with messages
multi_modal = true

# Allows user to use speech to text
[features.speech_to_text]
enabled = false
# See all languages here https://github.com/JamesBrill/react-speech-recognition/blob/HEAD/docs/API.md#language-string
# language = "en-US"

[UI]
# Name of the app and chatbot.
name = "Chatbot"

# Show the readme while the thread is empty.
show_readme_as_default = false

# Description of the app and chatbot. This is used for HTML tags.
# description = ""

Expand All @@ -41,9 +59,6 @@ hide_cot = false
# The CSS file can be served from the public directory or via an external link.
# custom_css = "/public/test.css"

# If the app is served behind a reverse proxy (like cloud run) we need to know the base url for oauth
# base_url = "https://mydomain.com"

# Override default MUI light theme. (Check theme.ts)
[UI.theme.light]
#background = "#FAFAFA"
Expand All @@ -66,4 +81,4 @@ hide_cot = false


[meta]
generated_by = "0.7.1"
generated_by = "1.0.0rc2"
16 changes: 14 additions & 2 deletions cypress/e2e/chat_profiles/spec.cy.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import { runTestServer } from '../../support/testUtils';
import { runTestServer, submitMessage } from '../../support/testUtils';

describe('Chat profiles', () => {
before(() => {
Expand Down Expand Up @@ -28,7 +28,6 @@ describe('Chat profiles', () => {
// Change chat profile

cy.get('[data-test="chat-profile:GPT-4"]').click();
cy.get('#confirm').click();

cy.get('.step')
.should('have.length', 1)
Expand All @@ -48,5 +47,18 @@ describe('Chat profiles', () => {
'contain',
'starting chat with admin using the GPT-4 chat profile'
);

submitMessage('hello');
cy.get('.step').should('have.length', 2).eq(1).should('contain', 'hello');
cy.get('[data-test="chat-profile:GPT-5"]').click();
cy.get('#confirm').click();

cy.get('.step')
.should('have.length', 1)
.eq(0)
.should(
'contain',
'starting chat with admin using the GPT-5 chat profile'
);
});
});
22 changes: 16 additions & 6 deletions frontend/src/components/molecules/chatProfiles.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,11 @@ import { useRecoilValue } from 'recoil';

import { Box, Popover, Tab, Tabs } from '@mui/material';

import { useChatInteract, useChatSession } from '@chainlit/react-client';
import {
useChatInteract,
useChatMessages,
useChatSession
} from '@chainlit/react-client';
import {
InputStateHandler,
Markdown,
Expand All @@ -19,6 +23,7 @@ import NewChatDialog from './newChatDialog';
export default function ChatProfiles() {
const pSettings = useRecoilValue(projectSettingsState);
const { chatProfile, setChatProfile } = useChatSession();
const { firstInteraction } = useChatMessages();
const [anchorEl, setAnchorEl] = useState<HTMLElement | null>(null);
const [chatProfileDescription, setChatProfileDescription] = useState('');
const { clear } = useChatInteract();
Expand All @@ -31,12 +36,13 @@ export default function ChatProfiles() {
setNewChatProfile(null);
};

const handleConfirm = () => {
if (!newChatProfile) {
const handleConfirm = (newChatProfileWithoutConfirm?: string) => {
const chatProfile = newChatProfileWithoutConfirm || newChatProfile;
if (!chatProfile) {
// Should never happen
throw new Error('Retry clicking on a profile before starting a new chat');
}
setChatProfile(newChatProfile);
setChatProfile(chatProfile);
setNewChatProfile(null);
clear();
handleClose();
Expand Down Expand Up @@ -70,7 +76,11 @@ export default function ChatProfiles() {
value={chatProfile || ''}
onChange={(event: React.SyntheticEvent, newValue: string) => {
setNewChatProfile(newValue);
setOpenDialog(true);
if (firstInteraction) {
setOpenDialog(true);
} else {
handleConfirm(newValue);
}
}}
variant="scrollable"
sx={{
Expand Down Expand Up @@ -178,7 +188,7 @@ export default function ChatProfiles() {
<NewChatDialog
open={openDialog}
handleClose={handleClose}
handleConfirm={handleConfirm}
handleConfirm={() => handleConfirm()}
/>
</Box>
);
Expand Down
4 changes: 2 additions & 2 deletions libs/react-client/src/state.ts
Original file line number Diff line number Diff line change
Expand Up @@ -112,8 +112,8 @@ export const tasklistState = atom<ITasklistElement[]>({
default: []
});

export const firstUserMessageState = atom<IStep | undefined>({
key: 'FirstUserMessage',
export const firstUserInteraction = atom<string | undefined>({
key: 'FirstUserInteraction',
default: undefined
});

Expand Down
6 changes: 3 additions & 3 deletions libs/react-client/src/useChatInteract.ts
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ import {
chatSettingsInputsState,
chatSettingsValueState,
elementState,
firstUserMessageState,
firstUserInteraction,
loadingState,
messagesState,
sessionIdState,
Expand All @@ -32,7 +32,7 @@ const useChatInteract = () => {
const resetSessionId = useResetRecoilState(sessionIdState);
const resetChatSettingsValue = useResetRecoilState(chatSettingsValueState);

const setFirstUserMessage = useSetRecoilState(firstUserMessageState);
const setFirstUserInteraction = useSetRecoilState(firstUserInteraction);
const setLoading = useSetRecoilState(loadingState);
const setMessages = useSetRecoilState(messagesState);
const setElements = useSetRecoilState(elementState);
Expand All @@ -47,7 +47,7 @@ const useChatInteract = () => {
session?.socket.disconnect();
setIdToResume(undefined);
resetSessionId();
setFirstUserMessage(undefined);
setFirstUserInteraction(undefined);
setMessages([]);
setElements([]);
setAvatars([]);
Expand Down
6 changes: 3 additions & 3 deletions libs/react-client/src/useChatMessages.ts
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
import { useRecoilValue } from 'recoil';

import { firstUserMessageState, messagesState } from './state';
import { firstUserInteraction, messagesState } from './state';

const useChatMessages = () => {
const messages = useRecoilValue(messagesState);
const firstUserMessage = useRecoilValue(firstUserMessageState);
const firstInteraction = useRecoilValue(firstUserInteraction);

return {
messages,
firstUserMessage
firstInteraction
};
};

Expand Down
8 changes: 4 additions & 4 deletions libs/react-client/src/useChatSession.ts
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ import {
chatSettingsInputsState,
chatSettingsValueState,
elementState,
firstUserMessageState,
firstUserInteraction,
loadingState,
messagesState,
sessionIdState,
Expand Down Expand Up @@ -49,7 +49,7 @@ const useChatSession = () => {
const [session, setSession] = useRecoilState(sessionState);

const resetChatSettingsValue = useResetRecoilState(chatSettingsValueState);
const setFirstUserMessage = useSetRecoilState(firstUserMessageState);
const setFirstUserInteraction = useSetRecoilState(firstUserInteraction);
const setLoading = useSetRecoilState(loadingState);
const setMessages = useSetRecoilState(messagesState);
const setAskUser = useSetRecoilState(askUserState);
Expand Down Expand Up @@ -139,8 +139,8 @@ const useChatSession = () => {
setMessages((oldMessages) => addMessage(oldMessages, message));
});

socket.on('init_thread', (message: IStep) => {
setFirstUserMessage(message);
socket.on('first_interaction', (interaction: string) => {
setFirstUserInteraction(interaction);
});

socket.on('update_message', (message: IStep) => {
Expand Down