From 3a4a7ceae5be618f5151b690fb220d9916f817a0 Mon Sep 17 00:00:00 2001 From: Willy Douhard Date: Tue, 12 Dec 2023 16:40:30 +0100 Subject: [PATCH] Release/1.0.0 (#587) * init * base data layer * add step to data layer * add queue until user message * remove data_persistence from config * upload askfilemessage response as file element * step context * step context * llama index integration + step elements * haystack integration + step error * langchain integration + error handling * feedback * feedback * refactor AppUser to User * migrate react-client * migrate react-components * migrate main ui * fix mypy * fix type import * fix step issues + langchain issues * token count * remove IMessage and MessageDict * wip fix tests * fix existing tests * add data layer test * action toast * remove seconds from message time * add support for action interruption * rename appuser to user * toast style * fix update thread * use http for file uploads * remove useless create_task * wip data layer * rename client step type * fix chainlit hello * wip data layer * fix test * wip data layer * add root param to step * fix llama index callback handler * add step show input * fix final answer streaming * update readme * step type lower case * chainlit_client * debug ci * debug ci * bump sdk version --- .gitignore | 2 + CHANGELOG.md | 13 + README.md | 42 +- backend/chainlit/__init__.py | 55 +- backend/chainlit/auth.py | 19 +- backend/chainlit/cli/__init__.py | 3 +- backend/chainlit/client/base.py | 169 ------ backend/chainlit/client/cloud.py | 502 ------------------ backend/chainlit/config.py | 25 +- backend/chainlit/context.py | 10 +- backend/chainlit/data/__init__.py | 384 +++++++++++++- backend/chainlit/data/acl.py | 11 +- backend/chainlit/element.py | 209 +++----- backend/chainlit/emitter.py | 167 ++++-- backend/chainlit/haystack/callbacks.py | 88 +-- backend/chainlit/hello.py | 2 +- backend/chainlit/langchain/callbacks.py | 252 ++++----- backend/chainlit/llama_index/callbacks.py | 116 ++-- backend/chainlit/message.py | 386 +++++++------- backend/chainlit/oauth_providers.py | 73 +-- backend/chainlit/playground/provider.py | 74 +-- .../playground/providers/anthropic.py | 8 +- .../playground/providers/huggingface.py | 4 +- .../playground/providers/langchain.py | 18 +- .../chainlit/playground/providers/openai.py | 32 +- backend/chainlit/prompt.py | 40 -- backend/chainlit/server.py | 254 +++++---- backend/chainlit/session.py | 149 ++++-- backend/chainlit/socket.py | 83 +-- backend/chainlit/step.py | 393 ++++++++++++++ backend/chainlit/types.py | 99 +++- backend/chainlit/user.py | 32 ++ backend/chainlit/user_session.py | 6 +- backend/pyproject.toml | 6 +- cypress/e2e/action/spec.cy.ts | 24 +- cypress/e2e/ask_file/main.py | 22 +- cypress/e2e/ask_file/spec.cy.ts | 6 +- .../ask_multiple_files/.chainlit/config.toml | 2 +- cypress/e2e/ask_multiple_files/spec.cy.ts | 2 +- cypress/e2e/ask_user/main.py | 2 +- cypress/e2e/ask_user/spec.cy.ts | 6 +- cypress/e2e/audio_element/spec.cy.ts | 4 +- cypress/e2e/author_rename/main.py | 6 + cypress/e2e/author_rename/spec.cy.ts | 10 +- .../e2e/{global_elements => avatar}/cat.jpeg | Bin cypress/e2e/avatar/main.py | 11 + cypress/e2e/avatar/spec.cy.ts | 9 +- cypress/e2e/chat_profiles/main.py | 17 +- cypress/e2e/chat_profiles/spec.cy.ts | 7 +- cypress/e2e/chat_settings/spec.cy.ts | 4 +- .../.chainlit/config.toml | 0 .../e2e/{sdk_availability => context}/main.py | 3 +- cypress/e2e/context/spec.cy.ts | 19 + cypress/e2e/conversations/main.py | 21 - cypress/e2e/conversations/spec.cy.ts | 125 ----- cypress/e2e/cot/main.py | 31 -- cypress/e2e/cot/spec.cy.ts | 22 - cypress/e2e/cot_mixed/.chainlit/config.toml | 75 --- cypress/e2e/cot_mixed/main.py | 21 - cypress/e2e/cot_mixed/spec.cy.ts | 19 - .../{cot => data_layer}/.chainlit/config.toml | 0 cypress/e2e/data_layer/main.py | 130 +++++ cypress/e2e/data_layer/spec.cy.ts | 66 +++ cypress/e2e/default_expand_cot/main.py | 31 +- cypress/e2e/default_expand_cot/spec.cy.ts | 7 +- .../.chainlit/config.toml | 0 .../{scoped_elements => elements}/cat.jpeg | Bin .../{scoped_elements => elements}/dummy.pdf | Bin cypress/e2e/elements/main.py | 52 ++ cypress/e2e/elements/spec.cy.ts | 52 ++ cypress/e2e/error_handling/spec.cy.ts | 12 +- cypress/e2e/file_element/spec.cy.ts | 4 +- cypress/e2e/global_elements/main.py | 26 - cypress/e2e/global_elements/spec.cy.ts | 46 -- cypress/e2e/header_auth/main.py | 8 +- cypress/e2e/header_auth/spec.cy.ts | 4 +- .../.chainlit/config.toml | 2 +- cypress/e2e/hide_prompt_playground/main.py | 19 +- cypress/e2e/hide_prompt_playground/spec.cy.ts | 9 +- .../.chainlit/config.toml | 0 .../main.py | 0 cypress/e2e/input_history/spec.cy.ts | 34 ++ cypress/e2e/llama_index_cb/main.py | 4 + cypress/e2e/llama_index_cb/spec.cy.ts | 6 +- cypress/e2e/message_history/spec.cy.ts | 34 -- cypress/e2e/on_chat_start/spec.cy.ts | 2 +- cypress/e2e/password_auth/main.py | 8 +- cypress/e2e/password_auth/spec.cy.ts | 4 +- cypress/e2e/plotly/spec.cy.ts | 4 +- cypress/e2e/prompt_playground/main.py | 56 +- cypress/e2e/prompt_playground/provider.py | 10 +- cypress/e2e/pyplot/spec.cy.ts | 4 +- cypress/e2e/remove_elements/main.py | 19 +- cypress/e2e/remove_elements/spec.cy.ts | 5 +- cypress/e2e/remove_message/spec.cy.ts | 31 -- .../.chainlit/config.toml | 0 .../{remove_message => remove_step}/main.py | 7 +- cypress/e2e/remove_step/spec.cy.ts | 31 ++ cypress/e2e/scoped_elements/main.py | 26 - cypress/e2e/scoped_elements/spec.cy.ts | 23 - cypress/e2e/sdk_availability/spec.cy.ts | 21 - .../.chainlit/config.toml | 0 cypress/e2e/step/main.py | 25 + cypress/e2e/step/main_async.py | 25 + cypress/e2e/step/spec.cy.ts | 30 ++ cypress/e2e/stop_task/main_async.py | 2 +- cypress/e2e/stop_task/main_sync.py | 2 +- cypress/e2e/stop_task/spec.cy.ts | 30 +- cypress/e2e/streaming/main.py | 18 +- cypress/e2e/streaming/spec.cy.ts | 32 +- cypress/e2e/tasklist/main.py | 2 + cypress/e2e/tasklist/spec.cy.ts | 2 +- .../e2e/update_message/.chainlit/config.toml | 62 --- cypress/e2e/update_message/spec.cy.ts | 13 - .../.chainlit/config.toml | 0 .../{update_message => update_step}/main.py | 9 +- cypress/e2e/update_step/spec.cy.ts | 18 + cypress/e2e/upload_attachments/spec.cy.ts | 16 +- cypress/e2e/user_env/spec.cy.ts | 4 +- cypress/e2e/user_session/spec.cy.ts | 16 +- cypress/e2e/video_element/spec.cy.ts | 4 +- frontend/package.json | 2 +- frontend/pnpm-lock.yaml | 242 ++++----- frontend/src/App.tsx | 17 +- frontend/src/api/index.ts | 2 +- .../atoms/buttons/progressIconButton.tsx | 32 ++ .../atoms/buttons/userButton/avatar.tsx | 4 +- .../atoms/buttons/userButton/menu.tsx | 2 +- .../src/components/molecules/attachments.tsx | 99 ++++ .../molecules/tasklist/TaskList.tsx | 22 +- .../organisms/chat/Messages/container.tsx | 86 ++- .../organisms/chat/Messages/index.tsx | 71 ++- .../organisms/chat/history/index.tsx | 46 +- .../src/components/organisms/chat/index.tsx | 94 +++- .../organisms/chat/inputBox/UploadButton.tsx | 12 +- .../organisms/chat/inputBox/index.tsx | 57 +- .../organisms/chat/inputBox/input.tsx | 65 +-- .../conversationsHistory/Conversation.tsx | 98 ---- frontend/src/components/organisms/header.tsx | 2 +- .../components/organisms/playground/index.tsx | 12 +- .../organisms/threadHistory/Thread.tsx | 153 ++++++ .../sidebar/DeleteThreadButton.tsx} | 40 +- .../sidebar/OpenThreadListButton.tsx} | 4 +- .../sidebar/ThreadList.tsx} | 65 ++- .../sidebar/filters/FeedbackSelect.tsx | 4 +- .../sidebar/filters/SearchBar.tsx | 4 +- .../sidebar/filters/index.tsx | 0 .../sidebar/index.tsx | 65 ++- frontend/src/hooks/localChatHistory.ts | 42 -- frontend/src/hooks/useLLMProviders.ts | 2 +- frontend/src/pages/Conversation.tsx | 63 --- frontend/src/pages/Element.tsx | 16 +- frontend/src/pages/Env.tsx | 2 +- frontend/src/pages/Page.tsx | 6 +- frontend/src/pages/ResumeButton.tsx | 16 +- frontend/src/pages/Thread.tsx | 53 ++ frontend/src/router.tsx | 6 +- frontend/src/state/chat.ts | 14 +- frontend/src/state/conversations.ts | 8 - frontend/src/state/project.ts | 6 +- frontend/src/state/threads.ts | 8 + .../{chatHistory.ts => userInputHistory.ts} | 14 +- libs/react-client/README.md | 10 +- libs/react-client/src/api/hooks/auth.ts | 18 +- libs/react-client/src/api/index.tsx | 111 +++- libs/react-client/src/state.ts | 87 ++- libs/react-client/src/types/chatHistory.ts | 15 - libs/react-client/src/types/conversation.ts | 12 - libs/react-client/src/types/element.ts | 28 +- libs/react-client/src/types/feedback.ts | 7 + libs/react-client/src/types/file.ts | 12 +- libs/react-client/src/types/generation.ts | 53 ++ libs/react-client/src/types/history.ts | 15 + libs/react-client/src/types/index.ts | 8 +- libs/react-client/src/types/message.ts | 59 -- libs/react-client/src/types/step.ts | 38 ++ libs/react-client/src/types/thread.ts | 12 + libs/react-client/src/types/user.ts | 17 +- libs/react-client/src/useChatData.ts | 5 - libs/react-client/src/useChatInteract.ts | 60 ++- libs/react-client/src/useChatSession.ts | 104 ++-- libs/react-client/src/utils/group.ts | 6 +- libs/react-client/src/utils/message.ts | 101 ++-- libs/react-components/hooks/useUpload.tsx | 44 +- libs/react-components/src/Attachment.tsx | 64 +++ libs/react-components/src/Attachments.tsx | 49 -- libs/react-components/src/elements/Audio.tsx | 14 +- libs/react-components/src/elements/Avatar.tsx | 25 +- libs/react-components/src/elements/File.tsx | 141 +---- libs/react-components/src/elements/Image.tsx | 10 +- .../src/elements/InlinedImageList.tsx | 5 +- .../src/elements/InlinedVideoList.tsx | 5 +- libs/react-components/src/elements/PDF.tsx | 11 +- libs/react-components/src/elements/Plotly.tsx | 12 +- libs/react-components/src/elements/Text.tsx | 9 +- libs/react-components/src/elements/Video.tsx | 9 +- libs/react-components/src/index.ts | 2 +- .../react-components/src/messages/Message.tsx | 23 +- .../src/messages/MessageContainer.tsx | 4 +- .../src/messages/Messages.tsx | 16 +- .../messages/components/AskUploadButton.tsx | 90 +++- .../src/messages/components/Author.tsx | 24 +- .../src/messages/components/DetailsButton.tsx | 22 +- .../messages/components/FeedbackButtons.tsx | 23 +- .../messages/components/MessageActions.tsx | 4 +- .../messages/components/MessageButtons.tsx | 19 +- .../messages/components/MessageContent.tsx | 79 ++- .../src/messages/components/MessageTime.tsx | 3 +- .../messages/components/PlaygroundButton.tsx | 8 +- .../react-components/src/playground/basic.tsx | 38 +- libs/react-components/src/playground/chat.tsx | 42 +- .../src/playground/editor/MessageWrapper.tsx | 32 +- .../src/playground/editor/formatted.tsx | 20 +- .../src/playground/editor/functionModal.tsx | 8 +- .../src/playground/editor/promptMessage.tsx | 11 +- .../src/playground/editor/template/index.tsx | 15 +- .../playground/editor/template/variable.tsx | 16 +- .../src/playground/editor/variableModal.tsx | 14 +- .../src/playground/functionInput.tsx | 2 +- .../src/playground/helpers/provider.ts | 2 +- .../react-components/src/playground/index.tsx | 34 +- .../src/playground/modelSettings.tsx | 22 +- .../src/playground/submitButton.tsx | 20 +- .../src/playground/variableInput.tsx | 2 +- .../src/types/messageContext.ts | 17 +- libs/react-components/src/types/playground.ts | 6 +- .../src/types/playgroundContext.ts | 4 +- libs/react-components/tests/content.spec.tsx | 52 +- libs/react-components/tests/message.spec.tsx | 24 +- libs/react-components/utils/message.ts | 22 +- 230 files changed, 4688 insertions(+), 4065 deletions(-) delete mode 100644 backend/chainlit/client/base.py delete mode 100644 backend/chainlit/client/cloud.py delete mode 100644 backend/chainlit/prompt.py create mode 100644 backend/chainlit/step.py create mode 100644 backend/chainlit/user.py rename cypress/e2e/{global_elements => avatar}/cat.jpeg (100%) rename cypress/e2e/{conversations => context}/.chainlit/config.toml (100%) rename cypress/e2e/{sdk_availability => context}/main.py (99%) create mode 100644 cypress/e2e/context/spec.cy.ts delete mode 100644 cypress/e2e/conversations/main.py delete mode 100644 cypress/e2e/conversations/spec.cy.ts delete mode 100644 cypress/e2e/cot/main.py delete mode 100644 cypress/e2e/cot/spec.cy.ts delete mode 100644 cypress/e2e/cot_mixed/.chainlit/config.toml delete mode 100644 cypress/e2e/cot_mixed/main.py delete mode 100644 cypress/e2e/cot_mixed/spec.cy.ts rename cypress/e2e/{cot => data_layer}/.chainlit/config.toml (100%) create mode 100644 cypress/e2e/data_layer/main.py create mode 100644 cypress/e2e/data_layer/spec.cy.ts rename cypress/e2e/{global_elements => elements}/.chainlit/config.toml (100%) rename cypress/e2e/{scoped_elements => elements}/cat.jpeg (100%) rename cypress/e2e/{scoped_elements => elements}/dummy.pdf (100%) create mode 100644 cypress/e2e/elements/main.py create mode 100644 cypress/e2e/elements/spec.cy.ts delete mode 100644 cypress/e2e/global_elements/main.py delete mode 100644 cypress/e2e/global_elements/spec.cy.ts rename cypress/e2e/{message_history => input_history}/.chainlit/config.toml (100%) rename cypress/e2e/{message_history => input_history}/main.py (100%) create mode 100644 cypress/e2e/input_history/spec.cy.ts delete mode 100644 cypress/e2e/message_history/spec.cy.ts delete mode 100644 cypress/e2e/remove_message/spec.cy.ts rename cypress/e2e/{remove_message => remove_step}/.chainlit/config.toml (100%) rename cypress/e2e/{remove_message => remove_step}/main.py (78%) create mode 100644 cypress/e2e/remove_step/spec.cy.ts delete mode 100644 cypress/e2e/scoped_elements/main.py delete mode 100644 cypress/e2e/scoped_elements/spec.cy.ts delete mode 100644 cypress/e2e/sdk_availability/spec.cy.ts rename cypress/e2e/{scoped_elements => step}/.chainlit/config.toml (100%) create mode 100644 cypress/e2e/step/main.py create mode 100644 cypress/e2e/step/main_async.py create mode 100644 cypress/e2e/step/spec.cy.ts delete mode 100644 cypress/e2e/update_message/.chainlit/config.toml delete mode 100644 cypress/e2e/update_message/spec.cy.ts rename cypress/e2e/{sdk_availability => update_step}/.chainlit/config.toml (100%) rename cypress/e2e/{update_message => update_step}/main.py (56%) create mode 100644 cypress/e2e/update_step/spec.cy.ts create mode 100644 frontend/src/components/atoms/buttons/progressIconButton.tsx create mode 100644 frontend/src/components/molecules/attachments.tsx delete mode 100644 frontend/src/components/organisms/conversationsHistory/Conversation.tsx create mode 100644 frontend/src/components/organisms/threadHistory/Thread.tsx rename frontend/src/components/organisms/{conversationsHistory/sidebar/DeleteConversationButton.tsx => threadHistory/sidebar/DeleteThreadButton.tsx} (73%) rename frontend/src/components/organisms/{conversationsHistory/sidebar/OpenChatHistoryButton.tsx => threadHistory/sidebar/OpenThreadListButton.tsx} (90%) rename frontend/src/components/organisms/{conversationsHistory/sidebar/ConversationsHistoryList.tsx => threadHistory/sidebar/ThreadList.tsx} (75%) rename frontend/src/components/organisms/{conversationsHistory => threadHistory}/sidebar/filters/FeedbackSelect.tsx (95%) rename frontend/src/components/organisms/{conversationsHistory => threadHistory}/sidebar/filters/SearchBar.tsx (93%) rename frontend/src/components/organisms/{conversationsHistory => threadHistory}/sidebar/filters/index.tsx (100%) rename frontend/src/components/organisms/{conversationsHistory => threadHistory}/sidebar/index.tsx (73%) delete mode 100644 frontend/src/hooks/localChatHistory.ts delete mode 100644 frontend/src/pages/Conversation.tsx create mode 100644 frontend/src/pages/Thread.tsx delete mode 100644 frontend/src/state/conversations.ts create mode 100644 frontend/src/state/threads.ts rename frontend/src/state/{chatHistory.ts => userInputHistory.ts} (66%) delete mode 100644 libs/react-client/src/types/chatHistory.ts delete mode 100644 libs/react-client/src/types/conversation.ts create mode 100644 libs/react-client/src/types/feedback.ts create mode 100644 libs/react-client/src/types/generation.ts create mode 100644 libs/react-client/src/types/history.ts delete mode 100644 libs/react-client/src/types/message.ts create mode 100644 libs/react-client/src/types/step.ts create mode 100644 libs/react-client/src/types/thread.ts create mode 100644 libs/react-components/src/Attachment.tsx delete mode 100644 libs/react-components/src/Attachments.tsx diff --git a/.gitignore b/.gitignore index fe2df75319..d6122d006e 100644 --- a/.gitignore +++ b/.gitignore @@ -5,6 +5,8 @@ dist .env +*.files + poetry.lock venv diff --git a/CHANGELOG.md b/CHANGELOG.md index f630d85903..108fdb03f3 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,6 +8,19 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/). Nothing is unreleased! +## [1.0.0rc0] - 2023-12-12 + +### Added + +- cl.Step + +### Changed + +- File upload uses HTTP instead of WS and no longer has size limitation +- `cl.AppUser` becomes `cl.User` +- `Prompt` has been split in `ChatGeneration` and `CompletionGeneration` +- `Action` now display a toaster in the UI while running + ## [0.7.700] - 2023-11-28 ### Added diff --git a/README.md b/README.md index 386cd6f409..e791adcfba 100644 --- a/README.md +++ b/README.md @@ -8,15 +8,7 @@ Chainlit is an open-source async Python framework that makes it incredibly fast to build Chat GPT like applications with your **own business logic and data**. -Contact us here for **Enterprise Support** and to get early access to the **Monitoring** product: https://forms.gle/BX3UNBLmTF75KgZVA - -Key features: - -- [πŸ’¬ Multi Modal chats](https://docs.chainlit.io/chat-experience/elements) -- [πŸ’­ Chain of Thought visualisation](https://docs.chainlit.io/observability-iteration/chain-of-thought) -- [πŸ’Ύ Data persistence + human feedback](https://docs.chainlit.io/chat-data/overview) -- [πŸ› In context Prompt Playground](https://docs.chainlit.io/observability-iteration/prompt-playground/overview) -- [πŸ‘€ Authentication](https://docs.chainlit.io/authentication/overview) +Contact us [here](https://forms.gle/BX3UNBLmTF75KgZVA) for **Enterprise Support** and to get early access to the **Analytics & Observability** product. https://github.com/Chainlit/chainlit/assets/13104895/8882af90-fdfa-4b24-8200-1ee96c6c7490 @@ -49,11 +41,16 @@ Create a new file `demo.py` with the following code: import chainlit as cl +@cl.step +def tool(): + return "Response from the tool!" + + @cl.on_message # this function will be called every time a user inputs a message in the UI async def main(message: cl.Message): """ This function is called every time a user inputs a message in the UI. - It sends back an intermediate response from Tool 1, followed by the final answer. + It sends back an intermediate response from the tool, followed by the final answer. Args: message: The user's message. @@ -62,15 +59,11 @@ async def main(message: cl.Message): None. """ - # Send an intermediate response from Tool 1. - await cl.Message( - author="Tool 1", - content=f"Response from tool1", - parent_id=message.id, - ).send() + # Call the tool + tool() # Send the final answer. - await cl.Message(content=f"This is the final answer").send() + await cl.Message(content="This is the final answer").send() ``` Now run it! @@ -90,7 +83,6 @@ Chainlit is compatible with all Python programs and libraries. That being said, - [OpenAI Assistant](https://github.com/Chainlit/cookbook/tree/main/openai-assistant) - [Llama Index](https://docs.chainlit.io/integrations/llama-index) - [Haystack](https://docs.chainlit.io/integrations/haystack) -- [Langflow](https://docs.chainlit.io/integrations/langflow) ## 🎨 Custom Frontend @@ -109,22 +101,8 @@ To build and connect your own frontend, check out our [Custom Frontend Cookbook] You can find various examples of Chainlit apps [here](https://github.com/Chainlit/cookbook) that leverage tools and services such as OpenAI, Anthropiс, LangChain, LlamaIndex, ChromaDB, Pinecone and more. -## πŸ›£ Roadmap - -- [x] Selectable chat profiles (at the beginning of a chat) -- [ ] One click chat sharing -- New clients: - - [x] Custom React app - - [ ] Slack - - [ ] Discord - - [ ] Website embbed - Tell us what you would like to see added in Chainlit using the Github issues or on [Discord](https://discord.gg/k73SQ3FyUh). -## 🏒 Enterprise support - -For entreprise grade features and self hosting, please visit this [page](https://docs.chainlit.io/cloud/persistence/enterprise) and fill the form. - ## πŸ’ Contributing As an open-source initiative in a rapidly evolving domain, we welcome contributions, be it through the addition of new features or the improvement of documentation. diff --git a/backend/chainlit/__init__.py b/backend/chainlit/__init__.py index 56c0c4bc19..c6796a755b 100644 --- a/backend/chainlit/__init__.py +++ b/backend/chainlit/__init__.py @@ -5,7 +5,7 @@ env_found = load_dotenv(dotenv_path=os.path.join(os.getcwd(), ".env")) import asyncio -from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Union +from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional from starlette.datastructures import Headers @@ -21,8 +21,8 @@ from chainlit.action import Action from chainlit.cache import cache from chainlit.chat_settings import ChatSettings -from chainlit.client.base import AppUser, ConversationDict, PersistedAppUser from chainlit.config import config +from chainlit.context import context from chainlit.element import ( Audio, Avatar, @@ -46,31 +46,34 @@ Message, ) from chainlit.oauth_providers import get_configured_oauth_providers +from chainlit.step import Step, step from chainlit.sync import make_async, run_sync from chainlit.telemetry import trace -from chainlit.types import ChatProfile, FileSpec +from chainlit.types import ChatProfile, ThreadDict +from chainlit.user import PersistedUser, User from chainlit.user_session import user_session from chainlit.utils import make_module_getattr, wrap_user_function from chainlit.version import __version__ +from chainlit_client import ChatGeneration, CompletionGeneration, GenerationMessage if env_found: logger.info("Loaded .env file") @trace -def password_auth_callback(func: Callable[[str, str], Optional[AppUser]]) -> Callable: +def password_auth_callback(func: Callable[[str, str], Optional[User]]) -> Callable: """ Framework agnostic decorator to authenticate the user. Args: - func (Callable[[str, str], Optional[AppUser]]): The authentication callback to execute. Takes the email and password as parameters. + func (Callable[[str, str], Optional[User]]): The authentication callback to execute. Takes the email and password as parameters. Example: @cl.password_auth_callback - async def password_auth_callback(username: str, password: str) -> Optional[AppUser]: + async def password_auth_callback(username: str, password: str) -> Optional[User]: Returns: - Callable[[str, str], Optional[AppUser]]: The decorated authentication callback. + Callable[[str, str], Optional[User]]: The decorated authentication callback. """ config.code.password_auth_callback = wrap_user_function(func) @@ -78,19 +81,19 @@ async def password_auth_callback(username: str, password: str) -> Optional[AppUs @trace -def header_auth_callback(func: Callable[[Headers], Optional[AppUser]]) -> Callable: +def header_auth_callback(func: Callable[[Headers], Optional[User]]) -> Callable: """ Framework agnostic decorator to authenticate the user via a header Args: - func (Callable[[Headers], Optional[AppUser]]): The authentication callback to execute. + func (Callable[[Headers], Optional[User]]): The authentication callback to execute. Example: @cl.header_auth_callback - async def header_auth_callback(headers: Headers) -> Optional[AppUser]: + async def header_auth_callback(headers: Headers) -> Optional[User]: Returns: - Callable[[Headers], Optional[AppUser]]: The decorated authentication callback. + Callable[[Headers], Optional[User]]: The decorated authentication callback. """ config.code.header_auth_callback = wrap_user_function(func) @@ -99,20 +102,20 @@ async def header_auth_callback(headers: Headers) -> Optional[AppUser]: @trace def oauth_callback( - func: Callable[[str, str, Dict[str, str], AppUser], Optional[AppUser]] + func: Callable[[str, str, Dict[str, str], User], Optional[User]] ) -> Callable: """ Framework agnostic decorator to authenticate the user via oauth Args: - func (Callable[[str, str, Dict[str, str], AppUser], Optional[AppUser]]): The authentication callback to execute. + func (Callable[[str, str, Dict[str, str], User], Optional[User]]): The authentication callback to execute. Example: @cl.oauth_callback - async def oauth_callback(provider_id: str, token: str, raw_user_data: Dict[str, str], default_app_user: AppUser) -> Optional[AppUser]: + async def oauth_callback(provider_id: str, token: str, raw_user_data: Dict[str, str], default_app_user: User) -> Optional[User]: Returns: - Callable[[str, str, Dict[str, str], AppUser], Optional[AppUser]]: The decorated authentication callback. + Callable[[str, str, Dict[str, str], User], Optional[User]]: The decorated authentication callback. """ if len(get_configured_oauth_providers()) == 0: @@ -158,7 +161,7 @@ def on_chat_start(func: Callable) -> Callable: @trace -def on_chat_resume(func: Callable[[ConversationDict], Any]) -> Callable: +def on_chat_resume(func: Callable[[ThreadDict], Any]) -> Callable: """ Hook to react to resume websocket connection event. @@ -175,16 +178,16 @@ def on_chat_resume(func: Callable[[ConversationDict], Any]) -> Callable: @trace def set_chat_profiles( - func: Callable[[Optional["AppUser"]], List["ChatProfile"]] + func: Callable[[Optional["User"]], List["ChatProfile"]] ) -> Callable: """ - Programmatic declaration of the available chat profiles (can depend on the AppUser from the session if authentication is setup). + Programmatic declaration of the available chat profiles (can depend on the User from the session if authentication is setup). Args: - func (Callable[[Optional["AppUser"]], List["ChatProfile"]]): The function declaring the chat profiles. + func (Callable[[Optional["User"]], List["ChatProfile"]]): The function declaring the chat profiles. Returns: - Callable[[Optional["AppUser"]], List["ChatProfile"]]: The decorated function. + Callable[[Optional["User"]], List["ChatProfile"]]: The decorated function. """ config.code.set_chat_profiles = wrap_user_function(func) @@ -225,7 +228,7 @@ def author_rename(func: Callable[[str], str]) -> Callable[[str], str]: @trace def on_stop(func: Callable) -> Callable: """ - Hook to react to the user stopping a conversation. + Hook to react to the user stopping a thread. Args: func (Callable[[], Any]): The stop hook to execute. @@ -291,8 +294,8 @@ def sleep(duration: int): __all__ = [ "user_session", "Action", - "AppUser", - "PersistedAppUser", + "User", + "PersistedUser", "Audio", "Pdf", "Plotly", @@ -312,6 +315,11 @@ def sleep(duration: int): "AskUserMessage", "AskActionMessage", "AskFileMessage", + "Step", + "step", + "ChatGeneration", + "CompletionGeneration", + "GenerationMessage", "on_chat_start", "on_chat_end", "on_chat_resume", @@ -325,6 +333,7 @@ def sleep(duration: int): "run_sync", "make_async", "cache", + "context", "LangchainCallbackHandler", "AsyncLangchainCallbackHandler", "LlamaIndexCallbackHandler", diff --git a/backend/chainlit/auth.py b/backend/chainlit/auth.py index 7ab56b3281..0812d5dba2 100644 --- a/backend/chainlit/auth.py +++ b/backend/chainlit/auth.py @@ -3,10 +3,10 @@ from typing import Any, Dict import jwt -from chainlit.client.cloud import AppUser from chainlit.config import config -from chainlit.data import chainlit_client +from chainlit.data import get_data_layer from chainlit.oauth_providers import get_configured_oauth_providers +from chainlit.user import User from fastapi import Depends, HTTPException from fastapi.security import OAuth2PasswordBearer @@ -47,7 +47,7 @@ def get_configuration(): } -def create_jwt(data: AppUser) -> str: +def create_jwt(data: User) -> str: to_encode = data.to_dict() # type: Dict[str, Any] to_encode.update( { @@ -67,21 +67,20 @@ async def authenticate_user(token: str = Depends(reuseable_oauth)): options={"verify_signature": True}, ) del dict["exp"] - app_user = AppUser(**dict) + user = User(**dict) except Exception as e: raise HTTPException(status_code=401, detail="Invalid authentication token") - - if chainlit_client: + if data_layer := get_data_layer(): try: - persisted_app_user = await chainlit_client.get_app_user(app_user.username) + persisted_user = await data_layer.get_user(user.identifier) except Exception as e: raise HTTPException(status_code=500, detail=str(e)) - if persisted_app_user == None: + if persisted_user == None: raise HTTPException(status_code=401, detail="User does not exist") - return persisted_app_user + return persisted_user else: - return app_user + return user async def get_current_user(token: str = Depends(reuseable_oauth)): diff --git a/backend/chainlit/cli/__init__.py b/backend/chainlit/cli/__init__.py index bab0b25079..a349d86f31 100644 --- a/backend/chainlit/cli/__init__.py +++ b/backend/chainlit/cli/__init__.py @@ -21,7 +21,7 @@ from chainlit.logger import logger from chainlit.markdown import init_markdown from chainlit.secret import random_secret -from chainlit.server import app, max_message_size, register_wildcard_route_handler +from chainlit.server import app, register_wildcard_route_handler from chainlit.telemetry import trace_event @@ -73,7 +73,6 @@ async def start(): host=host, port=port, log_level=log_level, - ws_max_size=max_message_size, ws_per_message_deflate=ws_per_message_deflate, ) server = uvicorn.Server(config) diff --git a/backend/chainlit/client/base.py b/backend/chainlit/client/base.py deleted file mode 100644 index 5432378c77..0000000000 --- a/backend/chainlit/client/base.py +++ /dev/null @@ -1,169 +0,0 @@ -from typing import ( - Any, - Dict, - Generic, - List, - Literal, - Mapping, - Optional, - TypedDict, - TypeVar, -) - -from chainlit.logger import logger -from chainlit.prompt import Prompt -from dataclasses_json import DataClassJsonMixin -from pydantic import BaseModel, Field -from pydantic.dataclasses import dataclass -from python_graphql_client import GraphqlClient - -ElementType = Literal[ - "image", "avatar", "text", "pdf", "tasklist", "audio", "video", "file", "plotly" -] -ElementDisplay = Literal["inline", "side", "page"] -ElementSize = Literal["small", "medium", "large"] - -Role = Literal["USER", "ADMIN", "OWNER", "ANONYMOUS"] -Provider = Literal[ - "credentials", "header", "github", "google", "azure-ad", "okta", "auth0", "descope" -] - - -class AppUserDict(TypedDict): - id: str - username: str - - -# Used when logging-in a user -@dataclass -class AppUser(DataClassJsonMixin): - username: str - role: Role = "USER" - tags: List[str] = Field(default_factory=list) - image: Optional[str] = None - provider: Optional[Provider] = None - - -@dataclass -class PersistedAppUserFields: - id: str - createdAt: int - - -@dataclass -class PersistedAppUser(AppUser, PersistedAppUserFields): - pass - - -class MessageDict(TypedDict): - conversationId: Optional[str] - id: str - createdAt: Optional[int] - content: str - author: str - prompt: Optional[Prompt] - language: Optional[str] - parentId: Optional[str] - indent: Optional[int] - authorIsUser: Optional[bool] - waitForAnswer: Optional[bool] - isError: Optional[bool] - humanFeedback: Optional[int] - disableHumanFeedback: Optional[bool] - - -class ElementDict(TypedDict): - id: str - conversationId: Optional[str] - type: ElementType - url: str - objectKey: Optional[str] - name: str - display: ElementDisplay - size: Optional[ElementSize] - language: Optional[str] - forIds: Optional[List[str]] - mime: Optional[str] - - -class ConversationDict(TypedDict): - id: Optional[str] - metadata: Optional[Dict] - createdAt: Optional[int] - appUser: Optional[AppUserDict] - messages: List[MessageDict] - elements: Optional[List[ElementDict]] - - -@dataclass -class PageInfo: - hasNextPage: bool - endCursor: Optional[str] - - -T = TypeVar("T") - - -@dataclass -class PaginatedResponse(DataClassJsonMixin, Generic[T]): - pageInfo: PageInfo - data: List[T] - - -class Pagination(BaseModel): - first: int - cursor: Optional[str] = None - - -class ConversationFilter(BaseModel): - feedback: Optional[Literal[-1, 0, 1]] = None - username: Optional[str] = None - search: Optional[str] = None - - -class ChainlitGraphQLClient: - def __init__(self, api_key: str, chainlit_server: str): - self.headers = {"content-type": "application/json"} - if api_key: - self.headers["x-api-key"] = api_key - else: - raise ValueError("Cannot instantiate Cloud Client without CHAINLIT_API_KEY") - - graphql_endpoint = f"{chainlit_server}/api/graphql" - self.graphql_client = GraphqlClient( - endpoint=graphql_endpoint, headers=self.headers - ) - - async def query(self, query: str, variables: Dict[str, Any] = {}) -> Dict[str, Any]: - """ - Execute a GraphQL query. - - :param query: The GraphQL query string. - :param variables: A dictionary of variables for the query. - :return: The response data as a dictionary. - """ - return await self.graphql_client.execute_async(query=query, variables=variables) - - def check_for_errors(self, response: Dict[str, Any], raise_error: bool = False): - if "errors" in response: - if raise_error: - raise Exception( - f"{response['errors'][0]['message']}. Path: {str(response['errors'][0]['path'])}" - ) - logger.error(response["errors"][0]) - return True - return False - - async def mutation( - self, mutation: str, variables: Mapping[str, Any] = {} - ) -> Dict[str, Any]: - """ - Execute a GraphQL mutation. - - :param mutation: The GraphQL mutation string. - :param variables: A dictionary of variables for the mutation. - :return: The response data as a dictionary. - """ - return await self.graphql_client.execute_async( - query=mutation, variables=variables - ) diff --git a/backend/chainlit/client/cloud.py b/backend/chainlit/client/cloud.py deleted file mode 100644 index 7090bf6619..0000000000 --- a/backend/chainlit/client/cloud.py +++ /dev/null @@ -1,502 +0,0 @@ -import uuid -from typing import Any, Dict, List, Optional, Tuple, Union - -import httpx -from chainlit.logger import logger - -from .base import ( - AppUser, - ChainlitGraphQLClient, - ConversationDict, - ConversationFilter, - ElementDict, - MessageDict, - PageInfo, - PaginatedResponse, - Pagination, - PersistedAppUser, -) - - -class ChainlitCloudClient(ChainlitGraphQLClient): - chainlit_server: str - - def __init__(self, api_key: str, chainlit_server="https://cloud.chainlit.io"): - # Remove trailing slash - chainlit_server = chainlit_server.rstrip("/") - super().__init__(api_key=api_key, chainlit_server=chainlit_server) - self.chainlit_server = chainlit_server - - async def create_app_user(self, app_user: AppUser) -> Optional[PersistedAppUser]: - mutation = """ - mutation ($username: String!, $role: Role!, $tags: [String!], $provider: String, $image: String) { - createAppUser(username: $username, role: $role, tags: $tags, provider: $provider, image: $image) { - id, - username, - role, - tags, - provider, - image, - createdAt - } - } - """ - variables = app_user.to_dict() - res = await self.mutation(mutation, variables) - - if self.check_for_errors(res): - logger.warning("Could not create app user.") - return None - - return PersistedAppUser.from_dict(res["data"]["createAppUser"]) - - async def update_app_user(self, app_user: AppUser) -> Optional[PersistedAppUser]: - mutation = """ - mutation ($username: String!, $role: Role!, $tags: [String!], $provider: String, $image: String) { - updateAppUser(username: $username, role: $role, tags: $tags, provider: $provider, image: $image) { - id, - username, - role, - tags, - provider, - image, - createdAt - } - } - """ - variables = app_user.to_dict() - res = await self.mutation(mutation, variables) - - if self.check_for_errors(res): - logger.warning("Could not update app user.") - return None - - return PersistedAppUser.from_dict(res["data"]["updateAppUser"]) - - async def get_app_user(self, username: str) -> Optional[PersistedAppUser]: - query = """ - query ($username: String!) { - getAppUser(username: $username) { - id, - username, - role, - tags, - provider, - image, - createdAt - } - } - """ - variables = {"username": username} - res = await self.query(query, variables) - - if self.check_for_errors(res): - logger.warning("Could not get app user.") - return None - - return PersistedAppUser.from_dict(res["data"]["getAppUser"]) - - async def delete_app_user(self, username: str) -> bool: - mutation = """ - mutation ($username: String!) { - deleteAppUser(username: $username) { - id, - } - } - """ - variables = {"username": username} - res = await self.mutation(mutation, variables) - - if self.check_for_errors(res): - logger.warning("Could not delete app user.") - return False - - return True - - async def create_conversation( - self, app_user_id: Optional[str], tags: Optional[List[str]] - ) -> Optional[str]: - mutation = """ - mutation ($appUserId: String, $tags: [String!]) { - createConversation (appUserId: $appUserId, tags: $tags) { - id - } - } - """ - variables = {} # type: Dict[str, Any] - if app_user_id is not None: - variables["appUserId"] = app_user_id - - if tags: - variables["tags"] = tags - - res = await self.mutation(mutation, variables) - - if self.check_for_errors(res): - logger.warning("Could not create conversation.") - return None - - return res["data"]["createConversation"]["id"] - - async def delete_conversation(self, conversation_id: str) -> bool: - mutation = """ - mutation ($id: ID!) { - deleteConversation(id: $id) { - id - } - } - """ - variables = {"id": conversation_id} - res = await self.mutation(mutation, variables) - self.check_for_errors(res, raise_error=True) - - return True - - async def get_conversation_author(self, conversation_id: str) -> Optional[str]: - query = """ - query ($id: ID!) { - conversation(id: $id) { - appUser { - username - } - } - } - """ - variables = { - "id": conversation_id, - } - res = await self.query(query, variables) - self.check_for_errors(res, raise_error=True) - data = res.get("data") - conversation = data.get("conversation") if data else None - if not conversation: - return None - return ( - conversation["appUser"].get("username") if conversation["appUser"] else None - ) - - async def get_conversation(self, conversation_id: str) -> ConversationDict: - query = """ - query ($id: ID!) { - conversation(id: $id) { - id - createdAt - tags - metadata - appUser { - id - username - } - messages { - id - isError - parentId - indent - author - content - waitForAnswer - humanFeedback - humanFeedbackComment - disableHumanFeedback - language - prompt - authorIsUser - createdAt - } - elements { - id - conversationId - type - name - mime - url - display - language - size - forIds - } - } - } - """ - variables = { - "id": conversation_id, - } - res = await self.query(query, variables) - self.check_for_errors(res, raise_error=True) - - return res["data"]["conversation"] - - async def update_conversation_metadata(self, conversation_id: str, metadata: Dict): - mutation = """mutation ($conversationId: ID!, $metadata: Json!) { - updateConversationMetadata(conversationId: $conversationId, metadata: $metadata) { - id - } - }""" - variables = { - "conversationId": conversation_id, - "metadata": metadata, - } - - res = await self.mutation(mutation, variables) - self.check_for_errors(res, raise_error=True) - - return True - - async def get_conversations( - self, pagination: Pagination, filter: ConversationFilter - ): - query = """query ( - $first: Int - $cursor: String - $withFeedback: Int - $username: String - $search: String - ) { - conversations( - first: $first - cursor: $cursor - withFeedback: $withFeedback - username: $username - search: $search - ) { - pageInfo { - endCursor - hasNextPage - } - edges { - cursor - node { - id - createdAt - tags - appUser { - username - } - messages { - content - } - } - } - } - }""" - - variables = { - "first": pagination.first, - "cursor": pagination.cursor, - "withFeedback": filter.feedback, - "username": filter.username, - "search": filter.search, - } - res = await self.query(query, variables) - self.check_for_errors(res, raise_error=True) - - conversations = [] - - for edge in res["data"]["conversations"]["edges"]: - node = edge["node"] - conversations.append(node) - - page_info = res["data"]["conversations"]["pageInfo"] - - return PaginatedResponse( - pageInfo=PageInfo( - hasNextPage=page_info["hasNextPage"], - endCursor=page_info["endCursor"], - ), - data=conversations, - ) - - async def set_human_feedback( - self, message_id: str, feedback: int, feedbackComment: Optional[str] - ) -> bool: - mutation = """mutation ($messageId: ID!, $humanFeedback: Int!, $humanFeedbackComment: String) { - setHumanFeedback(messageId: $messageId, humanFeedback: $humanFeedback, humanFeedbackComment: $humanFeedbackComment) { - id - humanFeedback - humanFeedbackComment - } - }""" - variables = { - "messageId": message_id, - "humanFeedback": feedback, - } - if feedbackComment: - variables["humanFeedbackComment"] = feedbackComment - res = await self.mutation(mutation, variables) - self.check_for_errors(res, raise_error=True) - - return True - - async def get_message(self): - raise NotImplementedError - - async def create_message(self, variables: MessageDict) -> Optional[str]: - mutation = """ - mutation ($id: ID!, $conversationId: ID!, $author: String!, $content: String!, $language: String, $prompt: Json, $isError: Boolean, $parentId: String, $indent: Int, $authorIsUser: Boolean, $disableHumanFeedback: Boolean, $waitForAnswer: Boolean, $createdAt: StringOrFloat) { - createMessage(id: $id, conversationId: $conversationId, author: $author, content: $content, language: $language, prompt: $prompt, isError: $isError, parentId: $parentId, indent: $indent, authorIsUser: $authorIsUser, disableHumanFeedback: $disableHumanFeedback, waitForAnswer: $waitForAnswer, createdAt: $createdAt) { - id - } - } - """ - res = await self.mutation(mutation, variables) - if self.check_for_errors(res): - logger.warning("Could not create message.") - return None - - return res["data"]["createMessage"]["id"] - - async def update_message(self, message_id: str, variables: MessageDict) -> bool: - mutation = """ - mutation ($messageId: ID!, $author: String!, $content: String!, $parentId: String, $language: String, $prompt: Json, $disableHumanFeedback: Boolean) { - updateMessage(messageId: $messageId, author: $author, content: $content, parentId: $parentId, language: $language, prompt: $prompt, disableHumanFeedback: $disableHumanFeedback) { - id - } - } - """ - res = await self.mutation(mutation, dict(messageId=message_id, **variables)) - - if self.check_for_errors(res): - logger.warning("Could not update message.") - return False - - return True - - async def delete_message(self, message_id: str) -> bool: - mutation = """ - mutation ($messageId: ID!) { - deleteMessage(messageId: $messageId) { - id - } - } - """ - res = await self.mutation(mutation, {"messageId": message_id}) - - if self.check_for_errors(res): - logger.warning("Could not delete message.") - return False - - return True - - async def get_element( - self, conversation_id: str, element_id: str - ) -> Optional[ElementDict]: - query = """query ( - $conversationId: ID! - $id: ID! - ) { - element( - conversationId: $conversationId, - id: $id - ) { - id - conversationId - type - name - mime - url - display - language - size - forIds - } - }""" - - variables = { - "conversationId": conversation_id, - "id": element_id, - } - res = await self.query(query, variables) - self.check_for_errors(res, raise_error=True) - - return res["data"]["element"] - - async def create_element(self, variables: ElementDict) -> Optional[ElementDict]: - mutation = """ - mutation ($conversationId: ID!, $type: String!, $name: String!, $display: String!, $forIds: [String!]!, $url: String, $objectKey: String, $size: String, $language: String, $mime: String) { - createElement(conversationId: $conversationId, type: $type, url: $url, objectKey: $objectKey, name: $name, display: $display, size: $size, language: $language, forIds: $forIds, mime: $mime) { - id, - type, - url, - objectKey, - name, - display, - size, - language, - forIds, - mime - } - } - """ - res = await self.mutation(mutation, variables) - - if self.check_for_errors(res): - logger.warning("Could not create element.") - return None - - return res["data"]["createElement"] - - async def update_element(self, variables: ElementDict) -> Optional[ElementDict]: - mutation = """ - mutation ($conversationId: ID!, $id: ID!, $forIds: [String!]!) { - updateElement(conversationId: $conversationId, id: $id, forIds: $forIds) { - id, - } - } - """ - - res = await self.mutation(mutation, variables) - - if self.check_for_errors(res): - logger.warning("Could not update element.") - return None - - return res["data"]["updateElement"] - - async def upload_element( - self, content: Union[bytes, str], mime: str, conversation_id: Optional[str] - ) -> Dict: - id = str(uuid.uuid4()) - body = {"fileName": id, "contentType": mime} - - if conversation_id: - body["conversationId"] = conversation_id - - path = "/api/upload/file" - - async with httpx.AsyncClient() as client: - response = await client.post( - f"{self.chainlit_server}{path}", - json=body, - headers=self.headers, - ) - if response.status_code != 200: - reason = response.text - logger.error(f"Failed to sign upload url: {reason}") - return {"object_key": None, "url": None} - json_res = response.json() - - upload_details = json_res["post"] - object_key = upload_details["fields"]["key"] - signed_url = json_res["signedUrl"] - - # Prepare form data - form_data = {} # type: Dict[str, Tuple[Union[str, None], Any]] - for field_name, field_value in upload_details["fields"].items(): - form_data[field_name] = (None, field_value) - - # Add file to the form_data - # Note: The content_type parameter is not needed here, as the correct MIME type should be set in the 'Content-Type' field from upload_details - form_data["file"] = (id, content) - - async with httpx.AsyncClient() as client: - upload_response = await client.post( - upload_details["url"], - files=form_data, - ) - try: - upload_response.raise_for_status() - url = f'{upload_details["url"]}/{object_key}' - return {"object_key": object_key, "url": signed_url} - except Exception as e: - logger.error(f"Failed to upload file: {str(e)}") - return {"object_key": None, "url": None} diff --git a/backend/chainlit/config.py b/backend/chainlit/config.py index 56d14f6e58..668c8ca063 100644 --- a/backend/chainlit/config.py +++ b/backend/chainlit/config.py @@ -1,6 +1,7 @@ import os import sys from importlib import util +from pathlib import Path from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional import tomli @@ -12,8 +13,8 @@ if TYPE_CHECKING: from chainlit.action import Action - from chainlit.client.base import AppUser, ConversationDict - from chainlit.types import ChatProfile + from chainlit.types import ChatProfile, ThreadDict + from chainlit.user import User BACKEND_ROOT = os.path.dirname(__file__) @@ -23,6 +24,10 @@ # Get the directory the script is running from APP_ROOT = os.getcwd() +# Create the directory to store the uploaded files +FILES_DIRECTORY = Path(APP_ROOT) / ".files" +FILES_DIRECTORY.mkdir(exist_ok=True) + config_dir = os.path.join(APP_ROOT, ".chainlit") config_file = os.path.join(config_dir, "config.toml") @@ -66,7 +71,7 @@ # Name of the app and chatbot. name = "Chatbot" -# Show the readme while the conversation is empty. +# Show the readme while the thread is empty. show_readme_as_default = true # Description of the app and chatbot. This is used for HTML tags. @@ -190,20 +195,20 @@ class CodeSettings: # Module object loaded from the module_name module: Any = None # Bunch of callbacks defined by the developer - password_auth_callback: Optional[Callable[[str, str], Optional["AppUser"]]] = None - header_auth_callback: Optional[Callable[[Headers], Optional["AppUser"]]] = None + password_auth_callback: Optional[Callable[[str, str], Optional["User"]]] = None + header_auth_callback: Optional[Callable[[Headers], Optional["User"]]] = None oauth_callback: Optional[ - Callable[[str, str, Dict[str, str], "AppUser"], Optional["AppUser"]] + Callable[[str, str, Dict[str, str], "User"], Optional["User"]] ] = None on_stop: Optional[Callable[[], Any]] = None on_chat_start: Optional[Callable[[], Any]] = None on_chat_end: Optional[Callable[[], Any]] = None - on_chat_resume: Optional[Callable[["ConversationDict"], Any]] = None + on_chat_resume: Optional[Callable[["ThreadDict"], Any]] = None on_message: Optional[Callable[[str], Any]] = None author_rename: Optional[Callable[[str], str]] = None on_settings_update: Optional[Callable[[Dict[str, Any]], Any]] = None set_chat_profiles: Optional[ - Callable[[Optional["AppUser"]], List["ChatProfile"]] + Callable[[Optional["User"]], List["ChatProfile"]] ] = None @@ -229,8 +234,6 @@ class ChainlitConfig: root = APP_ROOT # Chainlit server URL. Used only for cloud features chainlit_server: str - # Whether or not a chainlit api key has been provided - data_persistence: bool # The url of the deployed app. Only set if the app is deployed. chainlit_prod_url = chainlit_prod_url @@ -341,11 +344,9 @@ def load_config(): settings = load_settings() chainlit_server = os.environ.get("CHAINLIT_SERVER", "https://cloud.chainlit.io") - data_persistence = "CHAINLIT_API_KEY" in os.environ config = ChainlitConfig( chainlit_server=chainlit_server, - data_persistence=data_persistence, chainlit_prod_url=chainlit_prod_url, run=RunSettings(), **settings, diff --git a/backend/chainlit/context.py b/backend/chainlit/context.py index ff22afd747..0c5ab12a38 100644 --- a/backend/chainlit/context.py +++ b/backend/chainlit/context.py @@ -7,9 +7,8 @@ from lazify import LazyProxy if TYPE_CHECKING: - from chainlit.client.cloud import AppUser, PersistedAppUser from chainlit.emitter import BaseChainlitEmitter - from chainlit.message import Message + from chainlit.user import PersistedUser, User class ChainlitContextException(Exception): @@ -22,6 +21,11 @@ class ChainlitContext: emitter: "BaseChainlitEmitter" session: Union["HTTPSession", "WebsocketSession"] + @property + def current_step(self): + if self.session.active_steps: + return self.session.active_steps[-1] + def __init__(self, session: Union["HTTPSession", "WebsocketSession"]): from chainlit.emitter import BaseChainlitEmitter, ChainlitEmitter @@ -47,7 +51,7 @@ def init_ws_context(session_or_sid: Union[WebsocketSession, str]) -> ChainlitCon def init_http_context( - user: Optional[Union["AppUser", "PersistedAppUser"]] = None, + user: Optional[Union["User", "PersistedUser"]] = None, auth_token: Optional[str] = None, user_env: Optional[Dict[str, str]] = None, ) -> ChainlitContext: diff --git a/backend/chainlit/data/__init__.py b/backend/chainlit/data/__init__.py index 1b7951cd3f..ce42c77428 100644 --- a/backend/chainlit/data/__init__.py +++ b/backend/chainlit/data/__init__.py @@ -1,13 +1,379 @@ +import functools import os -from typing import Optional +from collections import deque +from typing import TYPE_CHECKING, Dict, List, Optional -from chainlit.client.cloud import ChainlitCloudClient -from chainlit.config import config +from chainlit.context import context +from chainlit.logger import logger +from chainlit.session import WebsocketSession +from chainlit.types import Feedback, Pagination, ThreadDict, ThreadFilter +from chainlit.user import PersistedUser, User, UserDict +from chainlit_client import Attachment +from chainlit_client import Feedback as ClientFeedback +from chainlit_client import PageInfo, PaginatedResponse +from chainlit_client import Step as ClientStep +from chainlit_client.thread import NumberListFilter, StringFilter, StringListFilter +from chainlit_client.thread import ThreadFilter as ClientThreadFilter -chainlit_client = None # type: Optional[ChainlitCloudClient] +if TYPE_CHECKING: + from chainlit.element import Element, ElementDict + from chainlit.step import FeedbackDict, StepDict -if config.data_persistence: - chainlit_client = ChainlitCloudClient( - api_key=os.environ.get("CHAINLIT_API_KEY", ""), - chainlit_server=config.chainlit_server, - ) +_data_layer = None + + +def queue_until_user_message(): + def decorator(method): + @functools.wraps(method) + async def wrapper(self, *args, **kwargs): + if ( + isinstance(context.session, WebsocketSession) + and not context.session.has_user_message + ): + # Queue the method invocation waiting for the first user message + queues = context.session.thread_queues + method_name = method.__name__ + if method_name not in queues: + queues[method_name] = deque() + queues[method_name].append((method, self, args, kwargs)) + + else: + # Otherwise, Execute the method immediately + return await method(self, *args, **kwargs) + + return wrapper + + return decorator + + +class BaseDataLayer: + """Base class for data persistence.""" + + async def get_user(self, identifier: str) -> Optional["PersistedUser"]: + return None + + async def create_user(self, user: "User") -> Optional["PersistedUser"]: + pass + + async def upsert_feedback( + self, + feedback: Feedback, + ) -> str: + return "" + + @queue_until_user_message() + async def create_element(self, element_dict: "ElementDict"): + pass + + async def get_element( + self, thread_id: str, element_id: str + ) -> Optional["ElementDict"]: + pass + + @queue_until_user_message() + async def delete_element(self, element_id: str): + pass + + @queue_until_user_message() + async def create_step(self, step_dict: "StepDict"): + pass + + @queue_until_user_message() + async def update_step(self, step_dict: "StepDict"): + pass + + @queue_until_user_message() + async def delete_step(self, step_id: str): + pass + + async def get_thread_author(self, thread_id: str) -> str: + return "" + + async def delete_thread(self, thread_id: str): + pass + + async def list_threads( + self, pagination: "Pagination", filters: "ThreadFilter" + ) -> "PaginatedResponse[ThreadDict]": + return PaginatedResponse( + data=[], pageInfo=PageInfo(hasNextPage=False, endCursor=None) + ) + + async def get_thread(self, thread_id: str) -> "Optional[ThreadDict]": + return None + + async def update_thread( + self, + thread_id: str, + user_id: Optional[str] = None, + metadata: Optional[Dict] = None, + tags: Optional[List[str]] = None, + ): + pass + + +class ChainlitDataLayer: + def __init__( + self, api_key: str, chainlit_server: Optional[str] = "https://cloud.chainlit.io" + ): + from chainlit_client import ChainlitClient + + self.client = ChainlitClient(api_key=api_key, url=chainlit_server) + logger.info("Chainlit data layer initialized") + + def attachment_to_element_dict(self, attachment: Attachment) -> "ElementDict": + metadata = attachment.metadata or {} + return { + "chainlitKey": None, + "display": metadata.get("display", "side"), + "language": metadata.get("language"), + "size": metadata.get("size"), + "type": metadata.get("type", "file"), + "forId": attachment.step_id, + "id": attachment.id or "", + "mime": attachment.mime, + "name": attachment.name or "", + "objectKey": attachment.object_key, + "url": attachment.url, + "threadId": attachment.thread_id, + } + + def feedback_to_feedback_dict( + self, feedback: Optional[ClientFeedback] + ) -> "Optional[FeedbackDict]": + if not feedback: + return None + return { + "id": feedback.id or "", + "forId": feedback.step_id or "", + "value": feedback.value or 0, # type: ignore + "comment": feedback.comment, + "strategy": "BINARY", + } + + def step_to_step_dict(self, step: ClientStep) -> "StepDict": + metadata = step.metadata or {} + return { + "createdAt": step.created_at, + "id": step.id or "", + "threadId": step.thread_id or "", + "parentId": step.parent_id, + "feedback": self.feedback_to_feedback_dict(step.feedback), + "start": step.start_time, + "end": step.end_time, + "type": step.type or "undefined", + "name": step.name or "", + "generation": step.generation.to_dict() if step.generation else None, + "input": step.input or "", + "output": step.output or "", + "showInput": metadata.get("showInput", False), + "disableFeedback": metadata.get("disableFeedback", False), + "indent": metadata.get("indent"), + "language": metadata.get("language"), + "isError": metadata.get("isError", False), + "waitForAnswer": metadata.get("waitForAnswer", False), + "feedback": self.feedback_to_feedback_dict(step.feedback), + } + + async def get_user(self, identifier: str) -> Optional[PersistedUser]: + user = await self.client.api.get_user(identifier=identifier) + if not user: + return None + return PersistedUser( + id=user.id or "", + identifier=user.identifier or "", + metadata=user.metadata, + createdAt=user.created_at or "", + ) + + async def create_user(self, user: User) -> Optional[PersistedUser]: + _user = await self.client.api.get_user(identifier=user.identifier) + if not _user: + _user = await self.client.api.create_user( + identifier=user.identifier, metadata=user.metadata + ) + return PersistedUser( + id=_user.id or "", + identifier=_user.identifier or "", + metadata=_user.metadata, + createdAt=_user.created_at or "", + ) + + async def upsert_feedback( + self, + feedback: Feedback, + ): + if feedback.id: + await self.client.api.update_feedback( + id=feedback.id, + update_params={ + "comment": feedback.comment, + "strategy": feedback.strategy, + "value": feedback.value, + }, + ) + return feedback.id + else: + created = await self.client.api.create_feedback( + step_id=feedback.forId, + value=feedback.value, + comment=feedback.comment, + strategy=feedback.strategy, + ) + return created.id or "" + + @queue_until_user_message() + async def create_element(self, element: "Element"): + metadata = { + "size": element.size, + "language": element.language, + "display": element.display, + "type": element.type, + } + + await self.client.api.create_attachment( + thread_id=element.thread_id, + step_id=element.for_id or "", + mime=element.mime, + name=element.name, + url=element.url, + content=element.content, + path=element.path, + metadata=metadata, + ) + + async def get_element( + self, thread_id: str, element_id: str + ) -> Optional["ElementDict"]: + attachment = await self.client.api.get_attachment(id=element_id) + if not attachment: + return None + return self.attachment_to_element_dict(attachment) + + @queue_until_user_message() + async def delete_element(self, element_id: str): + await self.client.api.delete_attachment(id=element_id) + + @queue_until_user_message() + async def create_step(self, step_dict: "StepDict"): + metadata = { + "disableFeedback": step_dict.get("disableFeedback"), + "isError": step_dict.get("isError"), + "waitForAnswer": step_dict.get("waitForAnswer"), + "language": step_dict.get("language"), + "showInput": step_dict.get("showInput"), + } + + await self.client.api.send_steps( + [ + { + "createdAt": step_dict.get("createdAt"), + "startTime": step_dict.get("start"), + "endTime": step_dict.get("end"), + "generation": step_dict.get("generation"), + "id": step_dict.get("id"), + "parentId": step_dict.get("parentId"), + "input": step_dict.get("input"), + "output": step_dict.get("output"), + "name": step_dict.get("name"), + "threadId": step_dict.get("threadId"), + "type": step_dict.get("type"), + "metadata": metadata, + } + ] + ) + + @queue_until_user_message() + async def update_step(self, step_dict: "StepDict"): + await self.create_step(step_dict) + + @queue_until_user_message() + async def delete_step(self, step_id: str): + await self.client.api.delete_step(id=step_id) + + async def get_thread_author(self, thread_id: str) -> str: + thread = await self.get_thread(thread_id) + if not thread: + return "" + user = thread.get("user") + if not user: + return "" + return user.get("identifier") or "" + + async def delete_thread(self, thread_id: str): + await self.client.api.delete_thread(id=thread_id) + + async def list_threads( + self, pagination: "Pagination", filters: "ThreadFilter" + ) -> "PaginatedResponse[ThreadDict]": + if not filters.userIdentifier: + raise ValueError("userIdentifier is required") + + client_filters = ClientThreadFilter( + participantsIdentifier=StringListFilter( + operator="in", value=[filters.userIdentifier] + ), + ) + if filters.search: + client_filters.search = StringFilter(operator="ilike", value=filters.search) + if filters.feedback: + client_filters.feedbacksValue = NumberListFilter( + operator="in", value=[filters.feedback] + ) + return await self.client.api.list_threads( + first=pagination.first, after=pagination.cursor, filters=client_filters + ) + + async def get_thread(self, thread_id: str) -> "Optional[ThreadDict]": + thread = await self.client.api.get_thread(id=thread_id) + if not thread: + return None + elements = [] # List[ElementDict] + steps = [] # List[StepDict] + if thread.steps: + for step in thread.steps: + for attachment in step.attachments: + elements.append(self.attachment_to_element_dict(attachment)) + steps.append(self.step_to_step_dict(step)) + + user = None # type: Optional["UserDict"] + + if thread.user: + user = { + "id": thread.user.id or "", + "identifier": thread.user.identifier or "", + "metadata": thread.user.metadata, + } + + return { + "createdAt": thread.created_at or "", + "id": thread.id, + "steps": steps, + "elements": elements, + "metadata": thread.metadata, + "user": user, + "tags": thread.tags, + } + + async def update_thread( + self, + thread_id: str, + user_id: Optional[str] = None, + metadata: Optional[Dict] = None, + tags: Optional[List[str]] = None, + ): + await self.client.api.upsert_thread( + thread_id=thread_id, + participant_id=user_id, + metadata=metadata, + tags=tags, + ) + + +if api_key := os.environ.get("CHAINLIT_API_KEY"): + chainlit_server = os.environ.get("CHAINLIT_SERVER") + _data_layer = ChainlitDataLayer(api_key=api_key, chainlit_server=chainlit_server) + + +def get_data_layer(): + return _data_layer diff --git a/backend/chainlit/data/acl.py b/backend/chainlit/data/acl.py index 0c1a05b4b6..fd3971faf4 100644 --- a/backend/chainlit/data/acl.py +++ b/backend/chainlit/data/acl.py @@ -1,14 +1,15 @@ -from chainlit.data import chainlit_client +from chainlit.data import get_data_layer from fastapi import HTTPException -async def is_conversation_author(username: str, conversation_id: str): - if not chainlit_client: +async def is_thread_author(username: str, thread_id: str): + data_layer = get_data_layer() + if not data_layer: raise HTTPException(status_code=401, detail="Unauthorized") - conversation_author = await chainlit_client.get_conversation_author(conversation_id) + thread_author = await data_layer.get_thread_author(thread_id) - if conversation_author != username: + if thread_author != username: raise HTTPException(status_code=401, detail="Unauthorized") else: return True diff --git a/backend/chainlit/element.py b/backend/chainlit/element.py index 1c6aaa1f4e..89e37dd4b9 100644 --- a/backend/chainlit/element.py +++ b/backend/chainlit/element.py @@ -2,16 +2,14 @@ import uuid from enum import Enum from io import BytesIO -from typing import Any, ClassVar, Dict, List, Optional, TypeVar, Union, cast +from typing import Any, ClassVar, List, Literal, Optional, TypedDict, TypeVar, Union -import aiofiles import filetype -from chainlit.client.base import ElementDict, ElementDisplay, ElementSize, ElementType -from chainlit.client.cloud import ChainlitCloudClient from chainlit.context import context -from chainlit.data import chainlit_client +from chainlit.data import get_data_layer from chainlit.logger import logger from chainlit.telemetry import trace_event +from chainlit.types import FileDict from pydantic.dataclasses import Field, dataclass from syncer import asyncio @@ -21,16 +19,38 @@ "plotly": "application/json", } +ElementType = Literal[ + "image", "avatar", "text", "pdf", "tasklist", "audio", "video", "file", "plotly" +] +ElementDisplay = Literal["inline", "side", "page"] +ElementSize = Literal["small", "medium", "large"] + + +class ElementDict(TypedDict): + id: str + threadId: Optional[str] + type: ElementType + chainlitKey: Optional[str] + url: Optional[str] + objectKey: Optional[str] + name: str + display: ElementDisplay + size: Optional[ElementSize] + language: Optional[str] + forId: Optional[str] + mime: Optional[str] + @dataclass class Element: # The type of the element. This will be used to determine how to display the element in the UI. type: ClassVar[ElementType] - + # Name of the element, this will be used to reference the element in the UI. + name: str # The ID of the element. This is set automatically when the element is sent to the UI. id: str = Field(default_factory=lambda: str(uuid.uuid4())) - # Name of the element, this will be used to reference the element in the UI. - name: Optional[str] = None + # The key of the element hosted on Chainlit. + chainlit_key: Optional[str] = None # The URL of the element if already hosted somehwere else. url: Optional[str] = None # The S3 object key. @@ -44,7 +64,7 @@ class Element: # Controls element size size: Optional[ElementSize] = None # The ID of the message this element is associated with. - for_ids: List[str] = Field(default_factory=list) + for_id: Optional[str] = None # The language, if relevant language: Optional[str] = None # Mime type, infered based on content if not provided @@ -53,6 +73,8 @@ class Element: def __post_init__(self) -> None: trace_event(f"init {self.__class__.__name__}") self.persisted = False + self.updatable = False + self.thread_id = context.session.thread_id if not self.url and not self.path and not self.content: raise ValueError("Must provide url, path or content to instantiate element") @@ -61,136 +83,92 @@ def to_dict(self) -> ElementDict: _dict = ElementDict( { "id": self.id, + "threadId": self.thread_id, "type": self.type, - "url": self.url or "", - "name": self.name or "", + "url": self.url, + "chainlitKey": self.chainlit_key, + "name": self.name, "display": self.display, "objectKey": getattr(self, "object_key", None), "size": getattr(self, "size", None), "language": getattr(self, "language", None), - "forIds": getattr(self, "for_ids", None), + "forId": getattr(self, "for_id", None), "mime": getattr(self, "mime", None), - "conversationId": None, } ) return _dict @classmethod - def from_dict(self, _dict: Dict): - if "image" in _dict.get("mime", ""): + def from_dict(self, _dict: FileDict): + type = _dict.get("type", "") + if "image" in type and "svg" not in type: return Image( id=_dict.get("id", str(uuid.uuid4())), - content=_dict.get("content"), - name=_dict.get("name"), - url=_dict.get("url"), - display=_dict.get("display", "inline"), - mime=_dict.get("mime"), + name=_dict.get("name", ""), + path=str(_dict.get("path")), + chainlit_key=_dict.get("id"), + display="inline", + mime=type, ) else: return File( id=_dict.get("id", str(uuid.uuid4())), - content=_dict.get("content"), - name=_dict.get("name"), - url=_dict.get("url"), - language=_dict.get("language"), - display=_dict.get("display", "inline"), - size=_dict.get("size"), - mime=_dict.get("mime"), + name=_dict.get("name", ""), + path=str(_dict.get("path")), + chainlit_key=_dict.get("id"), + display="inline", + mime=type, ) - async def with_conversation_id(self): - _dict = self.to_dict() - _dict["conversationId"] = await context.session.get_conversation_id() - return _dict - - async def preprocess_content(self): - pass - - async def load(self): - if self.path: - async with aiofiles.open(self.path, "rb") as f: - self.content = await f.read() - else: - raise ValueError("Must provide path or content to load element") - - async def persist(self, client: ChainlitCloudClient) -> Optional[ElementDict]: - if not self.url and self.content and not self.persisted: - conversation_id = await context.session.get_conversation_id() - upload_res = await client.upload_element( + async def _create(self) -> bool: + if (self.persisted or self.url) and not self.updatable: + return True + if data_layer := get_data_layer(): + try: + asyncio.create_task(data_layer.create_element(self)) + except Exception as e: + logger.error(f"Failed to create element: {str(e)}") + if not self.chainlit_key or self.updatable: + file_dict = await context.session.persist_file( + name=self.name, + path=self.path, content=self.content, mime=self.mime or "", - conversation_id=conversation_id, ) - self.url = upload_res["url"] - self.object_key = upload_res["object_key"] - element_dict = await self.with_conversation_id() + self.chainlit_key = file_dict["id"] - asyncio.create_task(self._persist(element_dict)) + self.persisted = True - return element_dict - - async def _persist(self, element: ElementDict): - if not chainlit_client: - return - - try: - if self.persisted: - await chainlit_client.update_element(element) - else: - await chainlit_client.create_element(element) - self.persisted = True - except Exception as e: - logger.error(f"Failed to persist element: {str(e)}") - - async def before_emit(self, element: Dict) -> Dict: - return element + return True async def remove(self): trace_event(f"remove {self.__class__.__name__}") + data_layer = get_data_layer() + if data_layer and self.persisted: + await data_layer.delete_element(self.id) await context.emitter.emit("remove_element", {"id": self.id}) - async def send(self, for_id: Optional[str] = None): - if not self.content and not self.url and self.path: - await self.load() + async def send(self, for_id: str): + if self.persisted and not self.updatable: + return - await self.preprocess_content() + self.for_id = for_id if not self.mime: # Only guess the mime type when the content is binary self.mime = ( mime_types[self.type] if self.type in mime_types - else filetype.guess_mime(self.content) + else filetype.guess_mime(self.path or self.content) ) - if for_id and for_id not in self.for_ids: - self.for_ids.append(for_id) + await self._create() - # We have a client, persist the element - if chainlit_client: - element_dict = await self.persist(chainlit_client) - if element_dict: - self.id = element_dict["id"] + if not self.url and not self.chainlit_key: + raise ValueError("Must provide url or chainlit key to send element") - elif not self.url and not self.content: - raise ValueError("Must provide url or content to send element") - - emit_dict = cast(Dict, self.to_dict()) - - # Adding this out of to_dict since the dict will be persisted in the DB - emit_dict["content"] = self.content - - # Element was already sent - if len(self.for_ids) > 1: - trace_event(f"update {self.__class__.__name__}") - await context.emitter.emit( - "update_element", - {"id": self.id, "forIds": self.for_ids}, - ) - else: - trace_event(f"send {self.__class__.__name__}") - emit_dict = await self.before_emit(emit_dict) - await context.emitter.emit("element", emit_dict) + trace_event(f"send {self.__class__.__name__}") + await context.emitter.emit("element", self.to_dict()) ElementBased = TypeVar("ElementBased", bound=Element) @@ -208,23 +186,7 @@ class Avatar(Element): type: ClassVar[ElementType] = "avatar" async def send(self): - element = None - - if not self.content and not self.url and self.path: - await self.load() - - if not self.url and not self.content: - raise ValueError("Must provide url or content to send element") - - element = self.to_dict() - - # Adding this out of to_dict since the dict will be persisted in the DB - element["content"] = self.content - - if element: - trace_event(f"send {self.__class__.__name__}") - element = await self.before_emit(element) - await context.emitter.emit("element", element) + await super().send(for_id="") @dataclass @@ -232,15 +194,8 @@ class Text(Element): """Useful to send a text (not a message) to the UI.""" type: ClassVar[ElementType] = "text" - - content: bytes = b"" language: Optional[str] = None - async def before_emit(self, text_element): - if "content" in text_element and isinstance(text_element["content"], bytes): - text_element["content"] = text_element["content"].decode("utf-8") - return text_element - @dataclass class Pdf(Element): @@ -312,12 +267,20 @@ class TaskList(Element): name: str = "tasklist" content: str = "dummy content to pass validation" + def __post_init__(self) -> None: + super().__post_init__() + self.updatable = True + async def add_task(self, task: Task): self.tasks.append(task) async def update(self): await self.send() + async def send(self): + await self.preprocess_content() + await super().send(for_id="") + async def preprocess_content(self): # serialize enum tasks = [ diff --git a/backend/chainlit/emitter.py b/backend/chainlit/emitter.py index 4a87541e9f..b7924ae6ba 100644 --- a/backend/chainlit/emitter.py +++ b/backend/chainlit/emitter.py @@ -1,12 +1,22 @@ import asyncio import uuid -from typing import Any, Dict, Optional +from datetime import datetime +from typing import Any, Dict, List, Optional, Union, cast -from chainlit.client.base import ConversationDict, MessageDict -from chainlit.element import Element +from chainlit.data import get_data_layer +from chainlit.element import Element, File from chainlit.message import Message from chainlit.session import BaseSession, WebsocketSession -from chainlit.types import AskSpec, UIMessagePayload +from chainlit.step import StepDict +from chainlit.types import ( + AskActionResponse, + AskSpec, + FileDict, + FileReference, + ThreadDict, + UIMessagePayload, +) +from chainlit.user import PersistedUser from socketio.exceptions import TimeoutError @@ -30,19 +40,19 @@ async def ask_user(self): """Stub method to get the 'ask_user' property from the session.""" pass - async def resume_conversation(self, conv_dict: ConversationDict): - """Stub method to resume a conversation.""" + async def resume_thread(self, thread_dict: ThreadDict): + """Stub method to resume a thread.""" pass - async def send_message(self, msg_dict: dict): + async def send_step(self, step_dict: StepDict): """Stub method to send a message to the UI.""" pass - async def update_message(self, msg_dict: dict): + async def update_step(self, step_dict: StepDict): """Stub method to update a message in the UI.""" pass - async def delete_message(self, msg_dict: dict): + async def delete_step(self, step_dict: StepDict): """Stub method to delete a message in the UI.""" pass @@ -54,15 +64,17 @@ async def clear_ask(self): """Stub method to clear the prompt from the UI.""" pass - async def init_conversation(self, msg_dict: MessageDict): - """Signal the UI that a new conversation (with a user message) exists""" + async def init_thread(self, step_dict: StepDict): + """Signal the UI that a new thread (with a user message) exists""" pass async def process_user_message(self, payload: UIMessagePayload) -> Message: """Stub method to process user message.""" return Message(content="") - async def send_ask_user(self, msg_dict: dict, spec, raise_on_timeout=False): + async def send_ask_user( + self, step_dict: StepDict, spec: AskSpec, raise_on_timeout=False + ) -> Optional[Union["StepDict", "AskActionResponse", List["FileDict"]]]: """Stub method to send a prompt to the UI and wait for a response.""" pass @@ -78,7 +90,7 @@ async def task_end(self): """Stub method to send a task end signal to the UI.""" pass - async def stream_start(self, msg_dict: dict): + async def stream_start(self, step_dict: StepDict): """Stub method to send a stream start signal to the UI.""" pass @@ -90,6 +102,12 @@ async def set_chat_settings(self, settings: dict): """Stub method to set chat settings.""" pass + async def send_action_response( + self, id: str, status: bool, response: Optional[str] = None + ): + """Send an action response to the UI.""" + pass + class ChainlitEmitter(BaseChainlitEmitter): """ @@ -123,23 +141,21 @@ def ask_user(self): """Get the 'ask_user' property from the session.""" return self._get_session_property("ask_user") - def resume_conversation(self, conv_dict: ConversationDict): - """Send a conversation to the UI to resume it""" - return self.emit("resume_conversation", conv_dict) + def resume_thread(self, thread_dict: ThreadDict): + """Send a thread to the UI to resume it""" + return self.emit("resume_thread", thread_dict) - def send_message(self, msg_dict: Dict): + def send_step(self, step_dict: StepDict): """Send a message to the UI.""" - return self.emit("new_message", msg_dict) + return self.emit("new_message", step_dict) - def update_message(self, msg_dict: Dict): + def update_step(self, step_dict: StepDict): """Update a message in the UI.""" + return self.emit("update_message", step_dict) - return self.emit("update_message", msg_dict) - - def delete_message(self, msg_dict): + def delete_step(self, step_dict: StepDict): """Delete a message in the UI.""" - - return self.emit("delete_message", msg_dict) + return self.emit("delete_message", step_dict) def send_ask_timeout(self): """Send a prompt timeout message to the UI.""" @@ -151,22 +167,44 @@ def clear_ask(self): return self.emit("clear_ask", {}) - def init_conversation(self, message: MessageDict): - """Signal the UI that a new conversation (with a user message) exists""" + async def init_thread(self, step: StepDict): + """Signal the UI that a new thread (with a user message) exists""" + if data_layer := get_data_layer(): + if isinstance(self.session.user, PersistedUser): + user_id = self.session.user.id + else: + user_id = None + await data_layer.update_thread( + thread_id=self.session.thread_id, + user_id=user_id, + metadata={"name": step["output"]}, + ) + await self.session.flush_method_queue() - return self.emit("init_conversation", message) + await self.emit("init_thread", step) async def process_user_message(self, payload: UIMessagePayload): - message_dict = payload["message"] - files = payload["files"] - # Temporary UUID generated by the frontend should use v4 - assert uuid.UUID(message_dict["id"]).version == 4 + step_dict = payload["message"] + file_refs = payload["fileReferences"] + # UUID generated by the frontend should use v4 + assert uuid.UUID(step_dict["id"]).version == 4 - message = Message.from_dict(message_dict) + message = Message.from_dict(step_dict) + # Overwrite the created_at timestamp with the current time + message.created_at = datetime.utcnow().isoformat() asyncio.create_task(message._create()) - if files: + if not self.session.has_user_message: + self.session.has_user_message = True + asyncio.create_task(self.init_thread(message.to_dict())) + + if file_refs: + files = [ + self.session.files[file["id"]] + for file in file_refs + if file["id"] in self.session.files + ] file_elements = [Element.from_dict(file) for file in files] message.elements = file_elements @@ -176,38 +214,60 @@ async def send_elements(): asyncio.create_task(send_elements()) - if not self.session.has_user_message: - self.session.has_user_message = True - await self.init_conversation(await message.with_conversation_id()) - self.session.root_message = message return message async def send_ask_user( - self, msg_dict: Dict, spec: AskSpec, raise_on_timeout=False + self, step_dict: StepDict, spec: AskSpec, raise_on_timeout=False ): """Send a prompt to the UI and wait for a response.""" try: # Send the prompt to the UI - res = await self.ask_user( - {"msg": msg_dict, "spec": spec.to_dict()}, spec.timeout - ) # type: Optional["MessageDict"] + user_res = await self.ask_user( + {"msg": step_dict, "spec": spec.to_dict()}, spec.timeout + ) # type: Optional[Union["StepDict", "AskActionResponse", List["FileReference"]]] # End the task temporarily so that the User can answer the prompt await self.task_end() - if res: - # If cloud is enabled, store the response in the database/S3 + final_res: Optional[ + Union["StepDict", "AskActionResponse", List["FileDict"]] + ] = None + + if user_res: if spec.type == "text": - await self.process_user_message({"message": res, "files": None}) + message_dict_res = cast(StepDict, user_res) + await self.process_user_message( + {"message": message_dict_res, "fileReferences": None} + ) + final_res = message_dict_res elif spec.type == "file": - # TODO: upload file to S3 - pass - + file_refs = cast(List[FileReference], user_res) + files = [ + self.session.files[file["id"]] + for file in file_refs + if file["id"] in self.session.files + ] + final_res = files + if get_data_layer(): + coros = [ + File( + name=file["name"], + path=str(file["path"]), + mime=file["type"], + chainlit_key=file["id"], + for_id=step_dict["id"], + )._create() + for file in files + ] + await asyncio.gather(*coros) + elif spec.type == "action": + action_res = cast(AskActionResponse, user_res) + final_res = action_res await self.clear_ask() - return res + return final_res except TimeoutError as e: await self.send_ask_timeout() @@ -231,11 +291,11 @@ def task_end(self): """Send a task end signal to the UI.""" return self.emit("task_end", {}) - def stream_start(self, msg_dict: Dict): + def stream_start(self, step_dict: StepDict): """Send a stream start signal to the UI.""" return self.emit( "stream_start", - msg_dict, + step_dict, ) def send_token(self, id: str, token: str, is_sequence=False): @@ -246,3 +306,10 @@ def send_token(self, id: str, token: str, is_sequence=False): def set_chat_settings(self, settings: Dict[str, Any]): self.session.chat_settings = settings + + def send_action_response( + self, id: str, status: bool, response: Optional[str] = None + ): + return self.emit( + "action_response", {"id": id, "status": status, "response": response} + ) diff --git a/backend/chainlit/haystack/callbacks.py b/backend/chainlit/haystack/callbacks.py index bbedf7e57d..ad7c0b1489 100644 --- a/backend/chainlit/haystack/callbacks.py +++ b/backend/chainlit/haystack/callbacks.py @@ -1,12 +1,12 @@ +from datetime import datetime from typing import Any, Generic, List, Optional, TypeVar -from chainlit.config import config from chainlit.context import context +from chainlit.step import Step +from chainlit.sync import run_sync from haystack.agents import Agent, Tool from haystack.agents.agent_step import AgentStep -import chainlit as cl - T = TypeVar("T") @@ -31,8 +31,8 @@ def clear(self) -> None: class HaystackAgentCallbackHandler: - stack: Stack[cl.Message] - latest_agent_message: Optional[cl.Message] + stack: Stack[Step] + last_step: Optional[Step] def __init__(self, agent: Agent): agent.callback_manager.on_agent_start += self.on_agent_start @@ -44,55 +44,53 @@ def __init__(self, agent: Agent): agent.tm.callback_manager.on_tool_finish += self.on_tool_finish agent.tm.callback_manager.on_tool_error += self.on_tool_error - def get_root_message(self): - if not context.session.root_message: - root_message = cl.Message(author=config.ui.name, content="") - cl.run_sync(root_message.send()) - - return context.session.root_message - def on_agent_start(self, **kwargs: Any) -> None: # Prepare agent step message for streaming self.agent_name = kwargs.get("name", "Agent") - self.stack = Stack[cl.Message]() - self.stack.push(self.get_root_message()) + self.stack = Stack[Step]() + root_message = context.session.root_message + parent_id = root_message.id if root_message else None + run_step = Step(name=self.agent_name, type="run", parent_id=parent_id) + run_step.start = datetime.utcnow().isoformat() + run_step.input = kwargs - agent_message = cl.Message( - author=self.agent_name, parent_id=self.stack.peek().id, content="" - ) - self.stack.push(agent_message) + run_sync(run_step.send()) + + self.stack.push(run_step) + + def on_agent_finish(self, agent_step: AgentStep, **kwargs: Any) -> None: + run_step = self.stack.pop() + run_step.end = datetime.utcnow().isoformat() + run_step.output = agent_step.prompt_node_response + run_sync(run_step.update()) # This method is called when a step has finished def on_agent_step(self, agent_step: AgentStep, **kwargs: Any) -> None: # Send previous agent step message - self.latest_agent_message = self.stack.pop() + self.last_step = self.stack.pop() # If token streaming is disabled - if self.latest_agent_message.content == "": - self.latest_agent_message.content = agent_step.prompt_node_response - - cl.run_sync(self.latest_agent_message.send()) + if self.last_step.output == "": + self.last_step.output = agent_step.prompt_node_response + self.last_step.end = datetime.utcnow().isoformat() + run_sync(self.last_step.update()) if not agent_step.is_last(): - # Prepare message for next agent step - agent_message = cl.Message( - author=self.agent_name, parent_id=self.stack.peek().id, content="" - ) - self.stack.push(agent_message) - - def on_agent_finish(self, agent_step: AgentStep, **kwargs: Any) -> None: - self.latest_agent_message = None - self.stack.clear() + # Prepare step for next agent step + step = Step(name=self.agent_name, parent_id=self.stack.peek().id) + self.stack.push(step) def on_new_token(self, token, **kwargs: Any) -> None: # Stream agent step tokens - cl.run_sync(self.stack.peek().stream_token(token)) + run_sync(self.stack.peek().stream_token(token)) def on_tool_start(self, tool_input: str, tool: Tool, **kwargs: Any) -> None: - # Tool started, create message - parent_id = self.latest_agent_message.id if self.latest_agent_message else None - tool_message = cl.Message(author=tool.name, parent_id=parent_id, content="") - self.stack.push(tool_message) + # Tool started, create step + parent_id = self.stack.items[0].id if self.stack.items[0] else None + tool_step = Step(name=tool.name, type="tool", parent_id=parent_id) + tool_step.input = tool_input + tool_step.start = datetime.utcnow().isoformat() + self.stack.push(tool_step) def on_tool_finish( self, @@ -101,12 +99,16 @@ def on_tool_finish( tool_input: Optional[str] = None, **kwargs: Any ) -> None: - # Tool finished, send message with tool_result - tool_message = self.stack.pop() - tool_message.content = tool_result - cl.run_sync(tool_message.send()) + # Tool finished, send step with tool_result + tool_step = self.stack.pop() + tool_step.output = tool_result + tool_step.end = datetime.utcnow().isoformat() + run_sync(tool_step.update()) def on_tool_error(self, exception: Exception, tool: Tool, **kwargs: Any) -> None: # Tool error, send error message - cl.run_sync(self.stack.pop().remove()) - cl.run_sync(cl.ErrorMessage(str(exception), author=tool.name).send()) + error_step = self.stack.pop() + error_step.is_error = True + error_step.output = str(exception) + error_step.end = datetime.utcnow().isoformat() + run_sync(error_step.update()) diff --git a/backend/chainlit/hello.py b/backend/chainlit/hello.py index a404152c77..b4b8192612 100644 --- a/backend/chainlit/hello.py +++ b/backend/chainlit/hello.py @@ -8,5 +8,5 @@ async def main(): res = await AskUserMessage(content="What is your name?", timeout=30).send() if res: await Message( - content=f"Your name is: {res['content']}.\nChainlit installation is working!\nYou can now start building your own chainlit apps!", + content=f"Your name is: {res['output']}.\nChainlit installation is working!\nYou can now start building your own chainlit apps!", ).send() diff --git a/backend/chainlit/langchain/callbacks.py b/backend/chainlit/langchain/callbacks.py index 72254e57fe..259399f36b 100644 --- a/backend/chainlit/langchain/callbacks.py +++ b/backend/chainlit/langchain/callbacks.py @@ -1,13 +1,15 @@ +from datetime import datetime from typing import Any, Dict, List, Optional, Union from uuid import UUID from chainlit.context import context_var from chainlit.message import Message from chainlit.playground.providers.openai import stringify_function_call -from chainlit.prompt import Prompt, PromptMessage +from chainlit.step import Step, TrueStepType +from chainlit_client import ChatGeneration, CompletionGeneration, GenerationMessage from langchain.callbacks.tracers.base import BaseTracer from langchain.callbacks.tracers.schemas import Run -from langchain.schema.messages import BaseMessage +from langchain.schema import BaseMessage from langchain.schema.output import ChatGenerationChunk, GenerationChunk DEFAULT_ANSWER_PREFIX_TOKENS = ["Final", "Answer", ":"] @@ -85,15 +87,15 @@ def _append_to_last_tokens(self, token: str) -> None: self.last_tokens_stripped.pop(0) -class PromptHelper: - prompt_sequence: List[Prompt] +class GenerationHelper: + generation_sequence: List[Union[ChatGeneration, CompletionGeneration]] def __init__(self) -> None: - self.prompt_sequence = [] + self.generation_sequence = [] @property - def current_prompt(self): - return self.prompt_sequence[-1] if self.prompt_sequence else None + def current_generation(self): + return self.generation_sequence[-1] if self.generation_sequence else None def _convert_message_role(self, role: str): if "human" in role.lower(): @@ -109,7 +111,7 @@ def _convert_message_dict( self, message: Dict, template: Optional[str] = None, - template_format: Optional[str] = None, + template_format: str = "f-string", ): class_name = message["id"][-1] kwargs = message.get("kwargs", {}) @@ -118,7 +120,7 @@ def _convert_message_dict( content = stringify_function_call(function_call) else: content = kwargs.get("content", "") - return PromptMessage( + return GenerationMessage( name=kwargs.get("name"), role=self._convert_message_role(class_name), template=template, @@ -130,7 +132,7 @@ def _convert_message( self, message: Union[Dict, BaseMessage], template: Optional[str] = None, - template_format: Optional[str] = None, + template_format: str = "f-string", ): if isinstance(message, dict): return self._convert_message_dict( @@ -141,7 +143,7 @@ def _convert_message( content = stringify_function_call(function_call) else: content = message.content - return PromptMessage( + return GenerationMessage( name=getattr(message, "name", None), role=self._convert_message_role(message.type), template=template, @@ -165,16 +167,16 @@ def _get_messages(self, serialized: Dict): return chain_messages - def _build_prompt(self, serialized: Dict, inputs: Dict): + def _build_generation(self, serialized: Dict, inputs: Dict): messages = self._get_messages(serialized) if messages: # If prompt is chat, the formatted values will be added in on_chat_model_start - self._build_chat_template_prompt(messages, inputs) + self._build_chat_template_generation(messages, inputs) else: # For completion prompt everything is done here - self._build_completion_prompt(serialized, inputs) + self._build_completion_generation(serialized, inputs) - def _build_completion_prompt(self, serialized: Dict, inputs: Dict): + def _build_completion_generation(self, serialized: Dict, inputs: Dict): if not serialized: return kwargs = serialized.get("kwargs", {}) @@ -185,15 +187,15 @@ def _build_completion_prompt(self, serialized: Dict, inputs: Dict): if not template: return - self.prompt_sequence.append( - Prompt( + self.generation_sequence.append( + CompletionGeneration( template=template, template_format=template_format, inputs=stringified_inputs, ) ) - def _build_default_prompt( + def _build_default_generation( self, run: Run, generation_type: str, @@ -203,12 +205,12 @@ def _build_default_prompt( ): """Build a prompt once an LLM has been executed if no current prompt exists (without template)""" if "chat" in generation_type.lower(): - return Prompt( + return ChatGeneration( provider=provider, settings=llm_settings, completion=completion, messages=[ - PromptMessage( + GenerationMessage( formatted=formatted_prompt, role=self._convert_message_role(formatted_prompt.split(":")[0]), ) @@ -216,16 +218,16 @@ def _build_default_prompt( ], ) else: - return Prompt( + return CompletionGeneration( provider=provider, settings=llm_settings, completion=completion, formatted=run.inputs.get("prompts", [])[0], ) - def _build_chat_template_prompt(self, lc_messages: List[Dict], inputs: Dict): - def build_template_messages() -> List[PromptMessage]: - template_messages = [] # type: List[PromptMessage] + def _build_chat_template_generation(self, lc_messages: List[Dict], inputs: Dict): + def build_template_messages() -> List[GenerationMessage]: + template_messages = [] # type: List[GenerationMessage] if not lc_messages: return template_messages @@ -249,13 +251,12 @@ def build_template_messages() -> List[PromptMessage]: if placeholder_size: template_messages += [ - PromptMessage(placeholder_size=placeholder_size) + GenerationMessage(placeholder_size=placeholder_size) ] else: template_messages += [ - PromptMessage( + GenerationMessage( template=template, - template_format=template_format, role=self._convert_message_role(class_name), ) ] @@ -267,18 +268,18 @@ def build_template_messages() -> List[PromptMessage]: return stringified_inputs = {k: str(v) for (k, v) in inputs.items()} - self.prompt_sequence.append( - Prompt(messages=template_messages, inputs=stringified_inputs) + self.generation_sequence.append( + ChatGeneration(messages=template_messages, inputs=stringified_inputs) ) - def _build_chat_formatted_prompt( + def _build_chat_formatted_generation( self, lc_messages: Union[List[BaseMessage], List[dict]] ): - if not self.current_prompt: + if not self.current_generation: return - formatted_messages = [] # type: List[PromptMessage] - if self.current_prompt.messages: + formatted_messages = [] # type: List[GenerationMessage] + if self.current_generation.messages: # This is needed to compute the correct message index to read placeholder_offset = 0 # The final list of messages @@ -286,7 +287,7 @@ def _build_chat_formatted_prompt( # Looping the messages built in build_prompt # They only contain the template for template_index, template_message in enumerate( - self.current_prompt.messages + self.current_generation.messages ): # If a message has a placeholder size, we need to replace it # With the N following messages, where N is the placeholder size @@ -322,7 +323,7 @@ def _build_chat_formatted_prompt( self._convert_message(lc_message) for lc_message in lc_messages ] - self.current_prompt.messages = formatted_messages + self.current_generation.messages = formatted_messages def _build_llm_settings( self, @@ -356,8 +357,8 @@ def _build_llm_settings( DEFAULT_TO_KEEP = ["retriever", "llm", "agent", "chain", "tool"] -class LangchainTracer(BaseTracer, PromptHelper, FinalStreamHelper): - llm_stream_message: Dict[str, Message] +class LangchainTracer(BaseTracer, GenerationHelper, FinalStreamHelper): + steps: Dict[str, Step] parent_id_map: Dict[str, str] ignored_runs: set @@ -376,7 +377,7 @@ def __init__( **kwargs: Any, ) -> None: BaseTracer.__init__(self, **kwargs) - PromptHelper.__init__(self) + GenerationHelper.__init__(self) FinalStreamHelper.__init__( self, answer_prefix_tokens=answer_prefix_tokens, @@ -384,7 +385,7 @@ def __init__( force_stream_final_answer=force_stream_final_answer, ) self.context = context_var.get() - self.llm_stream_message = {} + self.steps = {} self.parent_id_map = {} self.ignored_runs = set() self.root_parent_id = ( @@ -420,33 +421,37 @@ def _get_non_ignored_parent_id(self, current_parent_id: Optional[str] = None): return self.root_parent_id if current_parent_id not in self.parent_id_map: - return current_parent_id + return None while current_parent_id in self.parent_id_map: - current_parent_id = self.parent_id_map[current_parent_id] + # If the parent id is in the ignored runs, we need to get the parent id of the ignored run + if current_parent_id in self.ignored_runs: + current_parent_id = self.parent_id_map[current_parent_id] + else: + return current_parent_id - return current_parent_id + return self.root_parent_id def _should_ignore_run(self, run: Run): parent_id = self._get_run_parent_id(run) + if parent_id: + # Add the parent id of the ignored run in the mapping + # so we can re-attach a kept child to the right parent id + self.parent_id_map[str(run.id)] = parent_id + ignore_by_name = run.name in self.to_ignore ignore_by_parent = parent_id in self.ignored_runs ignore = ignore_by_name or ignore_by_parent - if ignore: - if parent_id: - # Add the parent id of the ignored run in the mapping - # so we can re-attach a kept child to the right parent id - self.parent_id_map[str(run.id)] = parent_id - # Tag the run as ignored - self.ignored_runs.add(str(run.id)) - # If the ignore cause is the parent being ignored, check if we should nonetheless keep the child if ignore_by_parent and not ignore_by_name and run.run_type in self.to_keep: - return False, self._get_non_ignored_parent_id(str(run.id)) + return False, self._get_non_ignored_parent_id(parent_id) else: + if ignore: + # Tag the run as ignored + self.ignored_runs.add(str(run.id)) return ignore, parent_id def _is_annotable(self, run: Run): @@ -477,12 +482,12 @@ def on_chat_model_start( ) -> Any: """Adding formatted content and new message to the previously built template prompt""" lc_messages = messages[0] - if not self.current_prompt: - self.prompt_sequence.append( - Prompt(messages=[self._convert_message(m) for m in lc_messages]) + if not self.current_generation: + self.generation_sequence.append( + ChatGeneration(messages=[self._convert_message(m) for m in lc_messages]) ) else: - self._build_chat_formatted_prompt(lc_messages) + self._build_chat_formatted_generation(lc_messages) super().on_chat_model_start( serialized, @@ -503,7 +508,7 @@ def on_llm_new_token( parent_run_id: Optional[UUID] = None, **kwargs: Any, ) -> Any: - msg = self.llm_stream_message.get(str(run_id), None) + msg = self.steps.get(str(run_id), None) if msg: self._run_sync(msg.stream_token(token)) @@ -513,6 +518,7 @@ def on_llm_new_token( if self.answer_reached: if not self.final_stream: self.final_stream = Message(content="") + self._run_sync(self.final_stream.send()) self._run_sync(self.final_stream.stream_token(token)) self.has_streamed_final_answer = True else: @@ -533,36 +539,41 @@ def _start_trace(self, run: Run) -> None: if run.run_type in ["chain", "prompt"]: # Prompt templates are contained in chains or prompts (lcel) - self._build_prompt(run.serialized or {}, run.inputs) + self._build_generation(run.serialized or {}, run.inputs) ignore, parent_id = self._should_ignore_run(run) if ignore: return - disable_human_feedback = not self._is_annotable(run) + step_type: TrueStepType = "undefined" - if run.run_type == "llm": - msg = Message( - id=run.id, - content="", - author=run.name, - parent_id=parent_id, - disable_human_feedback=disable_human_feedback, - ) - self.llm_stream_message[str(run.id)] = msg - self._run_sync(msg.send()) - return - - self._run_sync( - Message( - id=run.id, - content="", - author=run.name, - parent_id=parent_id, - disable_human_feedback=disable_human_feedback, - ).send() + if run.run_type in ["agent", "chain"]: + step_type = "run" + elif run.run_type == "llm": + step_type = "llm" + elif run.run_type == "retriever": + step_type = "retrieval" + elif run.run_type == "tool": + step_type = "tool" + elif run.run_type == "embedding": + step_type = "embedding" + + disable_feedback = not self._is_annotable(run) + + step = Step( + id=str(run.id), + name=run.name, + type=step_type, + parent_id=parent_id, + disable_feedback=disable_feedback, ) + step.start = datetime.utcnow().isoformat() + step.input = run.inputs + + self.steps[str(run.id)] = step + + self._run_sync(step.send()) def _on_run_update(self, run: Run) -> None: """Process a run upon update.""" @@ -573,74 +584,75 @@ def _on_run_update(self, run: Run) -> None: if ignore: return - disable_human_feedback = not self._is_annotable(run) + current_step = self.steps.get(str(run.id), None) if run.run_type in ["chain"]: - if self.prompt_sequence: - self.prompt_sequence.pop() + if self.generation_sequence: + self.generation_sequence.pop() if run.run_type == "llm": provider, llm_settings = self._build_llm_settings( (run.serialized or {}), (run.extra or {}).get("invocation_params") ) generations = (run.outputs or {}).get("generations", []) + llm_output = (run.outputs or {}).get("llm_output") completion, language = self._get_completion(generations[0][0]) - current_prompt = ( - self.prompt_sequence.pop() if self.prompt_sequence else None + current_generation = ( + self.generation_sequence.pop() if self.generation_sequence else None ) - if current_prompt: - current_prompt.provider = provider - current_prompt.settings = llm_settings - current_prompt.completion = completion + if current_generation: + current_generation.provider = provider + current_generation.settings = llm_settings + current_generation.completion = completion else: generation_type = generations[0][0].get("type", "") - current_prompt = self._build_default_prompt( + current_generation = self._build_default_generation( run, generation_type, provider, llm_settings, completion ) - msg = self.llm_stream_message.get(str(run.id), None) - if msg: - msg.content = completion - msg.language = language - msg.prompt = current_prompt - self._run_sync(msg.update()) + if llm_output and current_generation: + token_count = llm_output.get("token_usage", {}).get("total_tokens") + current_generation.token_count = token_count + + if current_step: + current_step.output = completion + current_step.language = language + current_step.end = datetime.utcnow().isoformat() + current_step.generation = current_generation + self._run_sync(current_step.update()) if self.final_stream and self.has_streamed_final_answer: self.final_stream.content = completion self.final_stream.language = language - self.final_stream.prompt = current_prompt - self._run_sync(self.final_stream.send()) + self._run_sync(self.final_stream.update()) + return outputs = run.outputs or {} output_keys = list(outputs.keys()) + output = outputs if output_keys: - content = outputs.get(output_keys[0], "") - else: - return + output = outputs.get(output_keys[0], outputs) - if run.run_type in ["agent", "chain"]: - pass - # # Add the response of the chain/tool - # self._run_sync( - # Message( - # content=content, - # author=run.name, - # parent_id=parent_id, - # disable_human_feedback=disable_human_feedback, - # ).send() - # ) - else: - self._run_sync( - Message( - id=run.id, - content=content, - author=run.name, - parent_id=parent_id, - disable_human_feedback=disable_human_feedback, - ).update() - ) + if current_step: + current_step.output = output + current_step.end = datetime.utcnow().isoformat() + self._run_sync(current_step.update()) + + def _on_error(self, error: BaseException, *, run_id: UUID, **kwargs: Any): + context_var.set(self.context) + + if current_step := self.steps.get(str(run_id), None): + current_step.is_error = True + current_step.output = str(error) + current_step.end = datetime.utcnow().isoformat() + self._run_sync(current_step.update()) + + on_llm_error = _on_error + on_chain_error = _on_error + on_tool_error = _on_error + on_retriever_error = _on_error LangchainCallbackHandler = LangchainTracer diff --git a/backend/chainlit/llama_index/callbacks.py b/backend/chainlit/llama_index/callbacks.py index 6c7d50d9a7..6187068e45 100644 --- a/backend/chainlit/llama_index/callbacks.py +++ b/backend/chainlit/llama_index/callbacks.py @@ -1,11 +1,11 @@ -import asyncio +from datetime import datetime from typing import Any, Dict, List, Optional from chainlit.context import context_var from chainlit.element import Text -from chainlit.message import Message -from chainlit.prompt import Prompt, PromptMessage -from llama_index.callbacks.base import BaseCallbackHandler +from chainlit.step import Step, StepType +from chainlit_client import ChatGeneration, CompletionGeneration, GenerationMessage +from llama_index.callbacks import TokenCountingHandler from llama_index.callbacks.schema import CBEventType, EventPayload from llama_index.llms.base import ChatMessage, ChatResponse, CompletionResponse @@ -19,18 +19,31 @@ ] -class LlamaIndexCallbackHandler(BaseCallbackHandler): +class LlamaIndexCallbackHandler(TokenCountingHandler): """Base callback handler that can be used to track event starts and ends.""" + steps: Dict[str, Step] + def __init__( self, event_starts_to_ignore: List[CBEventType] = DEFAULT_IGNORE, event_ends_to_ignore: List[CBEventType] = DEFAULT_IGNORE, ) -> None: """Initialize the base callback handler.""" + super().__init__( + event_starts_to_ignore=event_starts_to_ignore, + event_ends_to_ignore=event_ends_to_ignore, + ) self.context = context_var.get() - self.event_starts_to_ignore = tuple(event_starts_to_ignore) - self.event_ends_to_ignore = tuple(event_ends_to_ignore) + + self.steps = {} + + def _get_parent_id(self, event_parent_id: Optional[str] = None) -> Optional[str]: + if event_parent_id and event_parent_id in self.steps: + return event_parent_id + if root_message := self.context.session.root_message: + return root_message.id + return None def _restore_context(self) -> None: """Restore Chainlit context in the current thread @@ -45,12 +58,6 @@ def _restore_context(self) -> None: """ context_var.set(self.context) - def _get_parent_id(self) -> Optional[str]: - """Get the parent message id""" - if root_message := self.context.session.root_message: - return root_message.id - return None - def on_event_start( self, event_type: CBEventType, @@ -61,14 +68,25 @@ def on_event_start( ) -> str: """Run when an event starts and return id of event.""" self._restore_context() - asyncio.run( - Message( - content="", - author=event_type, - parent_id=self._get_parent_id(), - ).send() + step_type: StepType = "undefined" + if event_type == CBEventType.RETRIEVE: + step_type = "retrieval" + elif event_type == CBEventType.LLM: + step_type = "llm" + else: + return event_id + + step = Step( + name=event_type.value, + type=step_type, + parent_id=self._get_parent_id(parent_id), + id=event_id, + disable_feedback=False, ) - + self.steps[event_id] = step + step.start = datetime.utcnow().isoformat() + step.input = payload or {} + self.context.loop.create_task(step.send()) return event_id def on_event_end( @@ -79,31 +97,27 @@ def on_event_end( **kwargs: Any, ) -> None: """Run when an event ends.""" - if payload is None: + step = self.steps.get(event_id, None) + + if payload is None or step is None: return self._restore_context() + step.end = datetime.utcnow().isoformat() + if event_type == CBEventType.RETRIEVE: sources = payload.get(EventPayload.NODES) if sources: - elements = [ - Text(name=f"Source {idx}", content=source.node.get_text()) - for idx, source in enumerate(sources) - ] source_refs = "\, ".join( [f"Source {idx}" for idx, _ in enumerate(sources)] ) - content = f"Retrieved the following sources: {source_refs}" - - asyncio.run( - Message( - content=content, - author=event_type, - elements=elements, - parent_id=self._get_parent_id(), - ).send() - ) + step.elements = [ + Text(name=f"Source {idx}", content=source.node.get_text()) + for idx, source in enumerate(sources) + ] + step.output = f"Retrieved the following sources: {source_refs}" + self.context.loop.create_task(step.update()) if event_type == CBEventType.LLM: formatted_messages = payload.get( @@ -114,7 +128,7 @@ def on_event_end( if formatted_messages: messages = [ - PromptMessage(role=m.role.value, formatted=m.content) # type: ignore[arg-type] + GenerationMessage(role=m.role.value, formatted=m.content) # type: ignore[arg-type] for m in formatted_messages ] else: @@ -127,18 +141,24 @@ def on_event_end( else: content = "" - asyncio.run( - Message( - content=content, - author=event_type, - parent_id=self._get_parent_id(), - prompt=Prompt( - formatted=formatted_prompt, - messages=messages, - completion=content, - ), - ).send() - ) + step.output = content + + token_count = self.total_llm_token_count or None + + if messages: + step.generation = ChatGeneration( + messages=messages, completion=content, token_count=token_count + ) + elif formatted_prompt: + step.generation = CompletionGeneration( + formatted=formatted_prompt, + completion=content, + token_count=token_count, + ) + + self.context.loop.create_task(step.update()) + + self.steps.pop(event_id, None) def _noop(self, *args, **kwargs): pass diff --git a/backend/chainlit/message.py b/backend/chainlit/message.py index 632d167d3a..3ee3d43239 100644 --- a/backend/chainlit/message.py +++ b/backend/chainlit/message.py @@ -1,76 +1,87 @@ +import asyncio import json import uuid -from abc import ABC, abstractmethod -from datetime import datetime, timezone +from abc import ABC +from datetime import datetime from typing import Dict, List, Optional, Union, cast from chainlit.action import Action -from chainlit.client.base import MessageDict from chainlit.config import config from chainlit.context import context -from chainlit.data import chainlit_client +from chainlit.data import get_data_layer from chainlit.element import ElementBased from chainlit.logger import logger -from chainlit.prompt import Prompt +from chainlit.step import StepDict from chainlit.telemetry import trace_event from chainlit.types import ( AskActionResponse, AskActionSpec, AskFileResponse, AskFileSpec, - AskResponse, AskSpec, + FileDict, ) -from syncer import asyncio +from chainlit_client.step import MessageStepType class MessageBase(ABC): id: str + thread_id: str author: str content: str = "" + type: MessageStepType = "assistant_message" + disable_feedback = False streaming = False - created_at: Union[int, str, None] = None + created_at: Union[str, None] = None fail_on_persist_error: bool = False persisted = False + is_error = False + language: Optional[str] = None + wait_for_answer = False + indent: Optional[int] = None def __post_init__(self) -> None: trace_event(f"init {self.__class__.__name__}") + self.thread_id = context.session.thread_id + if not getattr(self, "id", None): self.id = str(uuid.uuid4()) - if not self.created_at: - self.created_at = datetime.now(timezone.utc).isoformat() - - @abstractmethod - def to_dict(self) -> Dict: - pass - - async def with_conversation_id(self): - _dict = self.to_dict() - _dict["conversationId"] = await context.session.get_conversation_id() - return _dict - - async def _create(self): - msg_dict = await self.with_conversation_id() - asyncio.create_task(self._persist_create(msg_dict)) - if not config.features.prompt_playground: - msg_dict.pop("prompt", None) - return msg_dict + @classmethod + def from_dict(self, _dict: StepDict): + type = _dict.get("type", "assistant_message") + message = Message( + id=_dict["id"], + created_at=_dict["createdAt"], + content=_dict["output"], + author=_dict.get("name", config.ui.name), + type=type, # type: ignore + disable_feedback=_dict.get("disableFeedback", False), + language=_dict.get("language"), + ) - async def _persist_create(self, message: MessageDict): - if not chainlit_client or self.persisted: - return + return message - try: - persisted_id = await chainlit_client.create_message(message) + def to_dict(self) -> StepDict: + _dict: StepDict = { + "id": self.id, + "threadId": self.thread_id, + "createdAt": self.created_at, + "start": self.created_at, + "end": self.created_at, + "output": self.content, + "name": self.author, + "type": self.type, + "createdAt": self.created_at, + "language": self.language, + "streaming": self.streaming, + "disableFeedback": self.disable_feedback, + "isError": self.is_error, + "waitForAnswer": self.wait_for_answer, + "indent": self.indent, + } - if persisted_id: - self.id = persisted_id - self.persisted = True - except Exception as e: - if self.fail_on_persist_error: - raise e - logger.error(f"Failed to persist message creation: {str(e)}") + return _dict async def update( self, @@ -79,51 +90,63 @@ async def update( Update a message already sent to the UI. """ trace_event("update_message") + if self.streaming: self.streaming = False - msg_dict = self.to_dict() - asyncio.create_task(self._persist_update(msg_dict)) - await context.emitter.update_message(msg_dict) - return True + step_dict = self.to_dict() + + data_layer = get_data_layer() + if data_layer: + try: + asyncio.create_task(data_layer.update_step(step_dict)) + except Exception as e: + if self.fail_on_persist_error: + raise e + logger.error(f"Failed to persist message update: {str(e)}") - async def _persist_update(self, message: MessageDict): - if not chainlit_client or not self.id: - return + await context.emitter.update_step(step_dict) - try: - await chainlit_client.update_message(self.id, message) - except Exception as e: - if self.fail_on_persist_error: - raise e - logger.error(f"Failed to persist message update: {str(e)}") + return True async def remove(self): """ Remove a message already sent to the UI. - This will not automatically remove potential nested messages and could lead to undesirable side effects in the UI. """ trace_event("remove_message") - if chainlit_client and self.id: - await chainlit_client.delete_message(self.id) + step_dict = self.to_dict() + data_layer = get_data_layer() + if data_layer: + try: + asyncio.create_task(data_layer.delete_step(step_dict["id"])) + except Exception as e: + if self.fail_on_persist_error: + raise e + logger.error(f"Failed to persist message deletion: {str(e)}") - await context.emitter.delete_message(self.to_dict()) + await context.emitter.delete_step(step_dict) return True - async def _persist_remove(self): - if not chainlit_client or not self.id: - return + async def _create(self): + step_dict = self.to_dict() + data_layer = get_data_layer() + if data_layer and not self.persisted: + try: + asyncio.create_task(data_layer.create_step(step_dict)) + self.persisted = True + except Exception as e: + if self.fail_on_persist_error: + raise e + logger.error(f"Failed to persist message creation: {str(e)}") - try: - await chainlit_client.delete_message(self.id) - except Exception as e: - if self.fail_on_persist_error: - raise e - logger.error(f"Failed to persist message deletion: {str(e)}") + return step_dict async def send(self): + if not self.created_at: + self.created_at = datetime.utcnow().isoformat() + if self.content is None: self.content = "" @@ -133,9 +156,9 @@ async def send(self): if self.streaming: self.streaming = False - msg_dict = await self._create() + step_dict = await self._create() - await context.emitter.send_message(msg_dict) + await context.emitter.send_step(step_dict) return self.id @@ -147,8 +170,8 @@ async def stream_token(self, token: str, is_sequence=False): if not self.streaming: self.streaming = True - msg_dict = self.to_dict() - await context.emitter.stream_start(msg_dict) + step_dict = self.to_dict() + await context.emitter.stream_start(step_dict) if is_sequence: self.content = token @@ -164,33 +187,27 @@ async def stream_token(self, token: str, is_sequence=False): class Message(MessageBase): """ Send a message to the UI - If a project ID is configured, the message will be persisted in the cloud. Args: content (Union[str, Dict]): The content of the message. author (str, optional): The author of the message, this will be used in the UI. Defaults to the chatbot name (see config). - prompt (Prompt, optional): The prompt used to generate the message. If provided, enables the prompt playground for this message. language (str, optional): Language of the code is the content is code. See https://react-code-blocks-rajinwonderland.vercel.app/?path=/story/codeblock--supported-languages for a list of supported languages. - parent_id (str, optional): If provided, the message will be nested inside the parent in the UI. - indent (int, optional): If positive, the message will be nested in the UI. (deprecated, use parent_id instead) actions (List[Action], optional): A list of actions to send with the message. elements (List[ElementBased], optional): A list of elements to send with the message. - disable_human_feedback (bool, optional): Hide the feedback buttons for this specific message + disable_feedback (bool, optional): Hide the feedback buttons for this specific message """ def __init__( self, content: Union[str, Dict], author: str = config.ui.name, - prompt: Optional[Prompt] = None, language: Optional[str] = None, - parent_id: Optional[str] = None, - indent: int = 0, actions: Optional[List[Action]] = None, elements: Optional[List[ElementBased]] = None, - disable_human_feedback: Optional[bool] = False, - author_is_user: Optional[bool] = False, - id: Optional[uuid.UUID] = None, + disable_feedback: bool = False, + type: MessageStepType = "assistant_message", + id: Optional[str] = None, + created_at: Union[str, None] = None, ): self.language = language @@ -200,87 +217,45 @@ def __init__( self.language = "json" except TypeError: self.content = str(content) - self.language = "python" + self.language = "text" elif isinstance(content, str): self.content = content else: self.content = str(content) - self.language = "python" + self.language = "text" if id: self.id = str(id) + if created_at: + self.created_at = created_at + self.author = author - self.author_is_user = author_is_user - self.prompt = prompt - self.parent_id = parent_id - self.indent = indent + self.type = type self.actions = actions if actions is not None else [] self.elements = elements if elements is not None else [] - self.disable_human_feedback = disable_human_feedback + self.disable_feedback = disable_feedback super().__post_init__() - @classmethod - def from_dict(self, _dict: MessageDict): - message = Message( - content=_dict["content"], - author=_dict.get("author", config.ui.name), - prompt=_dict.get("prompt"), - language=_dict.get("language"), - parent_id=_dict.get("parentId"), - indent=_dict.get("indent") or 0, - disable_human_feedback=_dict.get("disableHumanFeedback"), - author_is_user=_dict.get("authorIsUser"), - ) - - if _id := _dict.get("id"): - message.id = _id - if created_at := _dict.get("createdAt"): - message.created_at = created_at - - return message - - def to_dict(self): - _dict = { - "createdAt": self.created_at, - "content": self.content, - "author": self.author, - "authorIsUser": self.author_is_user, - "language": self.language, - "parentId": self.parent_id, - "indent": self.indent, - "streaming": self.streaming, - "disableHumanFeedback": self.disable_human_feedback, - } - - if self.prompt: - _dict["prompt"] = self.prompt.to_dict() - - if self.id: - _dict["id"] = self.id - - return _dict - async def send(self) -> str: """ Send the message to the UI and persist it in the cloud if a project ID is configured. Return the ID of the message. """ trace_event("send_message") - id = await super().send() + await super().send() - if not self.parent_id: - context.session.root_message = self + context.session.root_message = self # Create tasks for all actions and elements - tasks = [action.send(for_id=id) for action in self.actions] - tasks.extend(element.send(for_id=id) for element in self.elements) + tasks = [action.send(for_id=self.id) for action in self.actions] + tasks.extend(element.send(for_id=self.id) for element in self.elements) # Run all tasks concurrently await asyncio.gather(*tasks) - return id + return self.id async def update(self): """ @@ -290,15 +265,16 @@ async def update(self): trace_event("send_message") await super().update() - actions_to_update = [action for action in self.actions if action.forId is None] + # Update tasks for all actions and elements + tasks = [ + action.send(for_id=self.id) + for action in self.actions + if action.forId is None + ] + tasks.extend(element.send(for_id=self.id) for element in self.elements) - elements_to_update = [el for el in self.elements if self.id not in el.for_ids] - - for action in actions_to_update: - await action.send(for_id=self.id) - - for element in elements_to_update: - await element.send(for_id=self.id) + # Run all tasks concurrently + await asyncio.gather(*tasks) return True @@ -323,29 +299,16 @@ def __init__( self, content: str, author: str = config.ui.name, - parent_id: Optional[str] = None, - indent: int = 0, fail_on_persist_error: bool = False, ): self.content = content self.author = author - self.parent_id = parent_id - self.indent = indent + self.type = "system_message" + self.is_error = True self.fail_on_persist_error = fail_on_persist_error super().__post_init__() - def to_dict(self): - return { - "id": self.id, - "createdAt": self.created_at, - "content": self.content, - "author": self.author, - "parentId": self.parent_id, - "indent": self.indent, - "isError": True, - } - async def send(self): """ Send the error message to the UI and persist it in the cloud if a project ID is configured. @@ -371,7 +334,7 @@ class AskUserMessage(AskMessageBase): Args: content (str): The content of the prompt. author (str, optional): The author of the message, this will be used in the UI. Defaults to the chatbot name (see config). - disable_human_feedback (bool, optional): Hide the feedback buttons for this specific message + disable_feedback (bool, optional): Hide the feedback buttons for this specific message timeout (int, optional): The number of seconds to wait for an answer before raising a TimeoutError. raise_on_timeout (bool, optional): Whether to raise a socketio TimeoutError if the user does not answer in time. """ @@ -380,45 +343,46 @@ def __init__( self, content: str, author: str = config.ui.name, - disable_human_feedback: bool = False, + type: MessageStepType = "assistant_message", + disable_feedback: bool = False, timeout: int = 60, raise_on_timeout: bool = False, ): self.content = content self.author = author self.timeout = timeout - self.disable_human_feedback = disable_human_feedback + self.type = type + self.disable_feedback = disable_feedback self.raise_on_timeout = raise_on_timeout super().__post_init__() - def to_dict(self): - return { - "id": self.id, - "createdAt": self.created_at, - "content": self.content, - "author": self.author, - "waitForAnswer": True, - "disableHumanFeedback": self.disable_human_feedback, - } - - async def send(self) -> Union[AskResponse, None]: + async def send(self) -> Union[StepDict, None]: """ Sends the question to ask to the UI and waits for the reply. """ trace_event("send_ask_user") + if not self.created_at: + self.created_at = datetime.utcnow().isoformat() + + if config.code.author_rename: + self.author = await config.code.author_rename(self.author) if self.streaming: self.streaming = False - if config.code.author_rename: - self.author = await config.code.author_rename(self.author) + self.wait_for_answer = True - msg_dict = await self._create() + step_dict = await self._create() spec = AskSpec(type="text", timeout=self.timeout) - res = await context.emitter.send_ask_user(msg_dict, spec, self.raise_on_timeout) + res = cast( + Union[None, StepDict], + await context.emitter.send_ask_user(step_dict, spec, self.raise_on_timeout), + ) + + self.wait_for_answer = False return res @@ -435,7 +399,7 @@ class AskFileMessage(AskMessageBase): max_size_mb (int, optional): Maximum size per file in MB. Maximum value is 100. max_files (int, optional): Maximum number of files to upload. Maximum value is 10. author (str, optional): The author of the message, this will be used in the UI. Defaults to the chatbot name (see config). - disable_human_feedback (bool, optional): Hide the feedback buttons for this specific message + disable_feedback (bool, optional): Hide the feedback buttons for this specific message timeout (int, optional): The number of seconds to wait for an answer before raising a TimeoutError. raise_on_timeout (bool, optional): Whether to raise a socketio TimeoutError if the user does not answer in time. """ @@ -447,7 +411,8 @@ def __init__( max_size_mb=2, max_files=1, author=config.ui.name, - disable_human_feedback: bool = False, + type: MessageStepType = "assistant_message", + disable_feedback: bool = False, timeout=90, raise_on_timeout=False, ): @@ -455,36 +420,32 @@ def __init__( self.max_size_mb = max_size_mb self.max_files = max_files self.accept = accept + self.type = type self.author = author self.timeout = timeout self.raise_on_timeout = raise_on_timeout - self.disable_human_feedback = disable_human_feedback + self.disable_feedback = disable_feedback super().__post_init__() - def to_dict(self): - return { - "id": self.id, - "createdAt": self.created_at, - "content": self.content, - "author": self.author, - "waitForAnswer": True, - "disableHumanFeedback": self.disable_human_feedback, - } - async def send(self) -> Union[List[AskFileResponse], None]: """ Sends the message to request a file from the user to the UI and waits for the reply. """ trace_event("send_ask_file") + if not self.created_at: + self.created_at = datetime.utcnow().isoformat() + if self.streaming: self.streaming = False if config.code.author_rename: self.author = await config.code.author_rename(self.author) - msg_dict = await self._create() + self.wait_for_answer = True + + step_dict = await self._create() spec = AskFileSpec( type="file", @@ -494,10 +455,24 @@ async def send(self) -> Union[List[AskFileResponse], None]: timeout=self.timeout, ) - res = await context.emitter.send_ask_user(msg_dict, spec, self.raise_on_timeout) + res = cast( + Union[None, List[FileDict]], + await context.emitter.send_ask_user(step_dict, spec, self.raise_on_timeout), + ) + + self.wait_for_answer = False if res: - return [AskFileResponse(**r) for r in res] + return [ + AskFileResponse( + id=r["id"], + name=r["name"], + path=str(r["path"]), + size=r["size"], + type=r["type"], + ) + for r in res + ] else: return None @@ -513,55 +488,49 @@ def __init__( content: str, actions: List[Action], author=config.ui.name, - disable_human_feedback=False, + disable_feedback=False, timeout=90, raise_on_timeout=False, ): self.content = content self.actions = actions self.author = author - self.disable_human_feedback = disable_human_feedback + self.disable_feedback = disable_feedback self.timeout = timeout self.raise_on_timeout = raise_on_timeout super().__post_init__() - def to_dict(self): - return { - "id": self.id, - "createdAt": self.created_at, - "content": self.content, - "author": self.author, - "waitForAnswer": True, - "disableHumanFeedback": self.disable_human_feedback, - "timeout": self.timeout, - "raiseOnTimeout": self.raise_on_timeout, - } - async def send(self) -> Union[AskActionResponse, None]: """ Sends the question to ask to the UI and waits for the reply """ trace_event("send_ask_action") + if not self.created_at: + self.created_at = datetime.utcnow().isoformat() + if self.streaming: self.streaming = False if config.code.author_rename: self.author = await config.code.author_rename(self.author) - msg_dict = await self._create() + self.wait_for_answer = True + + step_dict = await self._create() + action_keys = [] for action in self.actions: action_keys.append(action.id) - await action.send(for_id=str(msg_dict["id"])) + await action.send(for_id=str(step_dict["id"])) spec = AskActionSpec(type="action", timeout=self.timeout, keys=action_keys) res = cast( Union[AskActionResponse, None], - await context.emitter.send_ask_user(msg_dict, spec, self.raise_on_timeout), + await context.emitter.send_ask_user(step_dict, spec, self.raise_on_timeout), ) for action in self.actions: @@ -570,6 +539,9 @@ async def send(self) -> Union[AskActionResponse, None]: self.content = "Timed out: no action was taken" else: self.content = f'**Selected action:** {res["label"]}' + + self.wait_for_answer = False + await self.update() return res diff --git a/backend/chainlit/oauth_providers.py b/backend/chainlit/oauth_providers.py index 9fea48c7df..6c9400a6b5 100644 --- a/backend/chainlit/oauth_providers.py +++ b/backend/chainlit/oauth_providers.py @@ -4,7 +4,7 @@ from typing import Dict, List, Optional, Tuple import httpx -from chainlit.client.base import AppUser +from chainlit.user import User from fastapi import HTTPException @@ -22,7 +22,7 @@ def is_configured(self): async def get_token(self, code: str, url: str) -> str: raise NotImplementedError() - async def get_user_info(self, token: str) -> Tuple[Dict[str, str], AppUser]: + async def get_user_info(self, token: str) -> Tuple[Dict[str, str], User]: raise NotImplementedError() @@ -65,7 +65,7 @@ async def get_user_info(self, token: str): headers={"Authorization": f"token {token}"}, ) user_response.raise_for_status() - user = user_response.json() + github_user = user_response.json() emails_response = await client.get( "https://api.github.com/user/emails", @@ -74,14 +74,12 @@ async def get_user_info(self, token: str): emails_response.raise_for_status() emails = emails_response.json() - user.update({"emails": emails}) - - app_user = AppUser( - username=user["login"], - image=user["avatar_url"], - provider="github", + github_user.update({"emails": emails}) + user = User( + identifier=github_user["login"], + metadata={"image": github_user["avatar_url"], "provider": "github"}, ) - return (user, app_user) + return (github_user, user) class GoogleOAuthProvider(OAuthProvider): @@ -129,12 +127,12 @@ async def get_user_info(self, token: str): headers={"Authorization": f"Bearer {token}"}, ) response.raise_for_status() - user = response.json() - - app_user = AppUser( - username=user["name"], image=user["picture"], provider="google" + google_user = response.json() + user = User( + identifier=google_user["email"], + metadata={"image": google_user["picture"], "provider": "google"}, ) - return (user, app_user) + return (google_user, user) class AzureADOAuthProvider(OAuthProvider): @@ -196,7 +194,7 @@ async def get_user_info(self, token: str): ) response.raise_for_status() - user = response.json() + azure_user = response.json() try: photo_response = await client.get( @@ -205,19 +203,18 @@ async def get_user_info(self, token: str): ) photo_data = await photo_response.aread() base64_image = base64.b64encode(photo_data) - user[ + azure_user[ "image" ] = f"data:{photo_response.headers['Content-Type']};base64,{base64_image.decode('utf-8')}" except Exception as e: # Ignore errors getting the photo pass - app_user = AppUser( - username=user["userPrincipalName"], - image=user.get("image", ""), - provider="azure-ad", + user = User( + identifier=azure_user["userPrincipalName"], + metadata={"image": azure_user.get("image"), "provider": "azure-ad"}, ) - return (user, app_user) + return (azure_user, user) class OktaOAuthProvider(OAuthProvider): @@ -284,10 +281,13 @@ async def get_user_info(self, token: str): headers={"Authorization": f"Bearer {token}"}, ) response.raise_for_status() - user = response.json() + okta_user = response.json() - app_user = AppUser(username=user.get("email"), image="", provider="okta") - return (user, app_user) + user = User( + identifier=okta_user.get("email"), + metadata={"image": "", "provider": "okta"}, + ) + return (okta_user, user) class Auth0OAuthProvider(OAuthProvider): @@ -342,13 +342,15 @@ async def get_user_info(self, token: str): headers={"Authorization": f"Bearer {token}"}, ) response.raise_for_status() - user = response.json() - app_user = AppUser( - username=user.get("email"), - image=user.get("picture", ""), - provider="auth0", + auth0_user = response.json() + user = User( + identifier=auth0_user.get("email"), + metadata={ + "image": auth0_user.get("picture", ""), + "provider": "auth0", + }, ) - return (user, app_user) + return (auth0_user, user) class DescopeOAuthProvider(OAuthProvider): @@ -398,10 +400,13 @@ async def get_user_info(self, token: str): f"{self.domain}/userinfo", headers={"Authorization": f"Bearer {token}"} ) response.raise_for_status() # This will raise an exception for 4xx/5xx responses - user = response.json() + descope_user = response.json() - app_user = AppUser(username=user.get("email"), image="", provider="descope") - return (user, app_user) + user = User( + identifier=descope_user.get("email"), + metadata={"image": "", "provider": "descope"}, + ) + return (descope_user, user) providers = [ diff --git a/backend/chainlit/playground/provider.py b/backend/chainlit/playground/provider.py index 234af441ae..82b71b5486 100644 --- a/backend/chainlit/playground/provider.py +++ b/backend/chainlit/playground/provider.py @@ -1,10 +1,10 @@ import os -from typing import Any, Dict, List, Union +from typing import Any, Dict, List, Optional, Union from chainlit.config import config -from chainlit.prompt import Prompt, PromptMessage from chainlit.telemetry import trace_event -from chainlit.types import CompletionRequest +from chainlit.types import GenerationRequest +from chainlit_client import BaseGeneration, ChatGeneration, GenerationMessage from fastapi import HTTPException from pydantic.dataclasses import dataclass @@ -20,67 +20,81 @@ class BaseProvider: is_chat: bool # Format the message based on the template provided - def format_message(self, message: PromptMessage, prompt: Prompt): + def format_message(self, message: GenerationMessage, inputs: Optional[Dict]): if message.template: - message.formatted = self._format_template(message.template, prompt) + message.formatted = self._format_template( + message.template, inputs, message.template_format + ) return message # Convert the message to string format - def message_to_string(self, message: PromptMessage): + def message_to_string(self, message: GenerationMessage): return message.formatted # Concatenate multiple messages with a joiner - def concatenate_messages(self, messages: List[PromptMessage], joiner="\n\n"): + def concatenate_messages(self, messages: List[GenerationMessage], joiner="\n\n"): return joiner.join([self.message_to_string(m) for m in messages]) # Format the template based on the prompt inputs - def _format_template(self, template: str, prompt: Prompt): - if prompt.template_format == "f-string": - return template.format(**(prompt.inputs or {})) - raise HTTPException( - status_code=422, detail=f"Unsupported format {prompt.template_format}" - ) + def _format_template( + self, template: str, inputs: Optional[Dict], format: str = "f-string" + ): + if format == "f-string": + return template.format(**(inputs or {})) + raise HTTPException(status_code=422, detail=f"Unsupported format {format}") # Create a prompt based on the request - def create_prompt(self, request: CompletionRequest): - prompt = request.prompt - if prompt.messages: - messages = [self.format_message(m, prompt=prompt) for m in prompt.messages] + def create_generation(self, request: GenerationRequest): + if request.chatGeneration and request.chatGeneration.messages: + messages = [ + self.format_message(m, request.chatGeneration.inputs) + for m in request.chatGeneration.messages + ] else: messages = None if self.is_chat: if messages: return messages - elif prompt.template or prompt.formatted: + elif request.completionGeneration and ( + request.completionGeneration.template + or request.completionGeneration.formatted + ): return [ self.format_message( - PromptMessage( - template=prompt.template, - formatted=prompt.formatted, + GenerationMessage( + template=request.completionGeneration.template, + formatted=request.completionGeneration.formatted, role="user", ), - prompt=prompt, + inputs=request.completionGeneration.inputs, ) ] else: - raise HTTPException(status_code=422, detail="Could not create prompt") + raise HTTPException( + status_code=422, detail="Could not create generation" + ) else: - if prompt.template: - return self._format_template(prompt.template, prompt=prompt) + if request.completionGeneration: + if request.completionGeneration.template: + return self._format_template( + request.completionGeneration.template, + request.completionGeneration.inputs, + request.completionGeneration.template_format, + ) + elif request.completionGeneration.formatted: + return request.completionGeneration.formatted elif messages: return self.concatenate_messages(messages) - elif prompt.formatted: - return prompt.formatted else: raise HTTPException(status_code=422, detail="Could not create prompt") # Create a completion event - async def create_completion(self, request: CompletionRequest): + async def create_completion(self, request: GenerationRequest): trace_event("completion") # Get the environment variable based on the request - def get_var(self, request: CompletionRequest, var: str) -> Union[str, None]: + def get_var(self, request: GenerationRequest, var: str) -> Union[str, None]: user_env = config.project.user_env or [] if var in user_env: @@ -101,7 +115,7 @@ def is_configured(self): return True # Validate the environment variables in the request - def validate_env(self, request: CompletionRequest): + def validate_env(self, request: GenerationRequest): return {k: self.get_var(request, v) for k, v in self.env_vars.items()} # Check if the required settings are present diff --git a/backend/chainlit/playground/providers/anthropic.py b/backend/chainlit/playground/providers/anthropic.py index 5dbcb2c298..c9b8a61c99 100644 --- a/backend/chainlit/playground/providers/anthropic.py +++ b/backend/chainlit/playground/providers/anthropic.py @@ -1,12 +1,12 @@ from chainlit.input_widget import Select, Slider, Tags from chainlit.playground.provider import BaseProvider -from chainlit.prompt import PromptMessage +from chainlit_client import GenerationMessage from fastapi import HTTPException from fastapi.responses import StreamingResponse class AnthropicProvider(BaseProvider): - def message_to_string(self, message: PromptMessage) -> str: + def message_to_string(self, message: GenerationMessage) -> str: import anthropic if message.role == "user": @@ -29,10 +29,10 @@ async def create_completion(self, request): env_settings = self.validate_env(request=request) - llm_settings = request.prompt.settings + llm_settings = request.generation.settings self.require_settings(llm_settings) - prompt = self.concatenate_messages(self.create_prompt(request), joiner="") + prompt = self.concatenate_messages(self.create_generation(request), joiner="") if not prompt.endswith(anthropic.AI_PROMPT): prompt += anthropic.AI_PROMPT diff --git a/backend/chainlit/playground/providers/huggingface.py b/backend/chainlit/playground/providers/huggingface.py index 62beef8102..d01f2b31dc 100644 --- a/backend/chainlit/playground/providers/huggingface.py +++ b/backend/chainlit/playground/providers/huggingface.py @@ -18,7 +18,7 @@ async def create_completion(self, request): from huggingface_hub.inference_api import InferenceApi env_settings = self.validate_env(request=request) - llm_settings = request.prompt.settings + llm_settings = request.generation.settings self.require_settings(llm_settings) client = InferenceApi( @@ -27,7 +27,7 @@ async def create_completion(self, request): task=self.task, ) - prompt = self.create_prompt(request) + prompt = self.create_generation(request) response = await make_async(client)(inputs=prompt, params=llm_settings) diff --git a/backend/chainlit/playground/providers/langchain.py b/backend/chainlit/playground/providers/langchain.py index 3eda1c84c7..522414b914 100644 --- a/backend/chainlit/playground/providers/langchain.py +++ b/backend/chainlit/playground/providers/langchain.py @@ -1,17 +1,15 @@ from typing import Union -from fastapi.responses import StreamingResponse - from chainlit.playground.provider import BaseProvider -from chainlit.prompt import PromptMessage from chainlit.sync import make_async - -from chainlit import input_widget +from chainlit_client import GenerationMessage +from fastapi.responses import StreamingResponse class LangchainGenericProvider(BaseProvider): from langchain.chat_models.base import BaseChatModel from langchain.llms.base import LLM + from langchain.schema import BaseMessage llm: Union[LLM, BaseChatModel] @@ -31,7 +29,7 @@ def __init__( ) self.llm = llm - def prompt_message_to_langchain_message(self, message: PromptMessage): + def prompt_message_to_langchain_message(self, message: GenerationMessage): from langchain.schema.messages import ( AIMessage, FunctionMessage, @@ -46,7 +44,7 @@ def prompt_message_to_langchain_message(self, message: PromptMessage): return AIMessage(content=content) elif message.role == "system": return SystemMessage(content=content) - elif message.role == "function": + elif message.role == "tool": return FunctionMessage( content=content, name=message.name if message.name else "function" ) @@ -57,15 +55,15 @@ def format_message(self, message, prompt): message = super().format_message(message, prompt) return self.prompt_message_to_langchain_message(message) - def message_to_string(self, message: PromptMessage) -> str: - return message.to_string() + def message_to_string(self, message: BaseMessage) -> str: # type: ignore[override] + return message.content async def create_completion(self, request): from langchain.schema.messages import BaseMessageChunk await super().create_completion(request) - messages = self.create_prompt(request) + messages = self.create_generation(request) stream = make_async(self.llm.stream) diff --git a/backend/chainlit/playground/providers/openai.py b/backend/chainlit/playground/providers/openai.py index 443d220221..2b11e519f6 100644 --- a/backend/chainlit/playground/providers/openai.py +++ b/backend/chainlit/playground/providers/openai.py @@ -127,11 +127,11 @@ async def create_completion(self, request): client = AsyncClient(api_key=env_settings["api_key"]) - llm_settings = request.prompt.settings + llm_settings = request.generation.settings self.require_settings(llm_settings) - messages = self.create_prompt(request) + messages = self.create_generation(request) if "stop" in llm_settings: stop = llm_settings["stop"] @@ -142,8 +142,8 @@ async def create_completion(self, request): llm_settings["stop"] = stop - if request.prompt.functions: - llm_settings["functions"] = request.prompt.functions + if request.generation.functions: + llm_settings["functions"] = request.generation.functions llm_settings["stream"] = False else: llm_settings["stream"] = True @@ -188,11 +188,11 @@ async def create_completion(self, request): client = AsyncClient(api_key=env_settings["api_key"]) - llm_settings = request.prompt.settings + llm_settings = request.generation.settings self.require_settings(llm_settings) - prompt = self.create_prompt(request) + prompt = self.create_generation(request) if "stop" in llm_settings: stop = llm_settings["stop"] @@ -240,11 +240,11 @@ async def create_completion(self, request): azure_ad_token_provider=self.get_var(request, "AZURE_AD_TOKEN_PROVIDER"), azure_deployment=self.get_var(request, "AZURE_DEPLOYMENT"), ) - llm_settings = request.prompt.settings + llm_settings = request.generation.settings self.require_settings(llm_settings) - prompt = self.create_prompt(request) + prompt = self.create_generation(request) if "stop" in llm_settings: stop = llm_settings["stop"] @@ -294,11 +294,11 @@ async def create_completion(self, request): azure_deployment=self.get_var(request, "AZURE_DEPLOYMENT"), ) - llm_settings = request.prompt.settings + llm_settings = request.generation.settings self.require_settings(llm_settings) - messages = self.create_prompt(request) + messages = self.create_generation(request) if "stop" in llm_settings: stop = llm_settings["stop"] @@ -311,8 +311,8 @@ async def create_completion(self, request): llm_settings["model"] = env_settings["deployment_name"] - if request.prompt.functions: - llm_settings["functions"] = request.prompt.functions + if request.generation.functions: + llm_settings["functions"] = request.generation.functions llm_settings["stream"] = False else: llm_settings["stream"] = True @@ -362,7 +362,13 @@ async def create_event_stream(): Select( id="model", label="Model", - values=["gpt-3.5-turbo", "gpt-3.5-turbo-16k", "gpt-4", "gpt-4-32k", "gpt-4-1106-preview"], + values=[ + "gpt-3.5-turbo", + "gpt-3.5-turbo-16k", + "gpt-4", + "gpt-4-32k", + "gpt-4-1106-preview", + ], initial_value="gpt-3.5-turbo", ), *openai_common_inputs, diff --git a/backend/chainlit/prompt.py b/backend/chainlit/prompt.py deleted file mode 100644 index 4aab7ff0c4..0000000000 --- a/backend/chainlit/prompt.py +++ /dev/null @@ -1,40 +0,0 @@ -from typing import Any, Dict, List, Literal, Optional - -from dataclasses_json import DataClassJsonMixin -from pydantic.dataclasses import dataclass - - -@dataclass -class BaseTemplate(DataClassJsonMixin): - template: Optional[str] = None - formatted: Optional[str] = None - template_format: Optional[str] = "f-string" - - -@dataclass -class PromptMessage(BaseTemplate): - # This is used for Langchain's MessagesPlaceholder - placeholder_size: Optional[int] = None - # This is used for OpenAI's function message - name: Optional[str] = None - role: Optional[Literal["system", "assistant", "user", "function"]] = None - - def to_openai(self): - msg_dict = {"role": self.role, "content": self.formatted} - if self.role == "function": - msg_dict["name"] = self.name or "" - return msg_dict - - def to_string(self): - return f"{self.role}: {self.formatted}" - - -@dataclass -class Prompt(BaseTemplate): - provider: Optional[str] = None - id: Optional[str] = None - inputs: Optional[Dict[str, str]] = None - completion: Optional[str] = None - settings: Optional[Dict[str, Any]] = None - messages: Optional[List[PromptMessage]] = None - functions: Optional[List[Dict]] = None diff --git a/backend/chainlit/server.py b/backend/chainlit/server.py index 599ae1bb97..b2814d34f0 100644 --- a/backend/chainlit/server.py +++ b/backend/chainlit/server.py @@ -1,6 +1,7 @@ import glob import json import mimetypes +import shutil import urllib.parse from typing import Optional, Union @@ -17,30 +18,31 @@ from pathlib import Path from chainlit.auth import create_jwt, get_configuration, get_current_user -from chainlit.client.cloud import AppUser, PersistedAppUser from chainlit.config import ( APP_ROOT, BACKEND_ROOT, DEFAULT_HOST, + FILES_DIRECTORY, PACKAGE_ROOT, config, load_module, reload_config, ) -from chainlit.data import chainlit_client -from chainlit.data.acl import is_conversation_author +from chainlit.data import get_data_layer +from chainlit.data.acl import is_thread_author from chainlit.logger import logger from chainlit.markdown import get_markdown_str from chainlit.playground.config import get_llm_providers from chainlit.telemetry import trace_event from chainlit.types import ( - CompletionRequest, - DeleteConversationRequest, - GetConversationsRequest, + DeleteThreadRequest, + GenerationRequest, + GetThreadsRequest, Theme, UpdateFeedbackRequest, ) -from fastapi import Depends, FastAPI, HTTPException, Query, Request, status +from chainlit.user import PersistedUser, User +from fastapi import Depends, FastAPI, HTTPException, Query, Request, UploadFile, status from fastapi.responses import FileResponse, HTMLResponse, JSONResponse, RedirectResponse from fastapi.security import OAuth2PasswordRequestForm from fastapi.staticfiles import StaticFiles @@ -117,6 +119,9 @@ async def watch_files_for_changes(): except asyncio.exceptions.CancelledError: pass + if FILES_DIRECTORY.is_dir(): + shutil.rmtree(FILES_DIRECTORY) + # Force exit the process to avoid potential AnyIO threads still running os._exit(0) @@ -134,6 +139,7 @@ def get_build_dir(): build_dir = get_build_dir() + app = FastAPI(lifespan=lifespan) app.mount("/public", StaticFiles(directory="public", check_dir=False), name="public") @@ -156,14 +162,10 @@ def get_build_dir(): ) -# Define max HTTP data size to 100 MB -max_message_size = 100 * 1024 * 1024 - socket = SocketManager( app, cors_allowed_origins=[], async_mode="asgi", - max_http_buffer_size=max_message_size, ) @@ -244,18 +246,18 @@ async def login(form_data: OAuth2PasswordRequestForm = Depends()): status_code=status.HTTP_400_BAD_REQUEST, detail="No auth_callback defined" ) - app_user = await config.code.password_auth_callback( + user = await config.code.password_auth_callback( form_data.username, form_data.password ) - if not app_user: + if not user: raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="credentialssignin", ) - access_token = create_jwt(app_user) - if chainlit_client: - await chainlit_client.create_app_user(app_user=app_user) + access_token = create_jwt(user) + if data_layer := get_data_layer(): + await data_layer.create_user(user) return { "access_token": access_token, "token_type": "bearer", @@ -270,17 +272,17 @@ async def header_auth(request: Request): detail="No header_auth_callback defined", ) - app_user = await config.code.header_auth_callback(request.headers) + user = await config.code.header_auth_callback(request.headers) - if not app_user: + if not user: raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="Unauthorized", ) - access_token = create_jwt(app_user) - if chainlit_client: - await chainlit_client.create_app_user(app_user=app_user) + access_token = create_jwt(user) + if data_layer := get_data_layer(): + await data_layer.create_user(user) return { "access_token": access_token, "token_type": "bearer", @@ -369,21 +371,22 @@ async def oauth_callback( url = get_user_facing_url(request.url) token = await provider.get_token(code, url) - (raw_user_data, default_app_user) = await provider.get_user_info(token) + (raw_user_data, default_user) = await provider.get_user_info(token) - app_user = await config.code.oauth_callback( - provider_id, token, raw_user_data, default_app_user + user = await config.code.oauth_callback( + provider_id, token, raw_user_data, default_user ) - if not app_user: + if not user: raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="Unauthorized", ) - access_token = create_jwt(app_user) - if chainlit_client: - await chainlit_client.create_app_user(app_user=app_user) + access_token = create_jwt(user) + + if data_layer := get_data_layer(): + await data_layer.create_user(user) params = urllib.parse.urlencode( { @@ -399,23 +402,21 @@ async def oauth_callback( return response -@app.post("/completion") -async def completion( - request: CompletionRequest, - current_user: Annotated[ - Union[AppUser, PersistedAppUser], Depends(get_current_user) - ], +@app.post("/generation") +async def generation( + request: GenerationRequest, + current_user: Annotated[Union[User, PersistedUser], Depends(get_current_user)], ): """Handle a completion request from the prompt playground.""" providers = get_llm_providers() try: - provider = [p for p in providers if p.id == request.prompt.provider][0] + provider = [p for p in providers if p.id == request.generation.provider][0] except IndexError: raise HTTPException( status_code=404, - detail=f"LLM provider '{request.prompt.provider}' not found", + detail=f"LLM provider '{request.generation.provider}' not found", ) trace_event("pp_create_completion") @@ -426,7 +427,7 @@ async def completion( @app.get("/project/llm-providers") async def get_providers( - current_user: Annotated[Union[AppUser, PersistedAppUser], Depends(get_current_user)] + current_user: Annotated[Union[User, PersistedUser], Depends(get_current_user)] ): """List the providers.""" trace_event("pp_get_llm_providers") @@ -437,7 +438,7 @@ async def get_providers( @app.get("/project/settings") async def project_settings( - current_user: Annotated[Union[AppUser, PersistedAppUser], Depends(get_current_user)] + current_user: Annotated[Union[User, PersistedUser], Depends(get_current_user)] ): """Return project settings. This is called by the UI before the establishing the websocket connection.""" profiles = [] @@ -450,126 +451,181 @@ async def project_settings( "ui": config.ui.to_dict(), "features": config.features.to_dict(), "userEnv": config.project.user_env, - "dataPersistence": config.data_persistence, - "conversationResumable": bool(config.code.on_chat_resume), + "dataPersistence": get_data_layer() is not None, + "threadResumable": bool(config.code.on_chat_resume), "markdown": get_markdown_str(config.root), "chatProfiles": profiles, } ) -@app.put("/message/feedback") +@app.put("/feedback") async def update_feedback( request: Request, update: UpdateFeedbackRequest, - current_user: Annotated[ - Union[AppUser, PersistedAppUser], Depends(get_current_user) - ], + current_user: Annotated[Union[User, PersistedUser], Depends(get_current_user)], ): """Update the human feedback for a particular message.""" - - # TODO: check that message belong to a user's conversation - - if not chainlit_client: + data_layer = get_data_layer() + if not data_layer: raise HTTPException(status_code=400, detail="Data persistence is not enabled") try: - await chainlit_client.set_human_feedback( - message_id=update.messageId, - feedback=update.feedback, - feedbackComment=update.feedbackComment, - ) + feedback_id = await data_layer.upsert_feedback(feedback=update.feedback) except Exception as e: raise HTTPException(detail=str(e), status_code=401) - return JSONResponse(content={"success": True}) + return JSONResponse(content={"success": True, "feedbackId": feedback_id}) -@app.post("/project/conversations") -async def get_user_conversations( +@app.post("/project/threads") +async def get_user_threads( request: Request, - payload: GetConversationsRequest, - current_user: Annotated[ - Union[AppUser, PersistedAppUser], Depends(get_current_user) - ], + payload: GetThreadsRequest, + current_user: Annotated[Union[User, PersistedUser], Depends(get_current_user)], ): - """Get the conversations page by page.""" - # Only show the current user conversations + """Get the threads page by page.""" + # Only show the current user threads - if not chainlit_client: + data_layer = get_data_layer() + + if not data_layer: raise HTTPException(status_code=400, detail="Data persistence is not enabled") - payload.filter.username = current_user.username - res = await chainlit_client.get_conversations(payload.pagination, payload.filter) + payload.filter.userIdentifier = current_user.identifier + + res = await data_layer.list_threads(payload.pagination, payload.filter) return JSONResponse(content=res.to_dict()) -@app.get("/project/conversation/{conversation_id}") -async def get_conversation( +@app.get("/project/thread/{thread_id}") +async def get_thread( request: Request, - conversation_id: str, - current_user: Annotated[ - Union[AppUser, PersistedAppUser], Depends(get_current_user) - ], + thread_id: str, + current_user: Annotated[Union[User, PersistedUser], Depends(get_current_user)], ): - """Get a specific conversation.""" + """Get a specific thread.""" + data_layer = get_data_layer() - if not chainlit_client: + if not data_layer: raise HTTPException(status_code=400, detail="Data persistence is not enabled") - await is_conversation_author(current_user.username, conversation_id) + await is_thread_author(current_user.identifier, thread_id) - res = await chainlit_client.get_conversation(conversation_id) + res = await data_layer.get_thread(thread_id) return JSONResponse(content=res) -@app.get("/project/conversation/{conversation_id}/element/{element_id}") -async def get_conversation_element( +@app.get("/project/thread/{thread_id}/element/{element_id}") +async def get_thread_element( request: Request, - conversation_id: str, + thread_id: str, element_id: str, - current_user: Annotated[ - Union[AppUser, PersistedAppUser], Depends(get_current_user) - ], + current_user: Annotated[Union[User, PersistedUser], Depends(get_current_user)], ): - """Get a specific conversation element.""" + """Get a specific thread element.""" + data_layer = get_data_layer() - if not chainlit_client: + if not data_layer: raise HTTPException(status_code=400, detail="Data persistence is not enabled") - await is_conversation_author(current_user.username, conversation_id) + await is_thread_author(current_user.identifier, thread_id) - res = await chainlit_client.get_element(conversation_id, element_id) + res = await data_layer.get_element(thread_id, element_id) return JSONResponse(content=res) -@app.delete("/project/conversation") -async def delete_conversation( +@app.delete("/project/thread") +async def delete_thread( request: Request, - payload: DeleteConversationRequest, - current_user: Annotated[ - Union[AppUser, PersistedAppUser], Depends(get_current_user) - ], + payload: DeleteThreadRequest, + current_user: Annotated[Union[User, PersistedUser], Depends(get_current_user)], ): - """Delete a conversation.""" + """Delete a thread.""" + + data_layer = get_data_layer() - if not chainlit_client: + if not data_layer: raise HTTPException(status_code=400, detail="Data persistence is not enabled") - conversation_id = payload.conversationId + thread_id = payload.threadId - await is_conversation_author(current_user.username, conversation_id) + await is_thread_author(current_user.identifier, thread_id) - await chainlit_client.delete_conversation(conversation_id) + await data_layer.delete_thread(thread_id) return JSONResponse(content={"success": True}) +@app.post("/project/file") +async def upload_file( + session_id: str, + file: UploadFile, + current_user: Annotated[ + Union[None, User, PersistedUser], Depends(get_current_user) + ], +): + from chainlit.session import WebsocketSession + + session = WebsocketSession.get_by_id(session_id) + + if not session: + raise HTTPException( + status_code=404, + detail="Session not found", + ) + + if current_user: + if not session.user or session.user.identifier != current_user.identifier: + raise HTTPException( + status_code=401, + detail="You are not authorized to upload files for this session", + ) + + session.files_dir.mkdir(exist_ok=True) + + content = await file.read() + + file_response = await session.persist_file( + name=file.filename, content=content, mime=file.content_type + ) + + return JSONResponse(file_response) + + +@app.get("/project/file/{file_id}") +async def get_file( + file_id: str, + session_id: Optional[str] = None, + token: Optional[str] = None, +): + from chainlit.session import WebsocketSession + + session = WebsocketSession.get_by_id(session_id) if session_id else None + + if not session: + raise HTTPException( + status_code=404, + detail="Session not found", + ) + + if current_user := await get_current_user(token or ""): + if not session.user or session.user.identifier != current_user.identifier: + raise HTTPException( + status_code=401, + detail="You are not authorized to upload files for this session", + ) + + if file_id in session.files: + file = session.files[file_id] + return FileResponse(file["path"], media_type=file["type"]) + else: + raise HTTPException(status_code=404, detail="File not found") + + @app.get("/files/{filename:path}") async def serve_file( filename: str, - current_user: Annotated[ - Union[AppUser, PersistedAppUser], Depends(get_current_user) - ], + current_user: Annotated[Union[User, PersistedUser], Depends(get_current_user)], ): base_path = Path(config.project.local_fs_path).resolve() file_path = (base_path / filename).resolve() diff --git a/backend/chainlit/session.py b/backend/chainlit/session.py index ae97c5b070..3c30c2567e 100644 --- a/backend/chainlit/session.py +++ b/backend/chainlit/session.py @@ -1,13 +1,16 @@ -import asyncio import json -from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Union +import mimetypes +import shutil +import uuid +from typing import TYPE_CHECKING, Any, Callable, Deque, Dict, List, Optional, Union + +import aiofiles if TYPE_CHECKING: from chainlit.message import Message - from chainlit.types import AskResponse - -from chainlit.client.cloud import AppUser, PersistedAppUser -from chainlit.data import chainlit_client + from chainlit.step import Step + from chainlit.types import FileDict, FileReference + from chainlit.user import PersistedUser, User class JSONEncoderIgnoreNonSerializable(json.JSONEncoder): @@ -27,12 +30,17 @@ def clean_metadata(metadata: Dict): class BaseSession: """Base object.""" + active_steps: List["Step"] + thread_id_to_resume: Optional[str] = None + def __init__( self, # Id of the session id: str, + # Thread id + thread_id: Optional[str], # Logged-in user informations - user: Optional[Union["AppUser", "PersistedAppUser"]], + user: Optional[Union["User", "PersistedUser"]], # Logged-in user token token: Optional[str], # User specific environment variables. Empty if no user environment variables are required. @@ -41,41 +49,30 @@ def __init__( root_message: Optional["Message"] = None, # Chat profile selected before the session was created chat_profile: Optional[str] = None, - # Conversation id to resume - conversation_id: Optional[str] = None, ): + if thread_id: + self.thread_id_to_resume = thread_id + self.thread_id = thread_id or str(uuid.uuid4()) self.user = user self.token = token self.root_message = root_message self.has_user_message = False self.user_env = user_env or {} self.chat_profile = chat_profile + self.active_steps = [] self.id = id - self.conversation_id = conversation_id self.chat_settings: Dict[str, Any] = {} - self.lock = asyncio.Lock() - - async def get_conversation_id(self) -> Optional[str]: - if not chainlit_client: - return None - - if isinstance(self, HTTPSession): - tags = ["api"] - else: - tags = ["chat"] - - async with self.lock: - if not self.conversation_id: - app_user_id = ( - self.user.id if isinstance(self.user, PersistedAppUser) else None - ) - self.conversation_id = await chainlit_client.create_conversation( - app_user_id=app_user_id, tags=tags - ) - return self.conversation_id + async def persist_file( + self, + name: str, + mime: str, + path: Optional[str] = None, + content: Optional[Union[bytes, str]] = None, + ): + return None def to_persistable(self) -> Dict: from chainlit.user_session import user_sessions @@ -94,17 +91,24 @@ def __init__( self, # Id of the session id: str, + # Thread id + thread_id: Optional[str] = None, # Logged-in user informations - user: Optional[Union["AppUser", "PersistedAppUser"]], + user: Optional[Union["User", "PersistedUser"]] = None, # Logged-in user token - token: Optional[str], - user_env: Optional[Dict[str, str]], + token: Optional[str] = None, + user_env: Optional[Dict[str, str]] = None, # Last message at the root of the chat root_message: Optional["Message"] = None, # User specific environment variables. Empty if no user environment variables are required. ): super().__init__( - id=id, user=user, token=token, user_env=user_env, root_message=root_message + id=id, + thread_id=thread_id, + user=user, + token=token, + user_env=user_env, + root_message=root_message, ) @@ -129,28 +133,28 @@ def __init__( # Function to emit a message to the user emit: Callable[[str, Any], None], # Function to ask the user a question - ask_user: Callable[[Any, Optional[int]], Union["AskResponse", None]], + ask_user: Callable[[Any, Optional[int]], Any], # User specific environment variables. Empty if no user environment variables are required. user_env: Dict[str, str], + # Thread id + thread_id: Optional[str] = None, # Logged-in user informations - user: Optional[Union["AppUser", "PersistedAppUser"]], + user: Optional[Union["User", "PersistedUser"]] = None, # Logged-in user token - token: Optional[str], + token: Optional[str] = None, # Last message at the root of the chat root_message: Optional["Message"] = None, # Chat profile selected before the session was created chat_profile: Optional[str] = None, - # Conversation id to resume - conversation_id: Optional[str] = None, ): super().__init__( id=id, + thread_id=thread_id, user=user, token=token, user_env=user_env, root_message=root_message, chat_profile=chat_profile, - conversation_id=conversation_id, ) self.socket_id = socket_id @@ -160,9 +164,66 @@ def __init__( self.should_stop = False self.restored = False + self.thread_queues = {} # type: Dict[str, Deque[Callable]] + self.files = {} # type: Dict[str, "FileDict"] + ws_sessions_id[self.id] = self ws_sessions_sid[socket_id] = self + @property + def files_dir(self): + from chainlit.config import FILES_DIRECTORY + + return FILES_DIRECTORY / self.id + + async def persist_file( + self, + name: str, + mime: str, + path: Optional[str] = None, + content: Optional[Union[bytes, str]] = None, + ) -> "FileReference": + if not path and not content: + raise ValueError( + "Either path or content must be provided to persist a file" + ) + + self.files_dir.mkdir(exist_ok=True) + + file_id = str(uuid.uuid4()) + + file_path = self.files_dir / file_id + + file_extension = mimetypes.guess_extension(mime) + if file_extension: + file_path = file_path.with_suffix(file_extension) + + if path: + # Copy the file from the given path + async with aiofiles.open(path, "rb") as src, aiofiles.open( + file_path, "wb" + ) as dst: + await dst.write(await src.read()) + elif content: + # Write the provided content to the file + async with aiofiles.open(file_path, "wb") as buffer: + if isinstance(content, str): + content = content.encode("utf-8") + await buffer.write(content) + + # Get the file size + file_size = file_path.stat().st_size + # Store the file content in memory + self.files[file_id] = { + "id": file_id, + "path": file_path, + "name": name, + "type": mime, + "size": file_size, + } + + return {"id": file_id} + def restore(self, new_socket_id: str): """Associate a new socket id to the session.""" ws_sessions_sid.pop(self.socket_id, None) @@ -172,9 +233,17 @@ def restore(self, new_socket_id: str): def delete(self): """Delete the session.""" + if self.files_dir.is_dir(): + shutil.rmtree(self.files_dir) ws_sessions_sid.pop(self.socket_id, None) ws_sessions_id.pop(self.id, None) + async def flush_method_queue(self): + for method_name, queue in self.thread_queues.items(): + while queue: + method, self, args, kwargs = queue.popleft() + await method(self, *args, **kwargs) + @classmethod def get(cls, socket_id: str): """Get session by socket id.""" diff --git a/backend/chainlit/socket.py b/backend/chainlit/socket.py index 58e00e4c7d..90e77aca8e 100644 --- a/backend/chainlit/socket.py +++ b/backend/chainlit/socket.py @@ -6,7 +6,7 @@ from chainlit.auth import get_current_user, require_login from chainlit.config import config from chainlit.context import init_ws_context -from chainlit.data import chainlit_client +from chainlit.data import get_data_layer from chainlit.logger import logger from chainlit.message import ErrorMessage, Message from chainlit.server import socket @@ -27,39 +27,33 @@ def restore_existing_session(sid, session_id, emit_fn, ask_user_fn): return False -async def persist_user_session(conversation_id: str, metadata: Dict): - if not chainlit_client: - return - - await chainlit_client.update_conversation_metadata( - conversation_id=conversation_id, metadata=metadata - ) +async def persist_user_session(thread_id: str, metadata: Dict): + if data_layer := get_data_layer(): + await data_layer.update_thread(thread_id=thread_id, metadata=metadata) -async def resume_conversation(session: WebsocketSession): - if not chainlit_client or not session.user or not session.conversation_id: +async def resume_thread(session: WebsocketSession): + data_layer = get_data_layer() + if not data_layer or not session.user or not session.thread_id_to_resume: + return + thread = await data_layer.get_thread(thread_id=session.thread_id_to_resume) + if not thread: return - conversation = await chainlit_client.get_conversation( - conversation_id=session.conversation_id - ) - - author = ( - conversation["appUser"].get("username") if conversation["appUser"] else None - ) - user_is_author = author == session.user.username + author = thread.get("user").get("identifier") if thread["user"] else None + user_is_author = author == session.user.identifier - if conversation and user_is_author: - metadata = conversation["metadata"] or {} + if user_is_author: + metadata = thread["metadata"] or {} user_sessions[session.id] = metadata.copy() if chat_profile := metadata.get("chat_profile"): session.chat_profile = chat_profile if chat_settings := metadata.get("chat_settings"): session.chat_settings = chat_settings - trace_event("conversation_resumed") + trace_event("thread_resumed") - return conversation + return thread def load_user_env(user_env): @@ -128,9 +122,8 @@ def ask_user_fn(data, timeout): user=user, token=token, chat_profile=environ.get("HTTP_X_CHAINLIT_CHAT_PROFILE"), - conversation_id=environ.get("HTTP_X_CHAINLIT_CONVERSATION_ID"), + thread_id=environ.get("HTTP_X_CHAINLIT_THREAD_ID"), ) - trace_event("connection_successful") return True @@ -142,13 +135,13 @@ async def connection_successful(sid): if context.session.restored: return - if context.session.conversation_id and config.code.on_chat_resume: - conversation = await resume_conversation(context.session) - if conversation: + if context.session.thread_id_to_resume and config.code.on_chat_resume: + thread = await resume_thread(context.session) + if thread: context.session.has_user_message = True await context.emitter.clear_ask() - await config.code.on_chat_resume(conversation) - await context.emitter.resume_conversation(conversation) + await context.emitter.resume_thread(thread) + await config.code.on_chat_resume(thread) return if config.code.on_chat_start: @@ -174,13 +167,14 @@ async def clean_session(sid): @socket.on("disconnect") async def disconnect(sid): session = WebsocketSession.get(sid) + if session: + init_ws_context(session) if config.code.on_chat_end and session: - init_ws_context(session) await config.code.on_chat_end() - if session and session.conversation_id: - await persist_user_session(session.conversation_id, session.to_persistable()) + if session and session.thread_id and session.has_user_message: + await persist_user_session(session.thread_id, session.to_persistable()) async def disconnect_on_timeout(sid): await asyncio.sleep(config.project.session_timeout) @@ -200,7 +194,9 @@ async def stop(sid): trace_event("stop_task") init_ws_context(session) - await Message(author="System", content="Task stopped by the user.").send() + await Message( + author="System", content="Task stopped by the user.", disable_feedback=True + ).send() session.should_stop = True @@ -240,7 +236,8 @@ async def message(sid, payload: UIMessagePayload): async def process_action(action: Action): callback = config.code.action_callbacks.get(action.name) if callback: - await callback(action) + res = await callback(action) + return res else: logger.warning("No callback found for action %s", action.name) @@ -248,11 +245,25 @@ async def process_action(action: Action): @socket.on("action_call") async def call_action(sid, action): """Handle an action call from the UI.""" - init_ws_context(sid) + context = init_ws_context(sid) action = Action(**action) - await process_action(action) + try: + res = await process_action(action) + await context.emitter.send_action_response( + id=action.id, status=True, response=res if isinstance(res, str) else None + ) + + except InterruptedError: + await context.emitter.send_action_response( + id=action.id, status=False, response="Action interrupted by the user" + ) + except Exception as e: + logger.exception(e) + await context.emitter.send_action_response( + id=action.id, status=False, response="An error occured" + ) @socket.on("chat_settings_change") diff --git a/backend/chainlit/step.py b/backend/chainlit/step.py new file mode 100644 index 0000000000..788c295c80 --- /dev/null +++ b/backend/chainlit/step.py @@ -0,0 +1,393 @@ +import asyncio +import inspect +import json +import uuid +from datetime import datetime +from functools import wraps +from typing import Callable, Dict, List, Optional, TypedDict, Union + +from chainlit.config import config +from chainlit.context import context +from chainlit.data import get_data_layer +from chainlit.element import Element +from chainlit.logger import logger +from chainlit.telemetry import trace_event +from chainlit.types import FeedbackDict +from chainlit_client import BaseGeneration +from chainlit_client.step import StepType, TrueStepType + + +class StepDict(TypedDict, total=False): + name: str + type: StepType + id: str + threadId: str + parentId: Optional[str] + disableFeedback: bool + streaming: bool + waitForAnswer: Optional[bool] + isError: Optional[bool] + metadata: Dict + input: str + output: str + createdAt: Optional[str] + start: Optional[str] + end: Optional[str] + generation: Optional[Dict] + showInput: Optional[Union[bool, str]] + language: Optional[str] + indent: Optional[int] + feedback: Optional[FeedbackDict] + + +def step( + original_function: Optional[Callable] = None, + *, + name: Optional[str] = "", + type: TrueStepType = "undefined", + id: Optional[str] = None, + disable_feedback: bool = True, + root: bool = False, + language: Optional[str] = None, + show_input: Union[bool, str] = False, +): + """Step decorator for async and sync functions.""" + + def wrapper(func: Callable): + nonlocal name + if not name: + name = func.__name__ + + # Handle async decorator + + if inspect.iscoroutinefunction(func): + + @wraps(func) + async def async_wrapper(*args, **kwargs): + async with Step( + type=type, + name=name, + id=id, + disable_feedback=disable_feedback, + root=root, + language=language, + show_input=show_input, + ) as step: + try: + step.input = {"args": args, "kwargs": kwargs} + except: + pass + result = await func(*args, **kwargs) + try: + if result and not step.output: + step.output = result + except: + pass + return result + + return async_wrapper + else: + # Handle sync decorator + @wraps(func) + def sync_wrapper(*args, **kwargs): + with Step( + type=type, + name=name, + id=id, + disable_feedback=disable_feedback, + root=root, + language=language, + show_input=show_input, + ) as step: + try: + step.input = {"args": args, "kwargs": kwargs} + except: + pass + result = func(*args, **kwargs) + try: + if result and not step.output: + step.output = result + except: + pass + return result + + return sync_wrapper + + func = original_function + if not func: + return wrapper + else: + return wrapper(func) + + +class Step: + # Constructor + name: str + type: TrueStepType + id: str + parent_id: Optional[str] + disable_feedback: bool + + streaming: bool + persisted: bool + + root: bool + show_input: Union[bool, str] + + is_error: Optional[bool] + metadata: Dict + thread_id: str + created_at: Union[str, None] + start: Union[str, None] + end: Union[str, None] + generation: Optional[BaseGeneration] + language: Optional[str] + elements: Optional[List[Element]] + fail_on_persist_error: bool + + def __init__( + self, + name: Optional[str] = config.ui.name, + type: TrueStepType = "undefined", + id: Optional[str] = None, + parent_id: Optional[str] = None, + elements: Optional[List[Element]] = None, + disable_feedback: bool = True, + root: bool = False, + language: Optional[str] = None, + show_input: Union[bool, str] = False, + ): + trace_event(f"init {self.__class__.__name__} {type}") + self._input = "" + self._output = "" + self.thread_id = context.session.thread_id + self.name = name or "" + self.type = type + self.id = id or str(uuid.uuid4()) + self.disable_feedback = disable_feedback + self.metadata = {} + self.is_error = False + self.show_input = show_input + self.parent_id = parent_id + self.root = root + + self.language = language + self.generation = None + self.elements = elements or [] + + self.created_at = datetime.utcnow().isoformat() + self.start = None + self.end = None + + self.streaming = False + self.persisted = False + self.fail_on_persist_error = False + + def _process_content(self, content, set_language=False): + if content is None: + return "" + if isinstance(content, dict): + try: + processed_content = json.dumps(content, indent=4, ensure_ascii=False) + if set_language: + self.language = "json" + except TypeError: + processed_content = str(content) + if set_language: + self.language = "text" + elif isinstance(content, str): + processed_content = content + else: + processed_content = str(content) + if set_language: + self.language = "text" + return processed_content + + @property + def input(self): + return self._input + + @input.setter + def input(self, content: Union[Dict, str]): + self._input = self._process_content(content, set_language=False) + + @property + def output(self): + return self._output + + @output.setter + def output(self, content: Union[Dict, str]): + self._output = self._process_content(content, set_language=True) + + def to_dict(self) -> StepDict: + _dict: StepDict = { + "name": self.name, + "type": self.type, + "id": self.id, + "threadId": self.thread_id, + "parentId": self.parent_id, + "disableFeedback": self.disable_feedback, + "streaming": self.streaming, + "metadata": self.metadata, + "input": self.input, + "isError": self.is_error, + "output": self.output, + "createdAt": self.created_at, + "start": self.start, + "end": self.end, + "language": self.language, + "showInput": self.show_input, + "generation": self.generation.to_dict() if self.generation else None, + } + return _dict + + async def update(self): + """ + Update a step already sent to the UI. + """ + trace_event("update_step") + + if self.streaming: + self.streaming = False + + step_dict = self.to_dict() + data_layer = get_data_layer() + + if data_layer: + try: + asyncio.create_task(data_layer.update_step(step_dict)) + except Exception as e: + if self.fail_on_persist_error: + raise e + logger.error(f"Failed to persist step update: {str(e)}") + + tasks = [el.send(for_id=self.id) for el in self.elements] + await asyncio.gather(*tasks) + + if not config.features.prompt_playground and "generation" in step_dict: + step_dict.pop("generation", None) + + await context.emitter.update_step(step_dict) + + return True + + async def remove(self): + """ + Remove a step already sent to the UI. + """ + trace_event("remove_step") + + step_dict = self.to_dict() + data_layer = get_data_layer() + + if data_layer: + try: + asyncio.create_task(data_layer.delete_step(self.id)) + except Exception as e: + if self.fail_on_persist_error: + raise e + logger.error(f"Failed to persist step deletion: {str(e)}") + + await context.emitter.delete_step(step_dict) + + return True + + async def send(self): + if self.persisted: + return + + if config.code.author_rename: + self.name = await config.code.author_rename(self.name) + + if self.streaming: + self.streaming = False + + step_dict = self.to_dict() + + data_layer = get_data_layer() + + if data_layer: + try: + asyncio.create_task(data_layer.create_step(step_dict)) + self.persisted = True + except Exception as e: + if self.fail_on_persist_error: + raise e + logger.error(f"Failed to persist step creation: {str(e)}") + + tasks = [el.send(for_id=self.id) for el in self.elements] + await asyncio.gather(*tasks) + + if not config.features.prompt_playground and "generation" in step_dict: + step_dict.pop("generation", None) + + await context.emitter.send_step(step_dict) + + return self.id + + async def stream_token(self, token: str, is_sequence=False): + """ + Sends a token to the UI. + Once all tokens have been streamed, call .send() to end the stream and persist the step if persistence is enabled. + """ + + if not self.streaming: + self.streaming = True + step_dict = self.to_dict() + await context.emitter.stream_start(step_dict) + + if is_sequence: + self.output = token + else: + self.output += token + + assert self.id + await context.emitter.send_token( + id=self.id, token=token, is_sequence=is_sequence + ) + + # Handle parameter less decorator + def __call__(self, func): + return step( + original_function=func, + type=self.type, + name=self.name, + id=self.id, + parent_id=self.parent_id, + thread_id=self.thread_id, + disable_feedback=self.disable_feedback, + ) + + # Handle Context Manager Protocol + async def __aenter__(self): + self.start = datetime.utcnow().isoformat() + if not self.parent_id and not self.root: + if current_step := context.current_step: + self.parent_id = current_step.id + elif context.session.root_message: + self.parent_id = context.session.root_message.id + context.session.active_steps.append(self) + await self.send() + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb): + self.end = datetime.utcnow().isoformat() + context.session.active_steps.pop() + await self.update() + + def __enter__(self): + self.start = datetime.utcnow().isoformat() + if not self.parent_id and not self.root: + if current_step := context.current_step: + self.parent_id = current_step.id + elif context.session.root_message: + self.parent_id = context.session.root_message.id + context.session.active_steps.append(self) + + asyncio.create_task(self.send()) + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + self.end = datetime.utcnow().isoformat() + context.session.active_steps.pop() + asyncio.create_task(self.update()) diff --git a/backend/chainlit/types.py b/backend/chainlit/types.py index da394d0715..f601f6ee85 100644 --- a/backend/chainlit/types.py +++ b/backend/chainlit/types.py @@ -1,9 +1,12 @@ from enum import Enum -from typing import Dict, List, Literal, Optional, TypedDict, Union +from typing import TYPE_CHECKING, Dict, List, Literal, Optional, TypedDict, Union -from chainlit.client.base import ConversationFilter, MessageDict, Pagination -from chainlit.element import File -from chainlit.prompt import Prompt +if TYPE_CHECKING: + from chainlit.element import ElementDict + from chainlit.user import UserDict + from chainlit.step import StepDict + +from chainlit_client import ChatGeneration, CompletionGeneration from dataclasses_json import DataClassJsonMixin from pydantic import BaseModel from pydantic.dataclasses import dataclass @@ -13,6 +16,27 @@ ] +class ThreadDict(TypedDict): + id: str + createdAt: str + user: Optional["UserDict"] + tags: Optional[List[str]] + metadata: Optional[Dict] + steps: List["StepDict"] + elements: Optional[List["ElementDict"]] + + +class Pagination(BaseModel): + first: int + cursor: Optional[str] = None + + +class ThreadFilter(BaseModel): + feedback: Optional[Literal[-1, 0, 1]] = None + userIdentifier: Optional[str] = None + search: Optional[str] = None + + @dataclass class FileSpec(DataClassJsonMixin): accept: Union[List[str], Dict[str, List[str]]] @@ -43,23 +67,30 @@ class AskActionSpec(ActionSpec, AskSpec, DataClassJsonMixin): """Specification for asking the user an action""" -class AskResponse(TypedDict): - content: str - author: str +class FileReference(TypedDict): + id: str + + +class FileDict(TypedDict): + id: str + name: str + path: str + size: int + type: str class UIMessagePayload(TypedDict): - message: MessageDict - files: Optional[List[Dict]] + message: "StepDict" + fileReferences: Optional[List[FileReference]] @dataclass class AskFileResponse: + id: str name: str path: str size: int type: str - content: bytes class AskActionResponse(TypedDict): @@ -72,24 +103,28 @@ class AskActionResponse(TypedDict): collapsed: bool -class CompletionRequest(BaseModel): - prompt: Prompt +class GenerationRequest(BaseModel): + chatGeneration: Optional[ChatGeneration] = None + completionGeneration: Optional[CompletionGeneration] = None userEnv: Dict[str, str] + @property + def generation(self): + if self.chatGeneration: + return self.chatGeneration + return self.completionGeneration -class UpdateFeedbackRequest(BaseModel): - messageId: str - feedback: Literal[-1, 0, 1] - feedbackComment: Optional[str] = None + def is_chat(self): + return self.chatGeneration is not None -class DeleteConversationRequest(BaseModel): - conversationId: str +class DeleteThreadRequest(BaseModel): + threadId: str -class GetConversationsRequest(BaseModel): +class GetThreadsRequest(BaseModel): pagination: Pagination - filter: ConversationFilter + filter: ThreadFilter class Theme(str, Enum): @@ -99,8 +134,30 @@ class Theme(str, Enum): @dataclass class ChatProfile(DataClassJsonMixin): - """Specification for a chat profile that can be chosen by the user at the conversation start.""" + """Specification for a chat profile that can be chosen by the user at the thread start.""" name: str markdown_description: str icon: Optional[str] = None + + +FeedbackStrategy = Literal["BINARY"] + + +class FeedbackDict(TypedDict): + value: Literal[-1, 0, 1] + strategy: FeedbackStrategy + comment: Optional[str] + + +@dataclass +class Feedback: + forId: str + value: Literal[-1, 0, 1] + strategy: FeedbackStrategy = "BINARY" + id: Optional[str] = None + comment: Optional[str] = None + + +class UpdateFeedbackRequest(BaseModel): + feedback: Feedback diff --git a/backend/chainlit/user.py b/backend/chainlit/user.py new file mode 100644 index 0000000000..47ad1d2639 --- /dev/null +++ b/backend/chainlit/user.py @@ -0,0 +1,32 @@ +from typing import Dict, Literal, TypedDict + +from dataclasses_json import DataClassJsonMixin +from pydantic.dataclasses import Field, dataclass + +Provider = Literal[ + "credentials", "header", "github", "google", "azure-ad", "okta", "auth0", "descope" +] + + +class UserDict(TypedDict): + id: str + identifier: str + metadata: Dict + + +# Used when logging-in a user +@dataclass +class User(DataClassJsonMixin): + identifier: str + metadata: Dict = Field(default_factory=dict) + + +@dataclass +class PersistedUserFields: + id: str + createdAt: str + + +@dataclass +class PersistedUser(User, PersistedUserFields): + pass diff --git a/backend/chainlit/user_session.py b/backend/chainlit/user_session.py index 011ec9b09b..a4a3d15c10 100644 --- a/backend/chainlit/user_session.py +++ b/backend/chainlit/user_session.py @@ -1,8 +1,4 @@ -from typing import TYPE_CHECKING, Dict, Optional, TypedDict, Union - -if TYPE_CHECKING: - from chainlit.message import Message - from chainlit.client.base import AppUser, PersistedAppUser +from typing import Dict from chainlit.context import context diff --git a/backend/pyproject.toml b/backend/pyproject.toml index 5c54251585..98a90b3285 100644 --- a/backend/pyproject.toml +++ b/backend/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "chainlit" -version = "0.7.700" +version = "1.0.0rc0" keywords = ['LLM', 'Agents', 'gen ai', 'chat ui', 'chatbot ui', 'langchain'] description = "A faster way to build chatbot UIs." authors = ["Chainlit"] @@ -20,6 +20,7 @@ chainlit = 'chainlit.cli:cli' [tool.poetry.dependencies] python = ">=3.8.1,<3.12" +chainlit_client = "0.1.0rc4" dataclasses_json = "^0.5.7" uvicorn = "^0.23.2" fastapi = "^0.100" @@ -60,13 +61,14 @@ plotly = "^5.18.0" optional = true [tool.poetry.group.mypy.dependencies] -mypy = "^1.5.1" +mypy = "^1.7.1" types-requests = "^2.31.0.2" types-aiofiles = "^23.1.0.5" [tool.mypy] python_version = "3.8" + [[tool.mypy.overrides]] module = [ "anthropic", diff --git a/cypress/e2e/action/spec.cy.ts b/cypress/e2e/action/spec.cy.ts index 89c3d10273..f7db198384 100644 --- a/cypress/e2e/action/spec.cy.ts +++ b/cypress/e2e/action/spec.cy.ts @@ -9,16 +9,16 @@ describe('Action', () => { // Click on "first action" cy.get('#first-action').should('be.visible'); cy.get('#first-action').click(); - cy.get('.message').should('have.length', 3); - cy.get('.message') + cy.get('.step').should('have.length', 3); + cy.get('.step') .eq(2) .should('contain', 'Thanks for pressing: first-action'); // Click on "test action" cy.get("[id='test-action']").should('be.visible'); cy.get("[id='test-action']").click(); - cy.get('.message').should('have.length', 4); - cy.get('.message').eq(3).should('contain', 'Executed test action!'); + cy.get('.step').should('have.length', 4); + cy.get('.step').eq(3).should('contain', 'Executed test action!'); cy.get("[id='test-action']").should('exist'); cy.wait(100); @@ -26,8 +26,8 @@ describe('Action', () => { // Click on "removable action" cy.get("[id='removable-action']").should('be.visible'); cy.get("[id='removable-action']").click(); - cy.get('.message').should('have.length', 5); - cy.get('.message').eq(4).should('contain', 'Executed removable action!'); + cy.get('.step').should('have.length', 5); + cy.get('.step').eq(4).should('contain', 'Executed removable action!'); cy.get("[id='removable-action']").should('not.exist'); cy.wait(100); @@ -35,13 +35,13 @@ describe('Action', () => { // Click on "multiple action one" in the action drawer, should remove the correct action button cy.get("[id='actions-drawer-button']").should('be.visible'); cy.get("[id='actions-drawer-button']").click(); - cy.get('.message').should('have.length', 5); + cy.get('.step').should('have.length', 5); cy.wait(100); cy.get("[id='multiple-action-one']").should('be.visible'); cy.get("[id='multiple-action-one']").click(); - cy.get('.message') + cy.get('.step') .eq(5) .should('contain', 'Action(id=multiple-action-one) has been removed!'); cy.get("[id='multiple-action-one']").should('not.exist'); @@ -49,11 +49,11 @@ describe('Action', () => { cy.wait(100); // Click on "multiple action two", should remove the correct action button - cy.get('.message').should('have.length', 6); + cy.get('.step').should('have.length', 6); cy.get("[id='actions-drawer-button']").click(); cy.get("[id='multiple-action-two']").should('be.visible'); cy.get("[id='multiple-action-two']").click(); - cy.get('.message') + cy.get('.step') .eq(6) .should('contain', 'Action(id=multiple-action-two) has been removed!'); cy.get("[id='multiple-action-two']").should('not.exist'); @@ -63,9 +63,7 @@ describe('Action', () => { // Click on "all actions removed", should remove all buttons cy.get("[id='all-actions-removed']").should('be.visible'); cy.get("[id='all-actions-removed']").click(); - cy.get('.message') - .eq(7) - .should('contain', 'All actions have been removed!'); + cy.get('.step').eq(7).should('contain', 'All actions have been removed!'); cy.get("[id='all-actions-removed']").should('not.exist'); cy.get("[id='test-action']").should('not.exist'); cy.get("[id='actions-drawer-button']").should('not.exist'); diff --git a/cypress/e2e/ask_file/main.py b/cypress/e2e/ask_file/main.py index ad11c83278..427b61561f 100644 --- a/cypress/e2e/ask_file/main.py +++ b/cypress/e2e/ask_file/main.py @@ -1,3 +1,5 @@ +import aiofiles + import chainlit as cl @@ -7,20 +9,20 @@ async def start(): content="Please upload a text file to begin!", accept=["text/plain"] ).send() txt_file = files[0] - # Decode the file - text = txt_file.content.decode("utf-8") - await cl.Message( - content=f"`Text file {txt_file.name}` uploaded, it contains {len(text)} characters!" - ).send() + async with aiofiles.open(txt_file.path, "r", encoding="utf-8") as f: + content = await f.read() + await cl.Message( + content=f"`Text file {txt_file.name}` uploaded, it contains {len(content)} characters!" + ).send() files = await cl.AskFileMessage( content="Please upload a python file to begin!", accept={"text/plain": [".py"]} ).send() py_file = files[0] - # Decode the file - text = py_file.content.decode("utf-8") - await cl.Message( - content=f"`Python file {py_file.name}` uploaded, it contains {len(text)} characters!" - ).send() + async with aiofiles.open(py_file.path, "r", encoding="utf-8") as f: + content = await f.read() + await cl.Message( + content=f"`Python file {py_file.name}` uploaded, it contains {len(content)} characters!" + ).send() diff --git a/cypress/e2e/ask_file/spec.cy.ts b/cypress/e2e/ask_file/spec.cy.ts index 62a041b3d8..8455dde2ab 100644 --- a/cypress/e2e/ask_file/spec.cy.ts +++ b/cypress/e2e/ask_file/spec.cy.ts @@ -16,7 +16,7 @@ describe('Upload file', () => { // cy.get("#ask-upload-button-loading").should("exist"); // cy.get("#ask-upload-button-loading").should("not.exist"); - cy.get('.message') + cy.get('.step') .eq(1) .should( 'contain', @@ -29,13 +29,13 @@ describe('Upload file', () => { cy.fixture('hello.cpp', 'utf-8').as('cppFile'); cy.get('#ask-button-input').selectFile('@cppFile', { force: true }); - cy.get('.message').should('have.length', 3); + cy.get('.step').should('have.length', 3); // Upload a python file cy.fixture('hello.py', 'utf-8').as('pyFile'); cy.get('#ask-button-input').selectFile('@pyFile', { force: true }); - cy.get('.message') + cy.get('.step') .should('have.length', 4) .eq(3) .should('contain', 'Python file hello.py uploaded, it contains'); diff --git a/cypress/e2e/ask_multiple_files/.chainlit/config.toml b/cypress/e2e/ask_multiple_files/.chainlit/config.toml index 2006dfbc52..92e7b496fe 100644 --- a/cypress/e2e/ask_multiple_files/.chainlit/config.toml +++ b/cypress/e2e/ask_multiple_files/.chainlit/config.toml @@ -25,7 +25,7 @@ multi_modal = true # Name of the app and chatbot. name = "Chatbot" -# Show the readme while the conversation is empty. +# Show the readme while the thread is empty. show_readme_as_default = true # Description of the app and chatbot. This is used for HTML tags. diff --git a/cypress/e2e/ask_multiple_files/spec.cy.ts b/cypress/e2e/ask_multiple_files/spec.cy.ts index bb40cc5e19..dc6dba0001 100644 --- a/cypress/e2e/ask_multiple_files/spec.cy.ts +++ b/cypress/e2e/ask_multiple_files/spec.cy.ts @@ -19,7 +19,7 @@ describe('Upload multiple files', () => { // cy.get("#ask-upload-button-loading").should("exist"); // cy.get("#ask-upload-button-loading").should("not.exist"); - cy.get('.message') + cy.get('.step') .eq(1) .should('contain', '2 files uploaded: state_of_the_union.txt,hello.py'); diff --git a/cypress/e2e/ask_user/main.py b/cypress/e2e/ask_user/main.py index 60488b1de0..312d3730a1 100644 --- a/cypress/e2e/ask_user/main.py +++ b/cypress/e2e/ask_user/main.py @@ -6,5 +6,5 @@ async def main(): res = await cl.AskUserMessage(content="What is your name?", timeout=10).send() if res: await cl.Message( - content=f"Your name is: {res['content']}", + content=f"Your name is: {res['output']}", ).send() diff --git a/cypress/e2e/ask_user/spec.cy.ts b/cypress/e2e/ask_user/spec.cy.ts index a4feb60267..2e3fba3d48 100644 --- a/cypress/e2e/ask_user/spec.cy.ts +++ b/cypress/e2e/ask_user/spec.cy.ts @@ -6,11 +6,11 @@ describe('Ask User', () => { }); it('should send a new message containing the user input', () => { - cy.get('.message').should('have.length', 1); + cy.get('.step').should('have.length', 1); submitMessage('Jeeves'); cy.wait(2000); - cy.get('.message').should('have.length', 3); + cy.get('.step').should('have.length', 3); - cy.get('.message').eq(2).should('contain', 'Jeeves'); + cy.get('.step').eq(2).should('contain', 'Jeeves'); }); }); diff --git a/cypress/e2e/audio_element/spec.cy.ts b/cypress/e2e/audio_element/spec.cy.ts index 342241377d..05cb819789 100644 --- a/cypress/e2e/audio_element/spec.cy.ts +++ b/cypress/e2e/audio_element/spec.cy.ts @@ -6,8 +6,8 @@ describe('audio', () => { }); it('should be able to display an audio element', () => { - cy.get('.message').should('have.length', 1); - cy.get('.message').eq(0).find('.inline-audio').should('have.length', 1); + cy.get('.step').should('have.length', 1); + cy.get('.step').eq(0).find('.inline-audio').should('have.length', 1); cy.get('.inline-audio audio') .then(($el) => { diff --git a/cypress/e2e/author_rename/main.py b/cypress/e2e/author_rename/main.py index 6581d38c3c..eddd2d3419 100644 --- a/cypress/e2e/author_rename/main.py +++ b/cypress/e2e/author_rename/main.py @@ -7,7 +7,13 @@ def rename(orig_author: str): return rename_dict.get(orig_author, orig_author) +@cl.step +def LLMMathChain(): + return "2+2=4" + + @cl.on_chat_start async def main(): await cl.Message(author="LLMMathChain", content="2+2=4").send() + LLMMathChain() await cl.Message(content="The response is 4").send() diff --git a/cypress/e2e/author_rename/spec.cy.ts b/cypress/e2e/author_rename/spec.cy.ts index ccf77cafcb..60b30cbf77 100644 --- a/cypress/e2e/author_rename/spec.cy.ts +++ b/cypress/e2e/author_rename/spec.cy.ts @@ -1,12 +1,12 @@ -import { runTestServer } from "../../support/testUtils"; +import { runTestServer } from '../../support/testUtils'; -describe("Author rename", () => { +describe('Author rename', () => { before(() => { runTestServer(); }); - it("should be able to rename authors", () => { - cy.get(".message").eq(0).should("contain", "Albert Einstein"); - cy.get(".message").eq(1).should("contain", "Assistant"); + it('should be able to rename authors', () => { + cy.get('.step').eq(0).should('contain', 'Albert Einstein'); + cy.get('.step').eq(1).should('contain', 'Assistant'); }); }); diff --git a/cypress/e2e/global_elements/cat.jpeg b/cypress/e2e/avatar/cat.jpeg similarity index 100% rename from cypress/e2e/global_elements/cat.jpeg rename to cypress/e2e/avatar/cat.jpeg diff --git a/cypress/e2e/avatar/main.py b/cypress/e2e/avatar/main.py index dde9753d02..3973051fab 100644 --- a/cypress/e2e/avatar/main.py +++ b/cypress/e2e/avatar/main.py @@ -8,6 +8,13 @@ async def start(): url="https://avatars.githubusercontent.com/u/128686189?s=400&u=a1d1553023f8ea0921fba0debbe92a8c5f840dd9&v=4", ).send() + await cl.Avatar(name="Cat", path="./cat.jpeg").send() + + await cl.Avatar( + name="Tool 1", + url="https://avatars.githubusercontent.com/u/128686189?s=400&u=a1d1553023f8ea0921fba0debbe92a8c5f840dd9&v=4", + ).send() + await cl.Message( content="This message should not have an avatar!", author="Tool 0" ).send() @@ -19,3 +26,7 @@ async def start(): await cl.Message( content="This message should not have an avatar!", author="Tool 2" ).send() + + await cl.Message( + content="This message should have a cat avatar!", author="Cat" + ).send() diff --git a/cypress/e2e/avatar/spec.cy.ts b/cypress/e2e/avatar/spec.cy.ts index 1a43e55c6a..b7ac2aaa4b 100644 --- a/cypress/e2e/avatar/spec.cy.ts +++ b/cypress/e2e/avatar/spec.cy.ts @@ -6,11 +6,12 @@ describe('Avatar', () => { }); it('should be able to display avatars', () => { - cy.get('.message').should('have.length', 3); + cy.get('.step').should('have.length', 4); - cy.get('.message').eq(0).find('.message-avatar').should('have.length', 0); - cy.get('.message').eq(1).find('.message-avatar').should('have.length', 1); - cy.get('.message').eq(2).find('.message-avatar').should('have.length', 0); + cy.get('.step').eq(0).find('.message-avatar').should('have.length', 0); + cy.get('.step').eq(1).find('.message-avatar').should('have.length', 1); + cy.get('.step').eq(2).find('.message-avatar').should('have.length', 0); + cy.get('.step').eq(3).find('.message-avatar').should('have.length', 1); cy.get('.element-link').should('have.length', 0); }); diff --git a/cypress/e2e/chat_profiles/main.py b/cypress/e2e/chat_profiles/main.py index e9857a7002..bce724285d 100644 --- a/cypress/e2e/chat_profiles/main.py +++ b/cypress/e2e/chat_profiles/main.py @@ -4,8 +4,8 @@ @cl.set_chat_profiles -async def chat_profile(current_user: cl.AppUser): - if current_user.role != "ADMIN": +async def chat_profile(current_user: cl.User): + if current_user.metadata["role"] != "ADMIN": return None return [ @@ -27,22 +27,17 @@ async def chat_profile(current_user: cl.AppUser): @cl.password_auth_callback -def auth_callback(username: str, password: str) -> Optional[cl.AppUser]: +def auth_callback(username: str, password: str) -> Optional[cl.User]: if (username, password) == ("admin", "admin"): - return cl.AppUser(username="admin", role="ADMIN", provider="credentials") + return cl.User(identifier="admin", metadata={"role": "ADMIN"}) else: return None -# @cl.on_message -# async def on_message(message: str): -# await cl.Message(content=f"echo: {message}").send() - - @cl.on_chat_start async def on_chat_start(): - app_user = cl.user_session.get("user") + user = cl.user_session.get("user") chat_profile = cl.user_session.get("chat_profile") await cl.Message( - content=f"starting chat with {app_user.username} using the {chat_profile} chat profile" + content=f"starting chat with {user.identifier} using the {chat_profile} chat profile" ).send() diff --git a/cypress/e2e/chat_profiles/spec.cy.ts b/cypress/e2e/chat_profiles/spec.cy.ts index c17f3dcaf3..16156a6478 100644 --- a/cypress/e2e/chat_profiles/spec.cy.ts +++ b/cypress/e2e/chat_profiles/spec.cy.ts @@ -17,7 +17,7 @@ describe('Chat profiles', () => { cy.get('[data-test="chat-profile:GPT-4"]').should('exist'); cy.get('[data-test="chat-profile:GPT-5"]').should('exist'); - cy.get('.message') + cy.get('.step') .should('have.length', 1) .eq(0) .should( @@ -30,19 +30,18 @@ describe('Chat profiles', () => { cy.get('[data-test="chat-profile:GPT-4"]').click(); cy.get('#confirm').click(); - cy.get('.message') + cy.get('.step') .should('have.length', 1) .eq(0) .should( 'contain', 'starting chat with admin using the GPT-4 chat profile' ); - // New conversation cy.get('#new-chat-button').click(); cy.get('#confirm').click(); - cy.get('.message') + cy.get('.step') .should('have.length', 1) .eq(0) .should( diff --git a/cypress/e2e/chat_settings/spec.cy.ts b/cypress/e2e/chat_settings/spec.cy.ts index d93fc85764..a60e2457fe 100644 --- a/cypress/e2e/chat_settings/spec.cy.ts +++ b/cypress/e2e/chat_settings/spec.cy.ts @@ -33,8 +33,8 @@ describe('Customize chat settings', () => { cy.contains('Confirm').click(); - cy.get('.message').should('have.length', 1); - cy.get('.message').eq(0).should('contain', 'Settings updated!'); + cy.get('.step').should('have.length', 1); + cy.get('.step').eq(0).should('contain', 'Settings updated!'); // Check if inputs are updated cy.get('#chat-settings-open-modal').click(); diff --git a/cypress/e2e/conversations/.chainlit/config.toml b/cypress/e2e/context/.chainlit/config.toml similarity index 100% rename from cypress/e2e/conversations/.chainlit/config.toml rename to cypress/e2e/context/.chainlit/config.toml diff --git a/cypress/e2e/sdk_availability/main.py b/cypress/e2e/context/main.py similarity index 99% rename from cypress/e2e/sdk_availability/main.py rename to cypress/e2e/context/main.py index 63ad8f838e..47cd59eaa2 100644 --- a/cypress/e2e/sdk_availability/main.py +++ b/cypress/e2e/context/main.py @@ -1,7 +1,8 @@ -import chainlit as cl from chainlit.context import context from chainlit.sync import make_async, run_sync +import chainlit as cl + async def async_function_from_sync(): await cl.sleep(2) diff --git a/cypress/e2e/context/spec.cy.ts b/cypress/e2e/context/spec.cy.ts new file mode 100644 index 0000000000..48348d2b8c --- /dev/null +++ b/cypress/e2e/context/spec.cy.ts @@ -0,0 +1,19 @@ +import { runTestServer } from '../../support/testUtils'; + +describe('Context should be reachable', () => { + before(() => { + runTestServer(); + }); + + it('should find the Emitter from async, make_async and async_from_sync contexts', () => { + cy.get('.step').should('have.length', 3); + + cy.get('.step').eq(0).should('contain', 'emitter from async found!'); + + cy.get('.step').eq(1).should('contain', 'emitter from make_async found!'); + + cy.get('.step') + .eq(2) + .should('contain', 'emitter from async_from_sync found!'); + }); +}); diff --git a/cypress/e2e/conversations/main.py b/cypress/e2e/conversations/main.py deleted file mode 100644 index 214344808f..0000000000 --- a/cypress/e2e/conversations/main.py +++ /dev/null @@ -1,21 +0,0 @@ -from typing import Optional - -import chainlit as cl - - -@cl.on_chat_start -async def main(): - await cl.Message("Hello, send me a message!").send() - - -@cl.on_message -async def handle_message(): - await cl.Message("Ok!").send() - - -@cl.password_auth_callback -def auth_callback(username: str, password: str) -> Optional[cl.AppUser]: - if (username, password) == ("admin", "admin"): - return cl.AppUser(username="admin", role="ADMIN", provider="credentials") - else: - return None diff --git a/cypress/e2e/conversations/spec.cy.ts b/cypress/e2e/conversations/spec.cy.ts deleted file mode 100644 index 8de2cc88ca..0000000000 --- a/cypress/e2e/conversations/spec.cy.ts +++ /dev/null @@ -1,125 +0,0 @@ -import { runTestServer } from '../../support/testUtils'; - -describe('Conversations', () => { - before(() => { - runTestServer(undefined, { - CHAINLIT_API_KEY: 'fake_key', - CHAINLIT_AUTH_SECRET: - 'G=>I6me4>E_y,n$_%K%XqbTMKXGQy-jvZ6:1oR>~o8z@DPb*.QY~NkgctmBDg3T-' - }); - - cy.intercept('POST', '/login', { - statusCode: 200, - body: { - access_token: - 'eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJ1c2VybmFtZSI6ImFkbWluIiwicm9sZSI6IkFETUlOIiwidGFncyI6W10sImltYWdlIjpudWxsLCJwcm92aWRlciI6ImNyZWRlbnRpYWxzIiwiZXhwIjoxNjk4MTM5MzgwfQ.VQS_O0Zar1O3BVzJ_bu4_8r-1LW0Mfq2En7sIojzd04', - token_type: 'bearer' - } - }).as('postLogin'); - - cy.intercept('GET', '/project/settings', { - statusCode: 200, - body: { - ui: { - show_readme_as_default: true - }, - userEnv: [], - dataPersistence: true, - markdown: 'foo', - chatProfiles: [] - } - }).as('getSettings'); - - const makeData = (start: number, count: number) => - Array.from({ length: count }, (_, i) => ({ - id: String(start + i), - createdAt: Date.now(), - tags: ['chat'], - appUser: { username: 'admin' }, - messages: [{ content: `foo ${start + i}` }] - })); - - cy.intercept('POST', '/project/conversations', (req) => { - const { cursor } = req.body.pagination; - const dataCount = cursor ? 3 : 20; - const startId = cursor ? 21 : 1; - - req.reply({ - statusCode: 200, - body: { - pageInfo: { - hasNextPage: !cursor, - endCursor: cursor ? 'newCursor' : 'someCursor' - }, - data: makeData(startId, dataCount) - } - }); - }).as('getConversations'); - - cy.intercept('GET', '/project/conversation/*', (req) => { - const conversationId = req.url.split('/').pop(); - - req.reply({ - statusCode: 200, - body: { - id: conversationId, - createdAt: Date.now(), - tags: ['chat'], - messages: [ - { - id: '2b1755ab-f7e3-48fa-9fe1-535595142b96', - isError: false, - parentId: null, - indent: 0, - author: 'Chatbot', - content: `Foo ${conversationId} message`, - waitForAnswer: false, - humanFeedback: 0, - humanFeedbackComment: null, - disableHumanFeedback: false, - language: null, - prompt: null, - authorIsUser: false, - createdAt: 1696844037149 - } - ], - elements: [] - } - }); - }).as('getConversation'); - - cy.intercept('DELETE', '/project/conversation', { - statusCode: 200, - body: { - success: true - } - }).as('deleteConversation'); - }); - - describe('Conversations history', () => { - it('should perform conversations history operations', () => { - // Login to the app - cy.get("[id='email']").type('admin'); - cy.get("[id='password']").type('admin{enter}'); - - // Conversations are being displayed - cy.contains('Foo 1'); - cy.contains('Foo 2'); - - // Scroll chat and fetch new conversations - cy.get('.chat-history-drawer > div').scrollTo('bottom'); - cy.get('#chat-history-loader').should('be.visible'); - cy.contains('Foo 23'); - - // Select conversation - cy.get('#conversation-18').click(); - cy.get('#conversation-18').should('be.visible'); - cy.contains('Foo 18 message'); - - // Delete conversation - cy.get("[data-testid='DeleteOutlineIcon']").click(); - cy.get("[type='button']").contains('Confirm').click(); - cy.contains('Conversation deleted!'); - }); - }); -}); diff --git a/cypress/e2e/cot/main.py b/cypress/e2e/cot/main.py deleted file mode 100644 index b8665e5c12..0000000000 --- a/cypress/e2e/cot/main.py +++ /dev/null @@ -1,31 +0,0 @@ -import chainlit as cl - - -@cl.on_message -async def main(message: cl.Message): - tool1_msg = cl.Message(content="", author="Tool 1", parent_id=message.id) - await tool1_msg.send() - - await cl.sleep(1) - - tool1_msg.content = "I need to use tool 2" - - await tool1_msg.update() - - tool2_msg = cl.Message(content="", author="Tool 2", parent_id=tool1_msg.id) - - await tool2_msg.send() - - await cl.sleep(1) - - tool2_msg.content = "Response from tool 2" - - await tool2_msg.update() - - await cl.Message( - content="Response from tool 2", author="Tool 1", parent_id=message.id - ).send() - - await cl.Message( - content="Final response", - ).send() diff --git a/cypress/e2e/cot/spec.cy.ts b/cypress/e2e/cot/spec.cy.ts deleted file mode 100644 index 4a050423de..0000000000 --- a/cypress/e2e/cot/spec.cy.ts +++ /dev/null @@ -1,22 +0,0 @@ -import { runTestServer, submitMessage } from '../../support/testUtils'; - -describe('Chain of Thought', () => { - before(() => { - runTestServer(); - }); - - it('should be able to display a nested CoT', () => { - submitMessage('Hello'); - - cy.get('#tool-1-loading').should('exist'); - cy.get('#tool-1-loading').click(); - - cy.get('#tool-2-loading').should('exist'); - cy.get('#tool-2-loading').click(); - - cy.get('#tool-1-done').should('exist'); - cy.get('#tool-2-done').should('exist'); - - cy.get('.message').should('have.length', 5); - }); -}); diff --git a/cypress/e2e/cot_mixed/.chainlit/config.toml b/cypress/e2e/cot_mixed/.chainlit/config.toml deleted file mode 100644 index a8b882f1bc..0000000000 --- a/cypress/e2e/cot_mixed/.chainlit/config.toml +++ /dev/null @@ -1,75 +0,0 @@ -[project] -# Whether to enable telemetry (default: true). No personal data is collected. -enable_telemetry = true - -# List of environment variables to be provided by each user to use the app. -user_env = [] - -# Duration (in seconds) during which the session is saved when the connection is lost -session_timeout = 3600 - -# Enable third parties caching (e.g LangChain cache) -cache = false - -# Follow symlink for asset mount (see https://github.com/Chainlit/chainlit/issues/317) -# follow_symlink = false - -[features] -# Show the prompt playground -prompt_playground = true - -# Authorize users to upload files with messages -multi_modal = true - -# Allows user to use speech to text -# speech_to_text = true - -[UI] -# Name of the app and chatbot. -name = "Chatbot" - -# Show the readme while the conversation is empty. -show_readme_as_default = true - -# Description of the app and chatbot. This is used for HTML tags. -# description = "" - -# Large size content are by default collapsed for a cleaner ui -default_collapse_content = true - -# The default value for the expand messages settings. -default_expand_messages = true - -# Hide the chain of thought details from the user in the UI. -hide_cot = false - -# Link to your github repo. This will add a github button in the UI's header. -# github = "" - -# Specify a CSS file that can be used to customize the user interface. -# The CSS file can be served from the public directory or via an external link. -# custom_css = "/public/test.css" - -# Override default MUI light theme. (Check theme.ts) -[UI.theme.light] - #background = "#FAFAFA" - #paper = "#FFFFFF" - - [UI.theme.light.primary] - #main = "#F80061" - #dark = "#980039" - #light = "#FFE7EB" - -# Override default MUI dark theme. (Check theme.ts) -[UI.theme.dark] - #background = "#FAFAFA" - #paper = "#FFFFFF" - - [UI.theme.dark.primary] - #main = "#F80061" - #dark = "#980039" - #light = "#FFE7EB" - - -[meta] -generated_by = "0.7.1" diff --git a/cypress/e2e/cot_mixed/main.py b/cypress/e2e/cot_mixed/main.py deleted file mode 100644 index 3db281acbd..0000000000 --- a/cypress/e2e/cot_mixed/main.py +++ /dev/null @@ -1,21 +0,0 @@ -import chainlit as cl - - -@cl.on_message -async def main(message: cl.Message): - tool1_msg = cl.Message( - content="Response from tool 1", author="Tool 1", parent_id=message.id - ) - await tool1_msg.send() - - await cl.Message( - content="Response from tool 2", - author="Tool 2", - parent_id=tool1_msg.id, - ).send() - - await cl.Message( - content="Response from tool 3", - author="Tool 3", - indent=3, - ).send() diff --git a/cypress/e2e/cot_mixed/spec.cy.ts b/cypress/e2e/cot_mixed/spec.cy.ts deleted file mode 100644 index 82a70ce3c7..0000000000 --- a/cypress/e2e/cot_mixed/spec.cy.ts +++ /dev/null @@ -1,19 +0,0 @@ -import { runTestServer, submitMessage } from '../../support/testUtils'; - -describe('Chain of Thought', () => { - before(() => { - runTestServer(); - }); - - it('should be able to display a nested CoT', () => { - submitMessage('Hello'); - - cy.get(".message:contains('Hello')").contains('Response from tool 1'); - cy.get(".message:contains('Response from tool 1')").contains( - 'Response from tool 2' - ); - cy.get(".message:contains('Response from tool 2')").contains( - 'Response from tool 3' - ); - }); -}); diff --git a/cypress/e2e/cot/.chainlit/config.toml b/cypress/e2e/data_layer/.chainlit/config.toml similarity index 100% rename from cypress/e2e/cot/.chainlit/config.toml rename to cypress/e2e/data_layer/.chainlit/config.toml diff --git a/cypress/e2e/data_layer/main.py b/cypress/e2e/data_layer/main.py new file mode 100644 index 0000000000..0beefc9931 --- /dev/null +++ b/cypress/e2e/data_layer/main.py @@ -0,0 +1,130 @@ +from datetime import datetime +from typing import List, Optional + +import chainlit.data as cl_data +from chainlit.step import StepDict + +import chainlit as cl + +now = datetime.utcnow().isoformat() + +create_step_counter = 0 + +user_dict = {"id": "test", "createdAt": now, "identifier": "admin"} + +thread_history = [ + { + "id": "test1", + "metadata": {"name": "thread 1"}, + "createdAt": now, + "user": user_dict, + "steps": [ + { + "id": "test1", + "name": "test", + "createdAt": now, + "type": "user_message", + "output": "Message 1", + }, + { + "id": "test2", + "name": "test", + "createdAt": now, + "type": "assistant_message", + "output": "Message 2", + }, + ], + }, + { + "id": "test2", + "createdAt": now, + "user": user_dict, + "metadata": {"name": "thread 2"}, + "steps": [ + { + "id": "test3", + "createdAt": now, + "name": "test", + "type": "user_message", + "output": "Message 3", + }, + { + "id": "test4", + "createdAt": now, + "name": "test", + "type": "assistant_message", + "output": "Message 4", + }, + ], + }, +] # type: List[cl_data.ThreadDict] +deleted_thread_ids = [] # type: List[str] + + +class TestDataLayer(cl_data.BaseDataLayer): + async def get_user(self, identifier: str): + return cl.PersistedUser(id="test", createdAt=now, identifier=identifier) + + async def create_user(self, user: cl.User): + return cl.PersistedUser(id="test", createdAt=now, identifier=user.identifier) + + @cl_data.queue_until_user_message() + async def create_step(self, step_dict: StepDict): + global create_step_counter + create_step_counter += 1 + + async def get_thread_author(self, thread_id: str): + return "admin" + + async def list_threads( + self, pagination: cl_data.Pagination, filter: cl_data.ThreadFilter + ) -> cl_data.PaginatedResponse[cl_data.ThreadDict]: + return cl_data.PaginatedResponse( + data=[t for t in thread_history if t["id"] not in deleted_thread_ids], + pageInfo=cl_data.PageInfo(hasNextPage=False, endCursor=None), + ) + + async def get_thread(self, thread_id: str): + return next((t for t in thread_history if t["id"] == thread_id), None) + + async def delete_thread(self, thread_id: str): + deleted_thread_ids.append(thread_id) + + +cl_data._data_layer = TestDataLayer() + + +async def send_count(): + await cl.Message( + f"Create step counter: {create_step_counter}", disable_feedback=True + ).send() + + +@cl.on_chat_start +async def main(): + await cl.Message("Hello, send me a message!", disable_feedback=True).send() + await send_count() + + +@cl.on_message +async def handle_message(): + # Wait for queue to be flushed + await cl.sleep(2) + await send_count() + async with cl.Step(root=True, disable_feedback=True) as step: + step.output = "Thinking..." + await cl.Message("Ok!").send() + await send_count() + + +@cl.password_auth_callback +def auth_callback(username: str, password: str) -> Optional[cl.User]: + if (username, password) == ("admin", "admin"): + return cl.User(identifier="admin") + else: + return None + + +@cl.on_chat_resume +async def on_chat_resume(thread: cl_data.ThreadDict): + await cl.Message(f"Welcome back to {thread['metadata']['name']}").send() diff --git a/cypress/e2e/data_layer/spec.cy.ts b/cypress/e2e/data_layer/spec.cy.ts new file mode 100644 index 0000000000..640f1f7286 --- /dev/null +++ b/cypress/e2e/data_layer/spec.cy.ts @@ -0,0 +1,66 @@ +import { runTestServer, submitMessage } from '../../support/testUtils'; + +function login() { + cy.get("[id='email']").type('admin'); + cy.get("[id='password']").type('admin{enter}'); +} + +function feedback() { + submitMessage('Hello'); + cy.get('.negative-feedback-off').should('have.length', 1); + cy.get('.positive-feedback-off').should('have.length', 1).click(); + cy.get('#feedbackSubmit').click(); + cy.get('.positive-feedback-on').should('have.length', 1); +} + +function threadQueue() { + cy.get('.step').eq(1).should('contain', 'Create step counter: 0'); + cy.get('.step').eq(3).should('contain', 'Create step counter: 3'); + cy.get('.step').eq(6).should('contain', 'Create step counter: 6'); +} + +function threadList() { + cy.get('#thread-test1').should('contain', 'Thread 1'); + cy.get('#thread-test2').should('contain', 'Thread 2'); + + // Test thread page + cy.get('#thread-test1').click(); + cy.get('#thread-info').should('exist'); + cy.get('.step').should('have.length', 2); + cy.get('.step').eq(0).should('contain', 'Message 1'); + cy.get('.step').eq(1).should('contain', 'Message 2'); + + // Test thread delete + cy.get('#thread-test1').find("[data-testid='DeleteOutlineIcon']").click(); + cy.get("[type='button']").contains('Confirm').click(); + cy.get('#thread-test1').should('not.exist'); +} + +function resumeThread() { + cy.get('#thread-test2').click(); + cy.get(`#chat-input`).should('not.exist'); + cy.get('#resumeThread').click(); + cy.get(`#chat-input`).should('exist'); + + cy.get('.step').should('have.length', 3); + + cy.get('.step').eq(0).should('contain', 'Message 3'); + cy.get('.step').eq(1).should('contain', 'Message 4'); + cy.get('.step').eq(2).should('contain', 'Welcome back to thread 2'); +} + +describe('Data Layer', () => { + before(() => { + runTestServer(); + }); + + describe('Data Features with persistence', () => { + it('should login, submit feedback, wait for user input to create steps, browse thread history, delete a thread and then resume a thread', () => { + login(); + feedback(); + threadQueue(); + threadList(); + resumeThread(); + }); + }); +}); diff --git a/cypress/e2e/default_expand_cot/main.py b/cypress/e2e/default_expand_cot/main.py index 22cf103cc3..83f6d578b2 100644 --- a/cypress/e2e/default_expand_cot/main.py +++ b/cypress/e2e/default_expand_cot/main.py @@ -1,21 +1,26 @@ import chainlit as cl -@cl.on_message -async def main(message: cl.Message): - tool1_msg = cl.Message( - content="I need to use tool 2", author="Tool 1", parent_id=message.id - ) - await tool1_msg.send() +@cl.step(name="Tool 3", type="tool") +async def tool_3(): + return "Response from tool 3" - tool2_msg = cl.Message( - content="Response from tool 2", author="Tool 2", parent_id=tool1_msg.id - ) - await tool2_msg.send() - await cl.Message( - content="Response from tool 3", author="Tool 3", parent_id=message.id - ).send() +@cl.step(name="Tool 2", type="tool") +async def tool_2(): + await tool_3() + return "Response from tool 2" + + +@cl.step(name="Tool 1", type="tool") +async def tool_1(): + await tool_2() + return "Response from tool 1" + + +@cl.on_message +async def main(message: cl.Message): + await tool_1() await cl.Message( content="Final response", diff --git a/cypress/e2e/default_expand_cot/spec.cy.ts b/cypress/e2e/default_expand_cot/spec.cy.ts index b3a5e11bfd..1464d441b9 100644 --- a/cypress/e2e/default_expand_cot/spec.cy.ts +++ b/cypress/e2e/default_expand_cot/spec.cy.ts @@ -8,10 +8,11 @@ describe('Default Expand', () => { it('should be able to set the default_expand_messages field in the config to have the CoT expanded by default', () => { submitMessage('Hello'); - cy.get(".message:contains('Hello')").contains('I need to use tool 2'); - cy.get(".message:contains('I need to use tool 2')").contains( + cy.get(".step:contains('Hello')").contains('Response from tool 1'); + cy.get(".step:contains('Response from tool 1')").contains( 'Response from tool 2' ); - cy.get(".message:contains('Hello')").contains('Response from tool 3'); + cy.get(".step:contains('Hello')").contains('Response from tool 3'); + cy.get(".step:contains('Final response')"); }); }); diff --git a/cypress/e2e/global_elements/.chainlit/config.toml b/cypress/e2e/elements/.chainlit/config.toml similarity index 100% rename from cypress/e2e/global_elements/.chainlit/config.toml rename to cypress/e2e/elements/.chainlit/config.toml diff --git a/cypress/e2e/scoped_elements/cat.jpeg b/cypress/e2e/elements/cat.jpeg similarity index 100% rename from cypress/e2e/scoped_elements/cat.jpeg rename to cypress/e2e/elements/cat.jpeg diff --git a/cypress/e2e/scoped_elements/dummy.pdf b/cypress/e2e/elements/dummy.pdf similarity index 100% rename from cypress/e2e/scoped_elements/dummy.pdf rename to cypress/e2e/elements/dummy.pdf diff --git a/cypress/e2e/elements/main.py b/cypress/e2e/elements/main.py new file mode 100644 index 0000000000..aab0cd2cca --- /dev/null +++ b/cypress/e2e/elements/main.py @@ -0,0 +1,52 @@ +from chainlit.context import context + +import chainlit as cl + + +@cl.step +async def gen_img(): + if current_step := context.current_step: + current_step.elements = [ + cl.Image(path="./cat.jpeg", name="image1", display="inline") + ] + return "Here is a cat!" + + +@cl.on_chat_start +async def start(): + # Element should not be inlined or referenced + await cl.Message( + content="Here is image1, a nice image of a cat! As well as text1 and text2!", + ).send() + + # Step should be able to have elements + await gen_img() + + # Image should be inlined even if not referenced + await cl.Message( + content="Here a nice image of a cat! As well as text1 and text2!", + elements=[ + cl.Image(path="./cat.jpeg", name="image1", display="inline"), + cl.Pdf(path="./dummy.pdf", name="pdf1", display="inline"), + cl.Text( + content="Here is a side text document", name="text1", display="side" + ), + cl.Text( + content="Here is a page text document", name="text2", display="page" + ), + ], + ).send() + # Element references should work even if element names collide + await cl.Message( + content="Here a nice image of a cat! As well as text1 and text2!", + elements=[ + cl.Image(path="./cat.jpeg", name="image1", display="inline"), + cl.Pdf(path="./dummy.pdf", name="pdf1", display="inline"), + cl.Text( + content="Here is a side text document", name="text1", display="side" + ), + cl.Text( + content="Here is a page text document", name="text2", display="page" + ), + ], + ).send() diff --git a/cypress/e2e/elements/spec.cy.ts b/cypress/e2e/elements/spec.cy.ts new file mode 100644 index 0000000000..9e4ede1c3c --- /dev/null +++ b/cypress/e2e/elements/spec.cy.ts @@ -0,0 +1,52 @@ +import { runTestServer } from '../../support/testUtils'; + +describe('Elements', () => { + before(() => { + runTestServer(); + }); + + it('should be able to display inlined, side and page elements', () => { + cy.get('.step').eq(0).find('.inline-image').should('have.length', 0); + cy.get('.step').eq(0).find('.element-link').should('have.length', 0); + cy.get('.step').eq(0).find('.inline-pdf').should('have.length', 0); + + cy.get('#gen_img-done').should('exist').click(); + cy.get('.step').eq(1).find('.inline-image').should('have.length', 1); + + cy.get('.step').eq(2).find('.inline-image').should('have.length', 1); + cy.get('.step').eq(2).find('.element-link').should('have.length', 2); + cy.get('.step').eq(2).find('.inline-pdf').should('have.length', 1); + + cy.get('.step').eq(3).find('.inline-image').should('have.length', 1); + cy.get('.step').eq(3).find('.element-link').should('have.length', 2); + cy.get('.step').eq(3).find('.inline-pdf').should('have.length', 1); + + // Side + cy.get('.step') + .eq(2) + .find('.element-link') + .eq(0) + .should('contain', 'text1') + .click(); + const sideViewTitle = cy.get('#side-view-title'); + sideViewTitle.should('exist'); + sideViewTitle.should('contain', 'text1'); + + const sideViewContent = cy.get('#side-view-content'); + sideViewContent.should('exist'); + sideViewContent.should('contain', 'Here is a side text document'); + + // Page + cy.get('.step') + .eq(2) + .find('.element-link') + .eq(1) + .should('contain', 'text2') + .click(); + + const view = cy.get('#element-view'); + view.should('exist'); + view.should('contain', 'text2'); + view.should('contain', 'Here is a page text document'); + }); +}); diff --git a/cypress/e2e/error_handling/spec.cy.ts b/cypress/e2e/error_handling/spec.cy.ts index 2deddc41ca..71897a9a26 100644 --- a/cypress/e2e/error_handling/spec.cy.ts +++ b/cypress/e2e/error_handling/spec.cy.ts @@ -1,14 +1,14 @@ -import { runTestServer } from "../../support/testUtils"; +import { runTestServer } from '../../support/testUtils'; -describe("Error Handling", () => { +describe('Error Handling', () => { before(() => { runTestServer(); }); - it("should correctly display errors", () => { - cy.get(".message") - .should("have.length", 1) + it('should correctly display errors', () => { + cy.get('.step') + .should('have.length', 1) .eq(0) - .should("contain", "This is an error message"); + .should('contain', 'This is an error message'); }); }); diff --git a/cypress/e2e/file_element/spec.cy.ts b/cypress/e2e/file_element/spec.cy.ts index 50f74440c7..0e4fda3aa1 100644 --- a/cypress/e2e/file_element/spec.cy.ts +++ b/cypress/e2e/file_element/spec.cy.ts @@ -6,8 +6,8 @@ describe('file', () => { }); it('should be able to display a file element', () => { - cy.get('.message').should('have.length', 1); - cy.get('.message').eq(0).find('.inline-file').should('have.length', 4); + cy.get('.step').should('have.length', 1); + cy.get('.step').eq(0).find('.inline-file').should('have.length', 4); cy.get('.inline-file').should(($files) => { const downloads = $files diff --git a/cypress/e2e/global_elements/main.py b/cypress/e2e/global_elements/main.py deleted file mode 100644 index 2c02d5e3cd..0000000000 --- a/cypress/e2e/global_elements/main.py +++ /dev/null @@ -1,26 +0,0 @@ -import asyncio - -import chainlit as cl - - -@cl.on_chat_start -async def start(): - # Send elements to the UI concurrently - elements = [ - cl.Image(path="./cat.jpeg", name="image1", display="inline").send(), - cl.Text( - content="Here is a side text document", name="text1", display="side" - ).send(), - cl.Text( - content="Here is a page text document", name="text2", display="page" - ).send(), - ] - await asyncio.gather(*elements) - - await cl.Message( - content="Here is image1, a nice image of a cat! As well as text1 and text2!", - ).send() - - await cl.Message( - content="Here is a message without element reference!", - ).send() diff --git a/cypress/e2e/global_elements/spec.cy.ts b/cypress/e2e/global_elements/spec.cy.ts deleted file mode 100644 index 462a9b92a0..0000000000 --- a/cypress/e2e/global_elements/spec.cy.ts +++ /dev/null @@ -1,46 +0,0 @@ -import { runTestServer } from "../../support/testUtils"; - -describe("Global Elements", () => { - before(() => { - runTestServer(); - }); - - it("should be able to display inlined, side and page elements", () => { - cy.get(".message").should("have.length", 2); - - cy.get(".message").eq(1).find(".inline-image").should("have.length", 0); - cy.get(".message").eq(1).find(".element-link").should("have.length", 0); - - // Inlined - cy.get(".message").eq(0).find(".inline-image").should("have.length", 1); - cy.get(".message").eq(0).find(".element-link").should("have.length", 2); - cy.get(".message") - .eq(0) - .find(".element-link") - .eq(0) - .should("contain", "text1"); - cy.get(".message").eq(0).find(".element-link").eq(0).click(); - - // Side - const sideViewTitle = cy.get("#side-view-title"); - sideViewTitle.should("exist"); - sideViewTitle.should("contain", "text1"); - - const sideViewContent = cy.get("#side-view-content"); - sideViewContent.should("exist"); - sideViewContent.should("contain", "Here is a side text document"); - - // Page - cy.get(".message") - .eq(0) - .find(".element-link") - .eq(1) - .should("contain", "text2") - .click(); - - const view = cy.get("#element-view"); - view.should("exist"); - view.should("contain", "text2"); - view.should("contain", "Here is a page text document"); - }); -}); diff --git a/cypress/e2e/header_auth/main.py b/cypress/e2e/header_auth/main.py index 159bdd6d45..2693236a78 100644 --- a/cypress/e2e/header_auth/main.py +++ b/cypress/e2e/header_auth/main.py @@ -4,14 +4,14 @@ @cl.header_auth_callback -def header_auth_callback(headers) -> Optional[cl.AppUser]: +def header_auth_callback(headers) -> Optional[cl.User]: if headers.get("test-header"): - return cl.AppUser(username="admin", role="ADMIN", provider="header") + return cl.User(identifier="admin") else: return None @cl.on_chat_start async def on_chat_start(): - app_user = cl.user_session.get("user") - await cl.Message(f"Hello {app_user.username}").send() + user = cl.user_session.get("user") + await cl.Message(f"Hello {user.identifier}").send() diff --git a/cypress/e2e/header_auth/spec.cy.ts b/cypress/e2e/header_auth/spec.cy.ts index 8d7384f336..097d8be123 100644 --- a/cypress/e2e/header_auth/spec.cy.ts +++ b/cypress/e2e/header_auth/spec.cy.ts @@ -15,9 +15,9 @@ describe('Header auth', () => { }); cy.visit('/'); cy.get('.MuiAlert-message').should('not.exist'); - cy.get('.message').eq(0).should('contain', 'Hello admin'); + cy.get('.step').eq(0).should('contain', 'Hello admin'); cy.reload(); - cy.get('.message').eq(0).should('contain', 'Hello admin'); + cy.get('.step').eq(0).should('contain', 'Hello admin'); }); }); diff --git a/cypress/e2e/hide_prompt_playground/.chainlit/config.toml b/cypress/e2e/hide_prompt_playground/.chainlit/config.toml index 0c509af72c..c1b3ecf175 100644 --- a/cypress/e2e/hide_prompt_playground/.chainlit/config.toml +++ b/cypress/e2e/hide_prompt_playground/.chainlit/config.toml @@ -16,7 +16,7 @@ cache = false [features] # Show the prompt playground -prompt_playground = true +prompt_playground = false [UI] # Name of the app and chatbot. diff --git a/cypress/e2e/hide_prompt_playground/main.py b/cypress/e2e/hide_prompt_playground/main.py index 42ac299d7b..1bb9128f12 100644 --- a/cypress/e2e/hide_prompt_playground/main.py +++ b/cypress/e2e/hide_prompt_playground/main.py @@ -1,5 +1,3 @@ -from chainlit.prompt import Prompt - import chainlit as cl template = """Hello, this is a template. @@ -20,12 +18,19 @@ completion = "This is the original completion" +@cl.step(type="llm") +async def gen_response(): + res = "This is a message with a basic prompt" + if current_step := cl.context.current_step: + current_step.generation = cl.CompletionGeneration( + template=template, inputs=inputs, completion=res + ) + return res + + @cl.on_chat_start async def start(): + content = await gen_response() await cl.Message( - content="This is a message with a basic prompt", - prompt=Prompt( - template=template, - inputs=inputs, - ), + content=content, ).send() diff --git a/cypress/e2e/hide_prompt_playground/spec.cy.ts b/cypress/e2e/hide_prompt_playground/spec.cy.ts index 11b92a4e55..39534c273c 100644 --- a/cypress/e2e/hide_prompt_playground/spec.cy.ts +++ b/cypress/e2e/hide_prompt_playground/spec.cy.ts @@ -1,13 +1,12 @@ import { runTestServer } from '../../support/testUtils'; -describe('HidePromptPlayground', () => { +describe('DisablePromptPlayground', () => { before(() => { runTestServer(); }); - describe('Basic template', () => { - it('should not display the playground button', () => { - cy.get('.playground-button').should('not.exist'); - }); + it('should not display the playground button', () => { + cy.wait(2000); + cy.get('.playground-button').should('not.exist'); }); }); diff --git a/cypress/e2e/message_history/.chainlit/config.toml b/cypress/e2e/input_history/.chainlit/config.toml similarity index 100% rename from cypress/e2e/message_history/.chainlit/config.toml rename to cypress/e2e/input_history/.chainlit/config.toml diff --git a/cypress/e2e/message_history/main.py b/cypress/e2e/input_history/main.py similarity index 100% rename from cypress/e2e/message_history/main.py rename to cypress/e2e/input_history/main.py diff --git a/cypress/e2e/input_history/spec.cy.ts b/cypress/e2e/input_history/spec.cy.ts new file mode 100644 index 0000000000..e53461916a --- /dev/null +++ b/cypress/e2e/input_history/spec.cy.ts @@ -0,0 +1,34 @@ +import { + closeHistory, + openHistory, + runTestServer, + submitMessage +} from '../../support/testUtils'; + +describe('Input History', () => { + before(() => { + runTestServer(); + }); + + it('should be able to show the last message in the message history', () => { + openHistory(); + + cy.get('.history-item').should('have.length', 0); + cy.get('#history-empty').should('exist'); + + closeHistory(); + + const timestamp = Date.now().toString(); + + submitMessage(timestamp); + + openHistory(); + + cy.get('#history-empty').should('not.exist'); + cy.get('.history-item').should('have.length', 1); + cy.get('.history-item').eq(0).should('contain', timestamp).click(); + cy.get('.history-item').should('have.length', 0); + + cy.get('#chat-input').should('have.value', timestamp); + }); +}); diff --git a/cypress/e2e/llama_index_cb/main.py b/cypress/e2e/llama_index_cb/main.py index fd6e8bad56..a4411a9e79 100644 --- a/cypress/e2e/llama_index_cb/main.py +++ b/cypress/e2e/llama_index_cb/main.py @@ -13,6 +13,8 @@ async def start(): cb.on_event_start(CBEventType.RETRIEVE, payload={}) + await cl.sleep(0.2) + cb.on_event_end( CBEventType.RETRIEVE, payload={ @@ -24,6 +26,8 @@ async def start(): cb.on_event_start(CBEventType.LLM) + await cl.sleep(0.2) + response = ChatResponse(message=ChatMessage(content="This is the LLM response")) cb.on_event_end( CBEventType.LLM, diff --git a/cypress/e2e/llama_index_cb/spec.cy.ts b/cypress/e2e/llama_index_cb/spec.cy.ts index 1c9dbc80f0..408eafba71 100644 --- a/cypress/e2e/llama_index_cb/spec.cy.ts +++ b/cypress/e2e/llama_index_cb/spec.cy.ts @@ -6,13 +6,13 @@ describe('Llama Index Callback', () => { }); it('should be able to send messages to the UI with prompts and elements', () => { - cy.get('.message').should('have.length', 1); + cy.get('.step').should('have.length', 1); cy.get('#llm-done').should('exist').click(); - cy.get('.message').should('have.length', 3); + cy.get('.step').should('have.length', 3); - cy.get('.message') + cy.get('.step') .eq(1) .find('.element-link') .eq(0) diff --git a/cypress/e2e/message_history/spec.cy.ts b/cypress/e2e/message_history/spec.cy.ts deleted file mode 100644 index 6bf433fb3a..0000000000 --- a/cypress/e2e/message_history/spec.cy.ts +++ /dev/null @@ -1,34 +0,0 @@ -import { - closeHistory, - openHistory, - runTestServer, - submitMessage, -} from "../../support/testUtils"; - -describe("Message History", () => { - before(() => { - runTestServer(); - }); - - it("should be able to show the last message in the message history", () => { - openHistory(); - - cy.get(".history-item").should("have.length", 0); - cy.get("#history-empty").should("exist"); - - closeHistory(); - - const timestamp = Date.now().toString(); - - submitMessage(timestamp); - - openHistory(); - - cy.get("#history-empty").should("not.exist"); - cy.get(".history-item").should("have.length", 1); - cy.get(".history-item").eq(0).should("contain", timestamp).click(); - cy.get(".history-item").should("have.length", 0); - - cy.get("#chat-input").should("have.value", timestamp); - }); -}); diff --git a/cypress/e2e/on_chat_start/spec.cy.ts b/cypress/e2e/on_chat_start/spec.cy.ts index 44840c8a9d..fa13076d0f 100644 --- a/cypress/e2e/on_chat_start/spec.cy.ts +++ b/cypress/e2e/on_chat_start/spec.cy.ts @@ -6,7 +6,7 @@ describe('on_chat_start', () => { }); it('should correctly run on_chat_start', () => { - const messages = cy.get('.message'); + const messages = cy.get('.step'); messages.should('have.length', 1); messages.eq(0).should('contain.text', 'Hello!'); diff --git a/cypress/e2e/password_auth/main.py b/cypress/e2e/password_auth/main.py index 61497ca3ce..025f463e90 100644 --- a/cypress/e2e/password_auth/main.py +++ b/cypress/e2e/password_auth/main.py @@ -4,14 +4,14 @@ @cl.password_auth_callback -def auth_callback(username: str, password: str) -> Optional[cl.AppUser]: +def auth_callback(username: str, password: str) -> Optional[cl.User]: if (username, password) == ("admin", "admin"): - return cl.AppUser(username="admin", role="ADMIN", provider="credentials") + return cl.User(identifier="admin") else: return None @cl.on_chat_start async def on_chat_start(): - app_user = cl.user_session.get("user") - await cl.Message(f"Hello {app_user.username}").send() + user = cl.user_session.get("user") + await cl.Message(f"Hello {user.identifier}").send() diff --git a/cypress/e2e/password_auth/spec.cy.ts b/cypress/e2e/password_auth/spec.cy.ts index fb6a3d541f..ebe7447b1b 100644 --- a/cypress/e2e/password_auth/spec.cy.ts +++ b/cypress/e2e/password_auth/spec.cy.ts @@ -18,10 +18,10 @@ describe('Password Auth', () => { cy.get("input[name='password']").type('admin'); cy.get("button[type='submit']").click(); cy.get('.MuiAlert-message').should('not.exist'); - cy.get('.message').eq(0).should('contain', 'Hello admin'); + cy.get('.step').eq(0).should('contain', 'Hello admin'); cy.reload(); cy.get("input[name='email']").should('not.exist'); - cy.get('.message').eq(0).should('contain', 'Hello admin'); + cy.get('.step').eq(0).should('contain', 'Hello admin'); }); }); diff --git a/cypress/e2e/plotly/spec.cy.ts b/cypress/e2e/plotly/spec.cy.ts index 10dd6c7ebe..c4fad29dfa 100644 --- a/cypress/e2e/plotly/spec.cy.ts +++ b/cypress/e2e/plotly/spec.cy.ts @@ -6,7 +6,7 @@ describe('plotly', () => { }); it('should be able to display an inline chart', () => { - cy.get('.message').should('have.length', 1); - cy.get('.message').eq(0).find('.inline-plotly').should('have.length', 1); + cy.get('.step').should('have.length', 1); + cy.get('.step').eq(0).find('.inline-plotly').should('have.length', 1); }); }); diff --git a/cypress/e2e/prompt_playground/main.py b/cypress/e2e/prompt_playground/main.py index 4b2a39f17b..40ff5a7b7a 100644 --- a/cypress/e2e/prompt_playground/main.py +++ b/cypress/e2e/prompt_playground/main.py @@ -1,7 +1,6 @@ from provider import ChatTestLLM, TestLLM import chainlit as cl -from chainlit.prompt import Prompt, PromptMessage template = """Hello, this is a template. This is a simple variable {variable1} @@ -25,45 +24,38 @@ @cl.on_chat_start async def start(): - await cl.Message( - content="This is a message with a basic prompt", - prompt=Prompt( - provider=TestLLM.id, - completion=completion, - template=template, - inputs=inputs, - ), - ).send() - - await cl.Message( - content="This is a message with only a formatted basic prompt", - prompt=Prompt(provider=TestLLM.id, completion=completion, formatted=formatted), - ).send() - - await cl.Message( - content="This is a message with a chat prompt", - prompt=Prompt( + async with cl.Step() as step: + step.generation = cl.CompletionGeneration( + provider=TestLLM.id, template=template, inputs=inputs, completion=completion + ) + step.output = "This is a message with a basic prompt" + + async with cl.Step() as step: + step.generation = cl.CompletionGeneration( + provider=TestLLM.id, completion=completion, formatted=formatted + ) + step.output = "This is a message with only a formatted basic prompt" + + async with cl.Step() as step: + step.generation = cl.ChatGeneration( provider=ChatTestLLM.id, completion=completion, - template=template, inputs=inputs, messages=[ - PromptMessage(template=template, role="system"), - PromptMessage(template=template, role="system"), + cl.GenerationMessage(template=template, role="system"), + cl.GenerationMessage(template=template, role="system"), ], - ), - ).send() + ) + step.output = "This is a message with a chat prompt" - await cl.Message( - content="This is a message with only a formatted chat prompt", - prompt=Prompt( + async with cl.Step() as step: + step.generation = cl.ChatGeneration( provider=ChatTestLLM.id, completion=completion, - template=template, inputs=inputs, messages=[ - PromptMessage(formatted=formatted, role="system"), - PromptMessage(formatted=formatted, role="system"), + cl.GenerationMessage(formatted=formatted, role="system"), + cl.GenerationMessage(formatted=formatted, role="system"), ], - ), - ).send() + ) + step.output = "This is a message with only a formatted chat prompt" diff --git a/cypress/e2e/prompt_playground/provider.py b/cypress/e2e/prompt_playground/provider.py index be9e367bc8..ff5aeb072b 100644 --- a/cypress/e2e/prompt_playground/provider.py +++ b/cypress/e2e/prompt_playground/provider.py @@ -1,12 +1,12 @@ import os +from chainlit.input_widget import Select, Slider +from chainlit.playground.config import BaseProvider, add_llm_provider +from chainlit.playground.providers.langchain import LangchainGenericProvider from fastapi.responses import StreamingResponse from langchain.llms.fake import FakeListLLM import chainlit as cl -from chainlit.input_widget import Select, Slider -from chainlit.playground.config import BaseProvider, add_llm_provider -from chainlit.playground.providers.langchain import LangchainGenericProvider os.environ["TEST_LLM_API_KEY"] = "sk..." @@ -15,8 +15,8 @@ class TestLLMProvider(BaseProvider): async def create_completion(self, request): await super().create_completion(request) - self.create_prompt(request) - self.require_settings(request.prompt.settings) + self.create_generation(request) + self.require_settings(request.generation.settings) stream = ["This ", "is ", "the ", "test ", "completion"] diff --git a/cypress/e2e/pyplot/spec.cy.ts b/cypress/e2e/pyplot/spec.cy.ts index 9b5f1ab609..22447cfb4a 100644 --- a/cypress/e2e/pyplot/spec.cy.ts +++ b/cypress/e2e/pyplot/spec.cy.ts @@ -6,7 +6,7 @@ describe('pyplot', () => { }); it('should be able to display an inline chart', () => { - cy.get('.message').should('have.length', 1); - cy.get('.message').eq(0).find('.inline-image').should('have.length', 1); + cy.get('.step').should('have.length', 1); + cy.get('.step').eq(0).find('.inline-image').should('have.length', 1); }); }); diff --git a/cypress/e2e/remove_elements/main.py b/cypress/e2e/remove_elements/main.py index c4081e088e..2ef09eb115 100644 --- a/cypress/e2e/remove_elements/main.py +++ b/cypress/e2e/remove_elements/main.py @@ -3,13 +3,26 @@ @cl.on_chat_start async def start(): - image = cl.Image(name="image1", display="inline", path="../../fixtures/cat.jpeg") + step_image = cl.Image( + name="image1", display="inline", path="../../fixtures/cat.jpeg" + ) + msg_image = cl.Image( + name="image1", display="inline", path="../../fixtures/cat.jpeg" + ) + + async with cl.Step() as step: + step.elements = [ + step_image, + cl.Image(name="image2", display="inline", path="../../fixtures/cat.jpeg"), + ] + step.output = "This step has an image" await cl.Message( content="This message has an image", elements=[ - image, + msg_image, cl.Image(name="image2", display="inline", path="../../fixtures/cat.jpeg"), ], ).send() - await image.remove() + await msg_image.remove() + await step_image.remove() diff --git a/cypress/e2e/remove_elements/spec.cy.ts b/cypress/e2e/remove_elements/spec.cy.ts index b0875199fa..8947478f7a 100644 --- a/cypress/e2e/remove_elements/spec.cy.ts +++ b/cypress/e2e/remove_elements/spec.cy.ts @@ -6,7 +6,8 @@ describe('remove_elements', () => { }); it('should be able to remove elements', () => { - cy.get('.message').should('have.length', 1); - cy.get('.inline-image').should('have.length', 1); + cy.get('.step').should('have.length', 2); + cy.get('.step').eq(0).find('.inline-image').should('have.length', 1); + cy.get('.step').eq(1).find('.inline-image').should('have.length', 1); }); }); diff --git a/cypress/e2e/remove_message/spec.cy.ts b/cypress/e2e/remove_message/spec.cy.ts deleted file mode 100644 index 9630e3486e..0000000000 --- a/cypress/e2e/remove_message/spec.cy.ts +++ /dev/null @@ -1,31 +0,0 @@ -import { runTestServer, submitMessage } from '../../support/testUtils'; - -describe('Delete Message', () => { - before(() => { - runTestServer(); - }); - - it('should be able to delete a message', () => { - cy.get('.message').should('have.length', 1); - cy.get('.message').eq(0).should('contain', 'Message 1'); - - cy.get('#chatbot-loading').should('exist'); - cy.get('#chatbot-loading').click(); - cy.get('.message').eq(1).should('contain', 'Child 1'); - - cy.get('.message').should('have.length', 2); - - cy.get('.message').eq(1).should('contain', 'Message 2'); - cy.get('.message').should('have.length', 1); - cy.get('.message').eq(0).should('contain', 'Message 2'); - cy.get('.message').should('have.length', 0); - - cy.get('.message').should('have.length', 1); - cy.get('.message').eq(0).should('contain', 'Message 3'); - - submitMessage('foo'); - - cy.get('.message').should('have.length', 1); - cy.get('.message').eq(0).should('contain', 'foo'); - }); -}); diff --git a/cypress/e2e/remove_message/.chainlit/config.toml b/cypress/e2e/remove_step/.chainlit/config.toml similarity index 100% rename from cypress/e2e/remove_message/.chainlit/config.toml rename to cypress/e2e/remove_step/.chainlit/config.toml diff --git a/cypress/e2e/remove_message/main.py b/cypress/e2e/remove_step/main.py similarity index 78% rename from cypress/e2e/remove_message/main.py rename to cypress/e2e/remove_step/main.py index 35d1a3cada..00ff44d12d 100644 --- a/cypress/e2e/remove_message/main.py +++ b/cypress/e2e/remove_step/main.py @@ -6,10 +6,11 @@ async def main(): msg1 = cl.Message(content="Message 1") await msg1.send() - msg1_child1 = cl.Message(content="Child 1", parent_id=msg1.id) - await msg1_child1.send() + async with cl.Step() as child1: + child1.output = "Child 1" + await cl.sleep(1) - await msg1_child1.remove() + await child1.remove() msg2 = cl.Message(content="Message 2") await msg2.send() diff --git a/cypress/e2e/remove_step/spec.cy.ts b/cypress/e2e/remove_step/spec.cy.ts new file mode 100644 index 0000000000..3e0b16ef07 --- /dev/null +++ b/cypress/e2e/remove_step/spec.cy.ts @@ -0,0 +1,31 @@ +import { runTestServer, submitMessage } from '../../support/testUtils'; + +describe('Remove Step', () => { + before(() => { + runTestServer(); + }); + + it('should be able to remove a step', () => { + cy.get('.step').should('have.length', 1); + cy.get('.step').eq(0).should('contain', 'Message 1'); + + cy.get('#chatbot-loading').should('exist'); + cy.get('#chatbot-loading').click(); + cy.get('.step').eq(1).should('contain', 'Child 1'); + + cy.get('.step').should('have.length', 2); + + cy.get('.step').eq(1).should('contain', 'Message 2'); + cy.get('.step').should('have.length', 1); + cy.get('.step').eq(0).should('contain', 'Message 2'); + cy.get('.step').should('have.length', 0); + + cy.get('.step').should('have.length', 1); + cy.get('.step').eq(0).should('contain', 'Message 3'); + + submitMessage('foo'); + + cy.get('.step').should('have.length', 1); + cy.get('.step').eq(0).should('contain', 'foo'); + }); +}); diff --git a/cypress/e2e/scoped_elements/main.py b/cypress/e2e/scoped_elements/main.py deleted file mode 100644 index b34afdf37d..0000000000 --- a/cypress/e2e/scoped_elements/main.py +++ /dev/null @@ -1,26 +0,0 @@ -import chainlit as cl - - -@cl.on_chat_start -async def start(): - elements = [ - cl.Image(path="./cat.jpeg", name="image1", display="inline"), - cl.Pdf(path="./dummy.pdf", name="pdf1", display="inline"), - cl.Text(content="Here is a side text document", name="text1", display="side"), - cl.Text(content="Here is a page text document", name="text2", display="page"), - ] - - # Element should not be inlined or referenced - await cl.Message( - content="Here is image1, a nice image of a cat! As well as text1 and text2!", - ).send() - # Image should be inlined even if not referenced - await cl.Message( - content="Here a nice image of a cat! As well as text1 and text2!", - elements=elements, - ).send() - # Element references should work even if element names collide - await cl.Message( - content="Here a nice image of a cat! As well as text1 and text2!", - elements=elements, - ).send() diff --git a/cypress/e2e/scoped_elements/spec.cy.ts b/cypress/e2e/scoped_elements/spec.cy.ts deleted file mode 100644 index 9e10affa14..0000000000 --- a/cypress/e2e/scoped_elements/spec.cy.ts +++ /dev/null @@ -1,23 +0,0 @@ -import { runTestServer } from "../../support/testUtils"; - -describe("Scoped Elements", () => { - before(() => { - runTestServer(); - }); - - it("should be able to display inlined, side and page elements", () => { - cy.get(".message").should("have.length", 3); - - cy.get(".message").eq(0).find(".inline-image").should("have.length", 0); - cy.get(".message").eq(0).find(".element-link").should("have.length", 0); - cy.get(".message").eq(0).find(".inline-pdf").should("have.length", 0); - - cy.get(".message").eq(1).find(".inline-image").should("have.length", 1); - cy.get(".message").eq(1).find(".element-link").should("have.length", 2); - cy.get(".message").eq(1).find(".inline-pdf").should("have.length", 1); - - cy.get(".message").eq(2).find(".inline-image").should("have.length", 1); - cy.get(".message").eq(2).find(".element-link").should("have.length", 2); - cy.get(".message").eq(2).find(".inline-pdf").should("have.length", 1); - }); -}); diff --git a/cypress/e2e/sdk_availability/spec.cy.ts b/cypress/e2e/sdk_availability/spec.cy.ts deleted file mode 100644 index 393fc5318a..0000000000 --- a/cypress/e2e/sdk_availability/spec.cy.ts +++ /dev/null @@ -1,21 +0,0 @@ -import { runTestServer } from "../../support/testUtils"; - -describe("Emitter should be reachable from all contexts", () => { - before(() => { - runTestServer(); - }); - - it("should find the Emitter from async, make_async and async_from_sync contexts", () => { - cy.get(".message").should("have.length", 3); - - cy.get(".message").eq(0).should("contain", "emitter from async found!"); - - cy.get(".message") - .eq(1) - .should("contain", "emitter from make_async found!"); - - cy.get(".message") - .eq(2) - .should("contain", "emitter from async_from_sync found!"); - }); -}); diff --git a/cypress/e2e/scoped_elements/.chainlit/config.toml b/cypress/e2e/step/.chainlit/config.toml similarity index 100% rename from cypress/e2e/scoped_elements/.chainlit/config.toml rename to cypress/e2e/step/.chainlit/config.toml diff --git a/cypress/e2e/step/main.py b/cypress/e2e/step/main.py new file mode 100644 index 0000000000..543ad1ad7b --- /dev/null +++ b/cypress/e2e/step/main.py @@ -0,0 +1,25 @@ +import chainlit as cl + + +def tool_3(): + with cl.Step(name="Tool 3", type="tool") as s: + cl.run_sync(cl.sleep(2)) + s.output = "Response from tool 3" + + +@cl.step +def tool_2(): + tool_3() + cl.run_sync(cl.Message(content="Message from tool 2").send()) + return "Response from tool 2" + + +@cl.step(name="Tool 1", type="tool") +def tool_1(): + tool_2() + return "Response from tool 1" + + +@cl.on_message +async def main(message: cl.Message): + tool_1() diff --git a/cypress/e2e/step/main_async.py b/cypress/e2e/step/main_async.py new file mode 100644 index 0000000000..bb8218cc12 --- /dev/null +++ b/cypress/e2e/step/main_async.py @@ -0,0 +1,25 @@ +import chainlit as cl + + +async def tool_3(): + async with cl.Step(name="Tool 3", type="tool") as s: + await cl.sleep(2) + s.output = "Response from tool 3" + + +@cl.step +async def tool_2(): + await tool_3() + await cl.Message(content="Message from tool 2").send() + return "Response from tool 2" + + +@cl.step(name="Tool 1", type="tool") +async def tool_1(): + await tool_2() + return "Response from tool 1" + + +@cl.on_message +async def main(message: cl.Message): + await tool_1() diff --git a/cypress/e2e/step/spec.cy.ts b/cypress/e2e/step/spec.cy.ts new file mode 100644 index 0000000000..705aed2266 --- /dev/null +++ b/cypress/e2e/step/spec.cy.ts @@ -0,0 +1,30 @@ +import { + describeSyncAsync, + runTestServer, + submitMessage +} from '../../support/testUtils'; + +describeSyncAsync('Step', () => { + before(() => { + runTestServer(); + }); + + it('should be able to nest steps', () => { + submitMessage('Hello'); + + cy.get('#tool-1-loading').should('exist'); + cy.get('#tool-1-loading').click(); + + cy.get('#tool_2-loading').should('exist'); + cy.get('#tool_2-loading').click(); + + cy.get('#tool-3-loading').should('exist'); + cy.get('#tool-3-loading').click(); + + cy.get('#tool-1-done').should('exist'); + cy.get('#tool_2-done').should('exist'); + cy.get('#tool-3-done').should('exist'); + + cy.get('.step').should('have.length', 5); + }); +}); diff --git a/cypress/e2e/stop_task/main_async.py b/cypress/e2e/stop_task/main_async.py index 93fba67f43..12ed98bd23 100644 --- a/cypress/e2e/stop_task/main_async.py +++ b/cypress/e2e/stop_task/main_async.py @@ -4,7 +4,7 @@ @cl.on_chat_start async def start(): await cl.Message(content="Message 1").send() - await cl.sleep(5) + await cl.sleep(1) await cl.Message(content="Message 2").send() diff --git a/cypress/e2e/stop_task/main_sync.py b/cypress/e2e/stop_task/main_sync.py index e5cd0fa82d..2af957552d 100644 --- a/cypress/e2e/stop_task/main_sync.py +++ b/cypress/e2e/stop_task/main_sync.py @@ -4,7 +4,7 @@ def sync_function(): - time.sleep(5) + time.sleep(1) @cl.on_chat_start diff --git a/cypress/e2e/stop_task/spec.cy.ts b/cypress/e2e/stop_task/spec.cy.ts index 51187096c4..f8c3d51930 100644 --- a/cypress/e2e/stop_task/spec.cy.ts +++ b/cypress/e2e/stop_task/spec.cy.ts @@ -1,35 +1,33 @@ import { describeSyncAsync, runTestServer, - submitMessage, -} from "../../support/testUtils"; + submitMessage +} from '../../support/testUtils'; -describeSyncAsync("Stop task", (mode) => { +describeSyncAsync('Stop task', (mode) => { before(() => { runTestServer(mode); }); - it("should be able to stop a task", () => { - cy.get(".message").should("have.length", 1); + it('should be able to stop a task', () => { + cy.get('.step').should('have.length', 1); - cy.get(".message").last().should("contain.text", "Message 1"); - cy.get("#stop-button").should("exist").click(); - cy.get("#stop-button").should("not.exist"); + cy.get('.step').last().should('contain.text', 'Message 1'); + cy.get('#stop-button').should('exist').click(); + cy.get('#stop-button').should('not.exist'); - cy.get(".message").should("have.length", 2); + cy.get('.step').should('have.length', 2); - cy.get(".message") - .last() - .should("contain.text", "Task stopped by the user."); + cy.get('.step').last().should('contain.text', 'Task stopped by the user.'); cy.wait(5000); - cy.get(".message").should("have.length", 2); + cy.get('.step').should('have.length', 2); - submitMessage("Hello"); + submitMessage('Hello'); - cy.get(".message").should("have.length", 4); + cy.get('.step').should('have.length', 4); - cy.get(".message").last().should("contain.text", "World"); + cy.get('.step').last().should('contain.text', 'World'); }); }); diff --git a/cypress/e2e/streaming/main.py b/cypress/e2e/streaming/main.py index fd81d05ee6..88dd019e4c 100644 --- a/cypress/e2e/streaming/main.py +++ b/cypress/e2e/streaming/main.py @@ -10,13 +10,27 @@ async def main(): msg = cl.Message(content="") for token in token_list: await msg.stream_token(token) - await cl.sleep(0.5) + await cl.sleep(0.2) await msg.send() msg = cl.Message(content="") for seq in sequence_list: await msg.stream_token(token=seq, is_sequence=True) - await cl.sleep(0.5) + await cl.sleep(0.2) await msg.send() + + step = cl.Step() + for token in token_list: + await step.stream_token(token) + await cl.sleep(0.2) + + await step.send() + + step = cl.Step() + for seq in sequence_list: + await step.stream_token(token=seq, is_sequence=True) + await cl.sleep(0.2) + + await step.send() diff --git a/cypress/e2e/streaming/spec.cy.ts b/cypress/e2e/streaming/spec.cy.ts index 235d140e49..a3bfb1c330 100644 --- a/cypress/e2e/streaming/spec.cy.ts +++ b/cypress/e2e/streaming/spec.cy.ts @@ -1,27 +1,35 @@ -import { runTestServer } from "../../support/testUtils"; +import { runTestServer } from '../../support/testUtils'; -function testStreamedMessage(index: number) { - const tokenList = ["the", "quick", "brown", "fox"]; +function testStreamedTest(index: number) { + const tokenList = ['the', 'quick', 'brown', 'fox']; for (const token of tokenList) { - cy.get(".message").eq(index).should("contain", token); + cy.get('.step').eq(index).should('contain', token); } - cy.get(".message").eq(index).should("contain", tokenList.join(" ")); + cy.get('.step').eq(index).should('contain', tokenList.join(' ')); } -describe("Streaming", () => { +describe('Streaming', () => { before(() => { runTestServer(); }); - it("should be able to stream a message", () => { - cy.get(".message").should("have.length", 1); + it('should be able to stream a message', () => { + cy.get('.step').should('have.length', 1); - testStreamedMessage(0); + testStreamedTest(0); - cy.get(".message").should("have.length", 1); + cy.get('.step').should('have.length', 1); - testStreamedMessage(1); + testStreamedTest(1); - cy.get(".message").should("have.length", 2); + cy.get('.step').should('have.length', 2); + + testStreamedTest(2); + + cy.get('.step').should('have.length', 3); + + testStreamedTest(3); + + cy.get('.step').should('have.length', 4); }); }); diff --git a/cypress/e2e/tasklist/main.py b/cypress/e2e/tasklist/main.py index 2765f423c4..57c33d6c40 100644 --- a/cypress/e2e/tasklist/main.py +++ b/cypress/e2e/tasklist/main.py @@ -41,6 +41,7 @@ async def main(): task_list.status = "Running..." for i in range(17): task = cl.Task(title=fake_tasks[i]) + await cl.sleep(0.2) await task_list.add_task(task) await task_list.send() @@ -50,6 +51,7 @@ async def main(): for i in range(9): task_list.tasks[i].status = cl.TaskStatus.DONE task_list.tasks[i + 1].status = cl.TaskStatus.RUNNING + await cl.sleep(0.2) await task_list.send() task_list.tasks[9].status = cl.TaskStatus.FAILED diff --git a/cypress/e2e/tasklist/spec.cy.ts b/cypress/e2e/tasklist/spec.cy.ts index 7ead3b1cc0..30bdc33c51 100644 --- a/cypress/e2e/tasklist/spec.cy.ts +++ b/cypress/e2e/tasklist/spec.cy.ts @@ -6,7 +6,7 @@ describe('tasklist', () => { }); it('should display the tasklist ', () => { - cy.get('.message').should('have.length', 0); + cy.get('.step').should('have.length', 0); cy.get('.tasklist').should('have.length', 2); cy.get('.tasklist.tasklist-mobile').should('not.be.visible'); cy.get('.tasklist.tasklist-mobile .task').should('not.be.visible'); diff --git a/cypress/e2e/update_message/.chainlit/config.toml b/cypress/e2e/update_message/.chainlit/config.toml deleted file mode 100644 index 0c509af72c..0000000000 --- a/cypress/e2e/update_message/.chainlit/config.toml +++ /dev/null @@ -1,62 +0,0 @@ -[project] -# Whether to enable telemetry (default: true). No personal data is collected. -enable_telemetry = true - -# List of environment variables to be provided by each user to use the app. -user_env = [] - -# Duration (in seconds) during which the session is saved when the connection is lost -session_timeout = 3600 - -# Enable third parties caching (e.g LangChain cache) -cache = false - -# Follow symlink for asset mount (see https://github.com/Chainlit/chainlit/issues/317) -# follow_symlink = false - -[features] -# Show the prompt playground -prompt_playground = true - -[UI] -# Name of the app and chatbot. -name = "Chatbot" - -# Description of the app and chatbot. This is used for HTML tags. -# description = "" - -# Large size content are by default collapsed for a cleaner ui -default_collapse_content = true - -# The default value for the expand messages settings. -default_expand_messages = false - -# Hide the chain of thought details from the user in the UI. -hide_cot = false - -# Link to your github repo. This will add a github button in the UI's header. -# github = "" - -# Override default MUI light theme. (Check theme.ts) -[UI.theme.light] - #background = "#FAFAFA" - #paper = "#FFFFFF" - - [UI.theme.light.primary] - #main = "#F80061" - #dark = "#980039" - #light = "#FFE7EB" - -# Override default MUI dark theme. (Check theme.ts) -[UI.theme.dark] - #background = "#FAFAFA" - #paper = "#FFFFFF" - - [UI.theme.dark.primary] - #main = "#F80061" - #dark = "#980039" - #light = "#FFE7EB" - - -[meta] -generated_by = "0.6.402" diff --git a/cypress/e2e/update_message/spec.cy.ts b/cypress/e2e/update_message/spec.cy.ts deleted file mode 100644 index b0c0a57133..0000000000 --- a/cypress/e2e/update_message/spec.cy.ts +++ /dev/null @@ -1,13 +0,0 @@ -import { runTestServer } from "../../support/testUtils"; - -describe("Update Message", () => { - before(() => { - runTestServer(); - }); - - it("should be able to update a message", () => { - cy.get(".message").should("have.length", 1); - cy.get(".message").eq(0).should("contain", "Hello"); - cy.get(".message").eq(0).should("contain", "Hello again!"); - }); -}); diff --git a/cypress/e2e/sdk_availability/.chainlit/config.toml b/cypress/e2e/update_step/.chainlit/config.toml similarity index 100% rename from cypress/e2e/sdk_availability/.chainlit/config.toml rename to cypress/e2e/update_step/.chainlit/config.toml diff --git a/cypress/e2e/update_message/main.py b/cypress/e2e/update_step/main.py similarity index 56% rename from cypress/e2e/update_message/main.py rename to cypress/e2e/update_step/main.py index 1daa6e6618..fb407253cc 100644 --- a/cypress/e2e/update_message/main.py +++ b/cypress/e2e/update_step/main.py @@ -5,6 +5,13 @@ async def main(): msg = cl.Message(content="Hello!") await msg.send() - await cl.sleep(2) + + async with cl.Step() as step: + step.output = "Foo" + + await cl.sleep(1) msg.content = "Hello again!" await msg.update() + + step.output += " Bar" + await step.update() diff --git a/cypress/e2e/update_step/spec.cy.ts b/cypress/e2e/update_step/spec.cy.ts new file mode 100644 index 0000000000..cc8fd02bec --- /dev/null +++ b/cypress/e2e/update_step/spec.cy.ts @@ -0,0 +1,18 @@ +import { runTestServer } from '../../support/testUtils'; + +describe('Update Step', () => { + before(() => { + runTestServer(); + }); + + it('should be able to update a step', () => { + cy.get('.step').should('have.length', 1); + cy.get('#chatbot-loading').should('exist').click(); + cy.get('.step').should('have.length', 2); + cy.get('.step').eq(0).should('contain', 'Hello!'); + cy.get('.step').eq(1).should('contain', 'Foo'); + + cy.get('.step').eq(0).should('contain', 'Hello again!'); + cy.get('.step').eq(1).should('contain', 'Foo Bar'); + }); +}); diff --git a/cypress/e2e/upload_attachments/spec.cy.ts b/cypress/e2e/upload_attachments/spec.cy.ts index 7b5f7b3461..23cb6070fc 100644 --- a/cypress/e2e/upload_attachments/spec.cy.ts +++ b/cypress/e2e/upload_attachments/spec.cy.ts @@ -7,17 +7,17 @@ describe('Upload attachments', () => { const shouldHaveInlineAttachments = () => { submitMessage('Message with attachments'); - cy.get('.message').should('have.length', 5); - cy.get('.message') + cy.get('.step').should('have.length', 5); + cy.get('.step') .eq(1) .should('contain', 'Content: Message with attachments'); - cy.get('.message') + cy.get('.step') .eq(2) .should('contain', 'Received element 0: state_of_the_union.txt'); - cy.get('.message').eq(3).should('contain', 'Received element 1: hello.cpp'); - cy.get('.message').eq(4).should('contain', 'Received element 2: hello.py'); + cy.get('.step').eq(3).should('contain', 'Received element 1: hello.cpp'); + cy.get('.step').eq(4).should('contain', 'Received element 2: hello.py'); - cy.get('.message').eq(0).find('.inline-file').should('have.length', 3); + cy.get('.step').eq(0).find('.inline-file').should('have.length', 3); cy.get('.inline-file') .eq(0) .should('have.attr', 'download', 'state_of_the_union.txt'); @@ -30,13 +30,15 @@ describe('Upload attachments', () => { cy.fixture('hello.cpp', 'utf-8').as('cppFile'); cy.fixture('hello.py', 'utf-8').as('pyFile'); + // Wait for the socket connection to be created + cy.wait(1000); + /** * Should be able to upload file from D&D input */ cy.get("[id='#upload-drop-input']").should('exist'); // Upload a text file cy.get("[id='#upload-drop-input']").selectFile('@txtFile', { force: true }); - // cy.get('#upload-drop-input').selectFile('@txtFile', { force: true }); cy.get('#attachments').should('contain', 'state_of_the_union.txt'); // Upload a C++ file diff --git a/cypress/e2e/user_env/spec.cy.ts b/cypress/e2e/user_env/spec.cy.ts index 8d60d6a295..735f2681b1 100644 --- a/cypress/e2e/user_env/spec.cy.ts +++ b/cypress/e2e/user_env/spec.cy.ts @@ -16,7 +16,7 @@ describe('User Env', () => { submitMessage('Hello'); - cy.get('.message').should('have.length', 2); - cy.get('.message').eq(1).should('contain', keyValue); + cy.get('.step').should('have.length', 2); + cy.get('.step').eq(1).should('contain', keyValue); }); }); diff --git a/cypress/e2e/user_session/spec.cy.ts b/cypress/e2e/user_session/spec.cy.ts index 1cc713346e..e62cc86a9e 100644 --- a/cypress/e2e/user_session/spec.cy.ts +++ b/cypress/e2e/user_session/spec.cy.ts @@ -16,24 +16,24 @@ describe('User Session', () => { it('should be able to store data related per user session', () => { submitMessage('Hello 1'); - cy.get('.message').should('have.length', 2); - cy.get('.message').eq(1).should('contain', 'Prev message: None'); + cy.get('.step').should('have.length', 2); + cy.get('.step').eq(1).should('contain', 'Prev message: None'); submitMessage('Hello 2'); - cy.get('.message').should('have.length', 4); - cy.get('.message').eq(3).should('contain', 'Prev message: Hello 1'); + cy.get('.step').should('have.length', 4); + cy.get('.step').eq(3).should('contain', 'Prev message: Hello 1'); newSession(); submitMessage('Hello 3'); - cy.get('.message').should('have.length', 2); - cy.get('.message').eq(1).should('contain', 'Prev message: None'); + cy.get('.step').should('have.length', 2); + cy.get('.step').eq(1).should('contain', 'Prev message: None'); submitMessage('Hello 4'); - cy.get('.message').should('have.length', 4); - cy.get('.message').eq(3).should('contain', 'Prev message: Hello 3'); + cy.get('.step').should('have.length', 4); + cy.get('.step').eq(3).should('contain', 'Prev message: Hello 3'); }); }); diff --git a/cypress/e2e/video_element/spec.cy.ts b/cypress/e2e/video_element/spec.cy.ts index 89d742c9cc..da33188e81 100644 --- a/cypress/e2e/video_element/spec.cy.ts +++ b/cypress/e2e/video_element/spec.cy.ts @@ -6,8 +6,8 @@ describe('video', () => { }); it('should be able to display a video element', () => { - cy.get('.message').should('have.length', 1); - cy.get('.message').eq(0).find('.inline-video').should('have.length', 1); + cy.get('.step').should('have.length', 1); + cy.get('.step').eq(0).find('.inline-video').should('have.length', 1); cy.get('video.inline-video') .then(($el) => { diff --git a/frontend/package.json b/frontend/package.json index 722f2a3579..766448d998 100644 --- a/frontend/package.json +++ b/frontend/package.json @@ -23,7 +23,7 @@ "lodash": "^4.17.21", "react": "^18.2.0", "react-dom": "^18.2.0", - "react-hot-toast": "^2.4.1", + "sonner": "^1.2.3", "react-hotkeys-hook": "^4.4.1", "react-router-dom": "^6.15.0", "react-speech-recognition": "^3.10.0", diff --git a/frontend/pnpm-lock.yaml b/frontend/pnpm-lock.yaml index ce9f8c6c1f..1530040dc8 100644 --- a/frontend/pnpm-lock.yaml +++ b/frontend/pnpm-lock.yaml @@ -38,9 +38,6 @@ dependencies: react-dom: specifier: ^18.2.0 version: 18.2.0(react@18.2.0) - react-hot-toast: - specifier: ^2.4.1 - version: 2.4.1(csstype@3.1.2)(react-dom@18.2.0)(react@18.2.0) react-hotkeys-hook: specifier: ^4.4.1 version: 4.4.1(react-dom@18.2.0)(react@18.2.0) @@ -56,6 +53,9 @@ dependencies: regenerator-runtime: specifier: ^0.14.0 version: 0.14.0 + sonner: + specifier: ^1.2.3 + version: 1.2.3(react-dom@18.2.0)(react@18.2.0) usehooks-ts: specifier: ^2.9.1 version: 2.9.1(react-dom@18.2.0)(react@18.2.0) @@ -97,11 +97,11 @@ devDependencies: packages: - /@babel/code-frame@7.22.13: - resolution: {integrity: sha512-XktuhWlJ5g+3TJXc5upd9Ks1HutSArik6jf2eAjYFyIOf4ej3RN+184cZbzDvbPnuTJIUhPKKJE3cIsYTiAT3w==} + /@babel/code-frame@7.23.5: + resolution: {integrity: sha512-CgH3s1a96LipHCmSUmYFPwY7MNx8C3avkq7i4Wl3cfa662ldtUe4VM1TPXX70pfmrlWTb6jLqTYrZyT2ZTJBgA==} engines: {node: '>=6.9.0'} dependencies: - '@babel/highlight': 7.22.20 + '@babel/highlight': 7.23.4 chalk: 2.4.2 dev: false @@ -109,11 +109,11 @@ packages: resolution: {integrity: sha512-0pYVBnDKZO2fnSPCrgM/6WMc7eS20Fbok+0r88fp+YtWVLZrp4CkafFGIp+W0VKw4a22sgebPT99y+FDNMdP4w==} engines: {node: '>=6.9.0'} dependencies: - '@babel/types': 7.23.3 + '@babel/types': 7.23.5 dev: false - /@babel/helper-string-parser@7.22.5: - resolution: {integrity: sha512-mM4COjgZox8U+JcXQwPijIZLElkgEpO5rsERVDJTc2qfCDfERyob6k5WegS14SX18IIjv+XD+GrqNumY5JRCDw==} + /@babel/helper-string-parser@7.23.4: + resolution: {integrity: sha512-803gmbQdqwdf4olxrX4AJyFBV/RTr3rSmOj0rKwesmzlfhYNDEs+/iOcznzpNWlJlIlTJC2QfPFcHB6DlzdVLQ==} engines: {node: '>=6.9.0'} dev: false @@ -122,8 +122,8 @@ packages: engines: {node: '>=6.9.0'} dev: false - /@babel/highlight@7.22.20: - resolution: {integrity: sha512-dkdMCN3py0+ksCgYmGG8jKeGA/8Tk+gJwSYYlFGxG5lmhfKNoAy004YpLxpS1W2J8m/EK2Ew+yOs9pVRwO89mg==} + /@babel/highlight@7.23.4: + resolution: {integrity: sha512-acGdbYSfp2WheJoJm/EBBBLh/ID8KDc64ISZ9DYtBmC8/Q204PZJLHyzeB5qMzJ5trcOkybd78M4x2KWsUq++A==} engines: {node: '>=6.9.0'} dependencies: '@babel/helper-validator-identifier': 7.22.20 @@ -131,18 +131,18 @@ packages: js-tokens: 4.0.0 dev: false - /@babel/runtime@7.23.2: - resolution: {integrity: sha512-mM8eg4yl5D6i3lu2QKPuPH4FArvJ8KhTofbE7jwMUv9KX5mBvwPAqnV3MlyBNqdp9RyRKP6Yck8TrfYrPvX3bg==} + /@babel/runtime@7.23.5: + resolution: {integrity: sha512-NdUTHcPe4C99WxPub+K9l9tK5/lV4UXIoaHSYgzco9BCyjKAAwzdBI+wWtYqHt7LJdbo74ZjRPJgzVweq1sz0w==} engines: {node: '>=6.9.0'} dependencies: regenerator-runtime: 0.14.0 dev: false - /@babel/types@7.23.3: - resolution: {integrity: sha512-OZnvoH2l8PK5eUvEcUyCt/sXgr/h+UWpVuBbOljwcrAgUl6lpchoQ++PHGyQy1AtYnVA6CEq3y5xeEI10brpXw==} + /@babel/types@7.23.5: + resolution: {integrity: sha512-ON5kSOJwVO6xXVRTvOI0eOnWe7VdUcIpsovGo9U/Br4Ie4UVFQTboO2cYnDhAGU6Fp+UxSiT+pMft0SMHfuq6w==} engines: {node: '>=6.9.0'} dependencies: - '@babel/helper-string-parser': 7.22.5 + '@babel/helper-string-parser': 7.23.4 '@babel/helper-validator-identifier': 7.22.20 to-fast-properties: 2.0.0 dev: false @@ -151,7 +151,7 @@ packages: resolution: {integrity: sha512-m4HEDZleaaCH+XgDDsPF15Ht6wTLsgDTeR3WYj9Q/k76JtWhrJjcP4+/XlG8LGT/Rol9qUfOIztXeA84ATpqPQ==} dependencies: '@babel/helper-module-imports': 7.22.15 - '@babel/runtime': 7.23.2 + '@babel/runtime': 7.23.5 '@emotion/hash': 0.9.1 '@emotion/memoize': 0.8.1 '@emotion/serialize': 1.1.2 @@ -196,7 +196,7 @@ packages: '@types/react': optional: true dependencies: - '@babel/runtime': 7.23.2 + '@babel/runtime': 7.23.5 '@emotion/babel-plugin': 11.11.0 '@emotion/cache': 11.11.0 '@emotion/serialize': 1.1.2 @@ -232,7 +232,7 @@ packages: '@types/react': optional: true dependencies: - '@babel/runtime': 7.23.2 + '@babel/runtime': 7.23.5 '@emotion/babel-plugin': 11.11.0 '@emotion/is-prop-valid': 1.2.1 '@emotion/react': 11.11.1(@types/react@18.2.0)(react@18.2.0) @@ -461,8 +461,8 @@ packages: dev: true optional: true - /@floating-ui/core@1.5.0: - resolution: {integrity: sha512-kK1h4m36DQ0UHGj5Ah4db7R0rHemTqqO0QLvUqi1/mUUp3LuAWbWxdxSIf/XsnH9VS6rRVPLJCncjRzUvyCLXg==} + /@floating-ui/core@1.5.1: + resolution: {integrity: sha512-QgcKYwzcc8vvZ4n/5uklchy8KVdjJwcOeI+HnnTNclJjs2nYsy23DOCf+sSV1kBwD9yDAoVKCkv/gEPzgQU3Pw==} dependencies: '@floating-ui/utils': 0.1.6 dev: false @@ -470,7 +470,7 @@ packages: /@floating-ui/dom@1.5.3: resolution: {integrity: sha512-ClAbQnEqJAKCJOEbbLo5IUlZHkNszqhuxS4fHAVxRPXPya6Ysf2G8KypnYcOTpx6I8xcgF9bbHb6g/2KpbV8qA==} dependencies: - '@floating-ui/core': 1.5.0 + '@floating-ui/core': 1.5.1 '@floating-ui/utils': 0.1.6 dev: false @@ -500,10 +500,10 @@ packages: '@types/react': optional: true dependencies: - '@babel/runtime': 7.23.2 + '@babel/runtime': 7.23.5 '@emotion/is-prop-valid': 1.2.1 - '@mui/types': 7.2.9(@types/react@18.2.0) - '@mui/utils': 5.14.18(@types/react@18.2.0)(react@18.2.0) + '@mui/types': 7.2.10(@types/react@18.2.0) + '@mui/utils': 5.14.19(@types/react@18.2.0)(react@18.2.0) '@popperjs/core': 2.11.8 '@types/react': 18.2.0 clsx: 1.2.1 @@ -524,10 +524,10 @@ packages: '@types/react': optional: true dependencies: - '@babel/runtime': 7.23.2 + '@babel/runtime': 7.23.5 '@floating-ui/react-dom': 2.0.4(react-dom@18.2.0)(react@18.2.0) - '@mui/types': 7.2.9(@types/react@18.2.0) - '@mui/utils': 5.14.18(@types/react@18.2.0)(react@18.2.0) + '@mui/types': 7.2.10(@types/react@18.2.0) + '@mui/utils': 5.14.19(@types/react@18.2.0)(react@18.2.0) '@popperjs/core': 2.11.8 '@types/react': 18.2.0 clsx: 2.0.0 @@ -536,8 +536,8 @@ packages: react-dom: 18.2.0(react@18.2.0) dev: false - /@mui/core-downloads-tracker@5.14.18: - resolution: {integrity: sha512-yFpF35fEVDV81nVktu0BE9qn2dD/chs7PsQhlyaV3EnTeZi9RZBuvoEfRym1/jmhJ2tcfeWXiRuHG942mQXJJQ==} + /@mui/core-downloads-tracker@5.14.19: + resolution: {integrity: sha512-y4JseIen5pmZs1n9hHy95HKKioKco8f6N2lford2AmjJigVJOv0KsU0qryiCpyuEUZmi/xCduVilHsK9DSkPcA==} dev: false /@mui/icons-material@5.14.9(@mui/material@5.14.10)(@types/react@18.2.0)(react@18.2.0): @@ -551,7 +551,7 @@ packages: '@types/react': optional: true dependencies: - '@babel/runtime': 7.23.2 + '@babel/runtime': 7.23.5 '@mui/material': 5.14.10(@emotion/react@11.11.1)(@emotion/styled@11.11.0)(@types/react@18.2.0)(react-dom@18.2.0)(react@18.2.0) '@types/react': 18.2.0 react: 18.2.0 @@ -575,14 +575,14 @@ packages: '@types/react': optional: true dependencies: - '@babel/runtime': 7.23.2 + '@babel/runtime': 7.23.5 '@emotion/react': 11.11.1(@types/react@18.2.0)(react@18.2.0) '@emotion/styled': 11.11.0(@emotion/react@11.11.1)(@types/react@18.2.0)(react@18.2.0) '@mui/base': 5.0.0-alpha.120(@types/react@18.2.0)(react-dom@18.2.0)(react@18.2.0) '@mui/material': 5.14.10(@emotion/react@11.11.1)(@emotion/styled@11.11.0)(@types/react@18.2.0)(react-dom@18.2.0)(react@18.2.0) - '@mui/system': 5.14.18(@emotion/react@11.11.1)(@emotion/styled@11.11.0)(@types/react@18.2.0)(react@18.2.0) - '@mui/types': 7.2.9(@types/react@18.2.0) - '@mui/utils': 5.14.18(@types/react@18.2.0)(react@18.2.0) + '@mui/system': 5.14.19(@emotion/react@11.11.1)(@emotion/styled@11.11.0)(@types/react@18.2.0)(react@18.2.0) + '@mui/types': 7.2.10(@types/react@18.2.0) + '@mui/utils': 5.14.19(@types/react@18.2.0)(react@18.2.0) '@types/react': 18.2.0 clsx: 1.2.1 prop-types: 15.8.1 @@ -608,14 +608,14 @@ packages: '@types/react': optional: true dependencies: - '@babel/runtime': 7.23.2 + '@babel/runtime': 7.23.5 '@emotion/react': 11.11.1(@types/react@18.2.0)(react@18.2.0) '@emotion/styled': 11.11.0(@emotion/react@11.11.1)(@types/react@18.2.0)(react@18.2.0) '@mui/base': 5.0.0-beta.16(@types/react@18.2.0)(react-dom@18.2.0)(react@18.2.0) - '@mui/core-downloads-tracker': 5.14.18 - '@mui/system': 5.14.18(@emotion/react@11.11.1)(@emotion/styled@11.11.0)(@types/react@18.2.0)(react@18.2.0) - '@mui/types': 7.2.9(@types/react@18.2.0) - '@mui/utils': 5.14.18(@types/react@18.2.0)(react@18.2.0) + '@mui/core-downloads-tracker': 5.14.19 + '@mui/system': 5.14.19(@emotion/react@11.11.1)(@emotion/styled@11.11.0)(@types/react@18.2.0)(react@18.2.0) + '@mui/types': 7.2.10(@types/react@18.2.0) + '@mui/utils': 5.14.19(@types/react@18.2.0)(react@18.2.0) '@types/react': 18.2.0 '@types/react-transition-group': 4.4.9 clsx: 2.0.0 @@ -627,8 +627,8 @@ packages: react-transition-group: 4.4.5(react-dom@18.2.0)(react@18.2.0) dev: false - /@mui/private-theming@5.14.18(@types/react@18.2.0)(react@18.2.0): - resolution: {integrity: sha512-WSgjqRlzfHU+2Rou3HlR2Gqfr4rZRsvFgataYO3qQ0/m6gShJN+lhVEvwEiJ9QYyVzMDvNpXZAcqp8Y2Vl+PAw==} + /@mui/private-theming@5.14.19(@types/react@18.2.0)(react@18.2.0): + resolution: {integrity: sha512-U9w39VpXLGVM8wZlUU/47YGTsBSk60ZQRRxQZtdqPfN1N7OVllQeN4cEKZKR8PjqqR3aYRcSciQ4dc6CttRoXQ==} engines: {node: '>=12.0.0'} peerDependencies: '@types/react': ^17.0.0 || ^18.0.0 @@ -637,15 +637,15 @@ packages: '@types/react': optional: true dependencies: - '@babel/runtime': 7.23.2 - '@mui/utils': 5.14.18(@types/react@18.2.0)(react@18.2.0) + '@babel/runtime': 7.23.5 + '@mui/utils': 5.14.19(@types/react@18.2.0)(react@18.2.0) '@types/react': 18.2.0 prop-types: 15.8.1 react: 18.2.0 dev: false - /@mui/styled-engine@5.14.18(@emotion/react@11.11.1)(@emotion/styled@11.11.0)(react@18.2.0): - resolution: {integrity: sha512-pW8bpmF9uCB5FV2IPk6mfbQCjPI5vGI09NOLhtGXPeph/4xIfC3JdIX0TILU0WcTs3aFQqo6s2+1SFgIB9rCXA==} + /@mui/styled-engine@5.14.19(@emotion/react@11.11.1)(@emotion/styled@11.11.0)(react@18.2.0): + resolution: {integrity: sha512-jtj/Pyn/bS8PM7NXdFNTHWZfE3p+vItO4/HoQbUeAv3u+cnWXcTBGHHY/xdIn446lYGFDczTh1YyX8G4Ts0Rtg==} engines: {node: '>=12.0.0'} peerDependencies: '@emotion/react': ^11.4.1 @@ -657,7 +657,7 @@ packages: '@emotion/styled': optional: true dependencies: - '@babel/runtime': 7.23.2 + '@babel/runtime': 7.23.5 '@emotion/cache': 11.11.0 '@emotion/react': 11.11.1(@types/react@18.2.0)(react@18.2.0) '@emotion/styled': 11.11.0(@emotion/react@11.11.1)(@types/react@18.2.0)(react@18.2.0) @@ -666,8 +666,8 @@ packages: react: 18.2.0 dev: false - /@mui/system@5.14.18(@emotion/react@11.11.1)(@emotion/styled@11.11.0)(@types/react@18.2.0)(react@18.2.0): - resolution: {integrity: sha512-hSQQdb3KF72X4EN2hMEiv8EYJZSflfdd1TRaGPoR7CIAG347OxCslpBUwWngYobaxgKvq6xTrlIl+diaactVww==} + /@mui/system@5.14.19(@emotion/react@11.11.1)(@emotion/styled@11.11.0)(@types/react@18.2.0)(react@18.2.0): + resolution: {integrity: sha512-4e3Q+2nx+vgEsd0h5ftxlZGB7XtkkPos/zWqCqnxUs1l/T70s0lF2YNrWHHdSQ7LgtBu0eQ0qweZG2pR7KwkAw==} engines: {node: '>=12.0.0'} peerDependencies: '@emotion/react': ^11.5.0 @@ -682,13 +682,13 @@ packages: '@types/react': optional: true dependencies: - '@babel/runtime': 7.23.2 + '@babel/runtime': 7.23.5 '@emotion/react': 11.11.1(@types/react@18.2.0)(react@18.2.0) '@emotion/styled': 11.11.0(@emotion/react@11.11.1)(@types/react@18.2.0)(react@18.2.0) - '@mui/private-theming': 5.14.18(@types/react@18.2.0)(react@18.2.0) - '@mui/styled-engine': 5.14.18(@emotion/react@11.11.1)(@emotion/styled@11.11.0)(react@18.2.0) - '@mui/types': 7.2.9(@types/react@18.2.0) - '@mui/utils': 5.14.18(@types/react@18.2.0)(react@18.2.0) + '@mui/private-theming': 5.14.19(@types/react@18.2.0)(react@18.2.0) + '@mui/styled-engine': 5.14.19(@emotion/react@11.11.1)(@emotion/styled@11.11.0)(react@18.2.0) + '@mui/types': 7.2.10(@types/react@18.2.0) + '@mui/utils': 5.14.19(@types/react@18.2.0)(react@18.2.0) '@types/react': 18.2.0 clsx: 2.0.0 csstype: 3.1.2 @@ -696,8 +696,8 @@ packages: react: 18.2.0 dev: false - /@mui/types@7.2.9(@types/react@18.2.0): - resolution: {integrity: sha512-k1lN/PolaRZfNsRdAqXtcR71sTnv3z/VCCGPxU8HfdftDkzi335MdJ6scZxvofMAd/K/9EbzCZTFBmlNpQVdCg==} + /@mui/types@7.2.10(@types/react@18.2.0): + resolution: {integrity: sha512-wX1vbDC+lzF7FlhT6A3ffRZgEoKWPF8VqRoTu4lZwouFX2t90KyCMsgepMw5DxLak1BSp/KP86CmtZttikb/gQ==} peerDependencies: '@types/react': ^17.0.0 || ^18.0.0 peerDependenciesMeta: @@ -707,8 +707,8 @@ packages: '@types/react': 18.2.0 dev: false - /@mui/utils@5.14.18(@types/react@18.2.0)(react@18.2.0): - resolution: {integrity: sha512-HZDRsJtEZ7WMSnrHV9uwScGze4wM/Y+u6pDVo+grUjt5yXzn+wI8QX/JwTHh9YSw/WpnUL80mJJjgCnWj2VrzQ==} + /@mui/utils@5.14.19(@types/react@18.2.0)(react@18.2.0): + resolution: {integrity: sha512-qAHvTXzk7basbyqPvhgWqN6JbmI2wLB/mf97GkSlz5c76MiKYV6Ffjvw9BjKZQ1YRb8rDX9kgdjRezOcoB91oQ==} engines: {node: '>=12.0.0'} peerDependencies: '@types/react': ^17.0.0 || ^18.0.0 @@ -717,8 +717,8 @@ packages: '@types/react': optional: true dependencies: - '@babel/runtime': 7.23.2 - '@types/prop-types': 15.7.10 + '@babel/runtime': 7.23.5 + '@types/prop-types': 15.7.11 '@types/react': 18.2.0 prop-types: 15.8.1 react: 18.2.0 @@ -734,8 +734,8 @@ packages: engines: {node: '>=14.0.0'} dev: false - /@swc/core-darwin-arm64@1.3.96: - resolution: {integrity: sha512-8hzgXYVd85hfPh6mJ9yrG26rhgzCmcLO0h1TIl8U31hwmTbfZLzRitFQ/kqMJNbIBCwmNH1RU2QcJnL3d7f69A==} + /@swc/core-darwin-arm64@1.3.100: + resolution: {integrity: sha512-XVWFsKe6ei+SsDbwmsuRkYck1SXRpO60Hioa4hoLwR8fxbA9eVp6enZtMxzVVMBi8ej5seZ4HZQeAWepbukiBw==} engines: {node: '>=10'} cpu: [arm64] os: [darwin] @@ -743,8 +743,8 @@ packages: dev: true optional: true - /@swc/core-darwin-x64@1.3.96: - resolution: {integrity: sha512-mFp9GFfuPg+43vlAdQZl0WZpZSE8sEzqL7sr/7Reul5McUHP0BaLsEzwjvD035ESfkY8GBZdLpMinblIbFNljQ==} + /@swc/core-darwin-x64@1.3.100: + resolution: {integrity: sha512-KF/MXrnH1nakm1wbt4XV8FS7kvqD9TGmVxeJ0U4bbvxXMvzeYUurzg3AJUTXYmXDhH/VXOYJE5N5RkwZZPs5iA==} engines: {node: '>=10'} cpu: [x64] os: [darwin] @@ -752,17 +752,8 @@ packages: dev: true optional: true - /@swc/core-linux-arm-gnueabihf@1.3.96: - resolution: {integrity: sha512-8UEKkYJP4c8YzYIY/LlbSo8z5Obj4hqcv/fUTHiEePiGsOddgGf7AWjh56u7IoN/0uEmEro59nc1ChFXqXSGyg==} - engines: {node: '>=10'} - cpu: [arm] - os: [linux] - requiresBuild: true - dev: true - optional: true - - /@swc/core-linux-arm64-gnu@1.3.96: - resolution: {integrity: sha512-c/IiJ0s1y3Ymm2BTpyC/xr6gOvoqAVETrivVXHq68xgNms95luSpbYQ28rqaZC8bQC8M5zdXpSc0T8DJu8RJGw==} + /@swc/core-linux-arm64-gnu@1.3.100: + resolution: {integrity: sha512-p8hikNnAEJrw5vHCtKiFT4hdlQxk1V7vqPmvUDgL/qe2menQDK/i12tbz7/3BEQ4UqUPnvwpmVn2d19RdEMNxw==} engines: {node: '>=10'} cpu: [arm64] os: [linux] @@ -770,8 +761,8 @@ packages: dev: true optional: true - /@swc/core-linux-arm64-musl@1.3.96: - resolution: {integrity: sha512-i5/UTUwmJLri7zhtF6SAo/4QDQJDH2fhYJaBIUhrICmIkRO/ltURmpejqxsM/ye9Jqv5zG7VszMC0v/GYn/7BQ==} + /@swc/core-linux-arm64-musl@1.3.100: + resolution: {integrity: sha512-BWx/0EeY89WC4q3AaIaBSGfQxkYxIlS3mX19dwy2FWJs/O+fMvF9oLk/CyJPOZzbp+1DjGeeoGFuDYpiNO91JA==} engines: {node: '>=10'} cpu: [arm64] os: [linux] @@ -779,8 +770,8 @@ packages: dev: true optional: true - /@swc/core-linux-x64-gnu@1.3.96: - resolution: {integrity: sha512-USdaZu8lTIkm4Yf9cogct/j5eqtdZqTgcTib4I+NloUW0E/hySou3eSyp3V2UAA1qyuC72ld1otXuyKBna0YKQ==} + /@swc/core-linux-x64-gnu@1.3.100: + resolution: {integrity: sha512-XUdGu3dxAkjsahLYnm8WijPfKebo+jHgHphDxaW0ovI6sTdmEGFDew7QzKZRlbYL2jRkUuuKuDGvD6lO5frmhA==} engines: {node: '>=10'} cpu: [x64] os: [linux] @@ -788,8 +779,8 @@ packages: dev: true optional: true - /@swc/core-linux-x64-musl@1.3.96: - resolution: {integrity: sha512-QYErutd+G2SNaCinUVobfL7jWWjGTI0QEoQ6hqTp7PxCJS/dmKmj3C5ZkvxRYcq7XcZt7ovrYCTwPTHzt6lZBg==} + /@swc/core-linux-x64-musl@1.3.100: + resolution: {integrity: sha512-PhoXKf+f0OaNW/GCuXjJ0/KfK9EJX7z2gko+7nVnEA0p3aaPtbP6cq1Ubbl6CMoPL+Ci3gZ7nYumDqXNc3CtLQ==} engines: {node: '>=10'} cpu: [x64] os: [linux] @@ -797,8 +788,8 @@ packages: dev: true optional: true - /@swc/core-win32-arm64-msvc@1.3.96: - resolution: {integrity: sha512-hjGvvAduA3Un2cZ9iNP4xvTXOO4jL3G9iakhFsgVhpkU73SGmK7+LN8ZVBEu4oq2SUcHO6caWvnZ881cxGuSpg==} + /@swc/core-win32-arm64-msvc@1.3.100: + resolution: {integrity: sha512-PwLADZN6F9cXn4Jw52FeP/MCLVHm8vwouZZSOoOScDtihjY495SSjdPnlosMaRSR4wJQssGwiD/4MbpgQPqbAw==} engines: {node: '>=10'} cpu: [arm64] os: [win32] @@ -806,8 +797,8 @@ packages: dev: true optional: true - /@swc/core-win32-ia32-msvc@1.3.96: - resolution: {integrity: sha512-Far2hVFiwr+7VPCM2GxSmbh3ikTpM3pDombE+d69hkedvYHYZxtTF+2LTKl/sXtpbUnsoq7yV/32c9R/xaaWfw==} + /@swc/core-win32-ia32-msvc@1.3.100: + resolution: {integrity: sha512-0f6nicKSLlDKlyPRl2JEmkpBV4aeDfRQg6n8mPqgL7bliZIcDahG0ej+HxgNjZfS3e0yjDxsNRa6sAqWU2Z60A==} engines: {node: '>=10'} cpu: [ia32] os: [win32] @@ -815,8 +806,8 @@ packages: dev: true optional: true - /@swc/core-win32-x64-msvc@1.3.96: - resolution: {integrity: sha512-4VbSAniIu0ikLf5mBX81FsljnfqjoVGleEkCQv4+zRlyZtO3FHoDPkeLVoy6WRlj7tyrRcfUJ4mDdPkbfTO14g==} + /@swc/core-win32-x64-msvc@1.3.100: + resolution: {integrity: sha512-b7J0rPoMkRTa3XyUGt8PwCaIBuYWsL2DqbirrQKRESzgCvif5iNpqaM6kjIjI/5y5q1Ycv564CB51YDpiS8EtQ==} engines: {node: '>=10'} cpu: [x64] os: [win32] @@ -824,8 +815,8 @@ packages: dev: true optional: true - /@swc/core@1.3.96: - resolution: {integrity: sha512-zwE3TLgoZwJfQygdv2SdCK9mRLYluwDOM53I+dT6Z5ZvrgVENmY3txvWDvduzkV+/8IuvrRbVezMpxcojadRdQ==} + /@swc/core@1.3.100: + resolution: {integrity: sha512-7dKgTyxJjlrMwFZYb1auj3Xq0D8ZBe+5oeIgfMlRU05doXZypYJe0LAk0yjj3WdbwYzpF+T1PLxwTWizI0pckw==} engines: {node: '>=10'} requiresBuild: true peerDependencies: @@ -837,16 +828,15 @@ packages: '@swc/counter': 0.1.2 '@swc/types': 0.1.5 optionalDependencies: - '@swc/core-darwin-arm64': 1.3.96 - '@swc/core-darwin-x64': 1.3.96 - '@swc/core-linux-arm-gnueabihf': 1.3.96 - '@swc/core-linux-arm64-gnu': 1.3.96 - '@swc/core-linux-arm64-musl': 1.3.96 - '@swc/core-linux-x64-gnu': 1.3.96 - '@swc/core-linux-x64-musl': 1.3.96 - '@swc/core-win32-arm64-msvc': 1.3.96 - '@swc/core-win32-ia32-msvc': 1.3.96 - '@swc/core-win32-x64-msvc': 1.3.96 + '@swc/core-darwin-arm64': 1.3.100 + '@swc/core-darwin-x64': 1.3.100 + '@swc/core-linux-arm64-gnu': 1.3.100 + '@swc/core-linux-arm64-musl': 1.3.100 + '@swc/core-linux-x64-gnu': 1.3.100 + '@swc/core-linux-x64-musl': 1.3.100 + '@swc/core-win32-arm64-msvc': 1.3.100 + '@swc/core-win32-ia32-msvc': 1.3.100 + '@swc/core-win32-x64-msvc': 1.3.100 dev: true /@swc/counter@0.1.2: @@ -873,8 +863,8 @@ packages: resolution: {integrity: sha512-dISoDXWWQwUquiKsyZ4Ng+HX2KsPL7LyHKHQwgGFEA3IaKac4Obd+h2a/a6waisAoepJlBcx9paWqjA8/HVjCw==} dev: false - /@types/prop-types@15.7.10: - resolution: {integrity: sha512-mxSnDQxPqsZxmeShFH+uwQ4kO4gcJcGahjjMFeLbKE95IAZiiZyiEepGZjtXJ7hN/yfu0bu9xN2ajcU0JcxX6A==} + /@types/prop-types@15.7.11: + resolution: {integrity: sha512-ga8y9v9uyeiLdpKddhxYQkxNDrfvuPrlFb0N1qnZZByvcElJaXthF1UhvCh9TLWJBEHeNtdnbysW7Y6Uq8CVng==} /@types/react-speech-recognition@3.9.2: resolution: {integrity: sha512-LS13Z4A8nluGWyT1NQncWoyaWARJdEojxmcRvaFDT9nTHpRkMgPjaYBJIc/9GBRYYFy8TQGaiCmUdH2g4M9INg==} @@ -891,12 +881,12 @@ packages: /@types/react@18.2.0: resolution: {integrity: sha512-0FLj93y5USLHdnhIhABk83rm8XEGA7kH3cr+YUlvxoUGp1xNt/DINUMvqPxLyOQMzLmZe8i4RTHbvb8MC7NmrA==} dependencies: - '@types/prop-types': 15.7.10 - '@types/scheduler': 0.16.6 + '@types/prop-types': 15.7.11 + '@types/scheduler': 0.16.8 csstype: 3.1.2 - /@types/scheduler@0.16.6: - resolution: {integrity: sha512-Vlktnchmkylvc9SnwwwozTv04L/e1NykF5vgoQ0XTmI8DD+wxfjQuHuvHS3p0r2jz2x2ghPs2h1FVeDirIteWA==} + /@types/scheduler@0.16.8: + resolution: {integrity: sha512-WZLiwShhwLRmeV6zH+GkbOFT6Z6VklCItrDioxUnv+u4Ll+8vKeFySoFyK/0ctcRpOmwAicELfmys1sDc/Rw+A==} /@types/uuid@9.0.3: resolution: {integrity: sha512-taHQQH/3ZyI3zP8M/puluDEIEvtQHVYcC6y3N8ijFtAd28+Ey/G4sg1u2gB01S8MwybLOKAp9/yCMu/uR5l3Ug==} @@ -907,7 +897,7 @@ packages: peerDependencies: vite: ^4 dependencies: - '@swc/core': 1.3.96 + '@swc/core': 1.3.100 vite: 4.4.9(@types/node@20.5.7) transitivePeerDependencies: - '@swc/helpers' @@ -924,7 +914,7 @@ packages: resolution: {integrity: sha512-Cg7TFGpIr01vOQNODXOOaGz2NpCU5gl8x1qJFbb6hbZxR7XrcE2vtbAsTAbJ7/xwJtUuJEw8K8Zr/AE0LHlesg==} engines: {node: '>=10', npm: '>=6'} dependencies: - '@babel/runtime': 7.23.2 + '@babel/runtime': 7.23.5 cosmiconfig: 7.1.0 resolve: 1.22.8 dev: false @@ -1001,7 +991,7 @@ packages: /dom-helpers@5.2.1: resolution: {integrity: sha512-nRCa7CK3VTrM2NmGkIy4cbK7IZlgBE/PYMn55rrXefr5xXDP0LdtfPnblFDoVdcAfslJ7or6iqAUnx0CCGIWQA==} dependencies: - '@babel/runtime': 7.23.2 + '@babel/runtime': 7.23.5 csstype: 3.1.2 dev: false @@ -1086,14 +1076,6 @@ packages: resolution: {integrity: sha512-uHJgbwAMwNFf5mLst7IWLNg14x1CkeqglJb/K3doi4dw6q2IvAAmM/Y81kevy83wP+Sst+nutFTYOGg3d1lsxg==} dev: true - /goober@2.1.13(csstype@3.1.2): - resolution: {integrity: sha512-jFj3BQeleOoy7t93E9rZ2de+ScC4lQICLwiAQmKMg9F6roKGaLSHoCDYKkWlSafg138jejvq/mTdvmnwDQgqoQ==} - peerDependencies: - csstype: ^3.0.10 - dependencies: - csstype: 3.1.2 - dev: false - /hamt_plus@1.0.2: resolution: {integrity: sha512-t2JXKaehnMb9paaYA7J0BX8QQAY8lwfQ9Gjf4pg/mk4krt+cmwmU652HOoWonf+7+EQV97ARPMhhVgU1ra2GhA==} dev: false @@ -1187,7 +1169,7 @@ packages: resolution: {integrity: sha512-ayCKvm/phCGxOkYRSCM82iDwct8/EonSEgCSxWxD7ve6jHggsFl4fZVQBPRNgQoKiuV/odhFrGzQXZwbifC8Rg==} engines: {node: '>=8'} dependencies: - '@babel/code-frame': 7.22.13 + '@babel/code-frame': 7.23.5 error-ex: 1.3.2 json-parse-even-better-errors: 2.3.1 lines-and-columns: 1.2.4 @@ -1241,20 +1223,6 @@ packages: resolution: {integrity: sha512-suNP+J1VU1MWFKcyt7RtjiSWUjvidmQSlqu+eHslq+342xCbGTYmC0mEhPCOHxlW0CywylOC1u2DFAT+bv4dBw==} dev: false - /react-hot-toast@2.4.1(csstype@3.1.2)(react-dom@18.2.0)(react@18.2.0): - resolution: {integrity: sha512-j8z+cQbWIM5LY37pR6uZR6D4LfseplqnuAO4co4u8917hBUvXlEqyP1ZzqVLcqoyUesZZv/ImreoCeHVDpE5pQ==} - engines: {node: '>=10'} - peerDependencies: - react: '>=16' - react-dom: '>=16' - dependencies: - goober: 2.1.13(csstype@3.1.2) - react: 18.2.0 - react-dom: 18.2.0(react@18.2.0) - transitivePeerDependencies: - - csstype - dev: false - /react-hotkeys-hook@4.4.1(react-dom@18.2.0)(react@18.2.0): resolution: {integrity: sha512-sClBMBioFEgFGYLTWWRKvhxcCx1DRznd+wkFHwQZspnRBkHTgruKIHptlK/U/2DPX8BhHoRGzpMVWUXMmdZlmw==} peerDependencies: @@ -1310,7 +1278,7 @@ packages: react: '>=16.6.0' react-dom: '>=16.6.0' dependencies: - '@babel/runtime': 7.23.2 + '@babel/runtime': 7.23.5 dom-helpers: 5.2.1 loose-envify: 1.4.0 prop-types: 15.8.1 @@ -1374,6 +1342,16 @@ packages: loose-envify: 1.4.0 dev: false + /sonner@1.2.3(react-dom@18.2.0)(react@18.2.0): + resolution: {integrity: sha512-LMr155izOFA8hudzuUVQT0H93VqmcF9ODP475YjjC/4INESYWN1/ioC5SYRG20jmDmwuQDR8ugP7y6ELghT6JQ==} + peerDependencies: + react: ^18.0.0 + react-dom: ^18.0.0 + dependencies: + react: 18.2.0 + react-dom: 18.2.0(react@18.2.0) + dev: false + /source-map-js@1.0.2: resolution: {integrity: sha512-R0XvVJ9WusLiqTCEiGCmICCMplcCkIwwR11mOSD9CR5u+IXYdiseeEuXCVAjS54zqwkLcPNnmU4OeJ6tUrWhDw==} engines: {node: '>=0.10.0'} diff --git a/frontend/src/App.tsx b/frontend/src/App.tsx index e93623a243..204675d71f 100644 --- a/frontend/src/App.tsx +++ b/frontend/src/App.tsx @@ -1,10 +1,10 @@ -import { wsEndpoint } from 'api'; +import { apiClient } from 'api'; import { useAuth } from 'api/auth'; import { useEffect } from 'react'; -import { Toaster } from 'react-hot-toast'; import { RouterProvider } from 'react-router-dom'; import { useRecoilValue } from 'recoil'; import { router } from 'router'; +import { Toaster } from 'sonner'; import { Box, GlobalStyles } from '@mui/material'; import { Theme, ThemeProvider } from '@mui/material/styles'; @@ -89,7 +89,7 @@ function App() { return; } else { connect({ - wsEndpoint, + client: apiClient, userEnv, accessToken }); @@ -109,19 +109,14 @@ function App() { }} /> diff --git a/frontend/src/api/index.ts b/frontend/src/api/index.ts index 8aa90c263f..0ef3e9fa33 100644 --- a/frontend/src/api/index.ts +++ b/frontend/src/api/index.ts @@ -1,4 +1,4 @@ -import toast from 'react-hot-toast'; +import { toast } from 'sonner'; import { ChainlitAPI, ClientError } from '@chainlit/react-client'; diff --git a/frontend/src/components/atoms/buttons/progressIconButton.tsx b/frontend/src/components/atoms/buttons/progressIconButton.tsx new file mode 100644 index 0000000000..9931ccfbaf --- /dev/null +++ b/frontend/src/components/atoms/buttons/progressIconButton.tsx @@ -0,0 +1,32 @@ +import CircularProgress from '@mui/material/CircularProgress'; +import IconButton, { IconButtonProps } from '@mui/material/IconButton'; + +interface Props extends IconButtonProps { + progress: number; + children: React.ReactNode; +} + +export default function CircularProgressIconButton({ + progress, + children, + ...iconButtonProps +}: Props) { + return ( +
+ {children} + {progress < 100 && ( + + )} +
+ ); +} diff --git a/frontend/src/components/atoms/buttons/userButton/avatar.tsx b/frontend/src/components/atoms/buttons/userButton/avatar.tsx index bb7c35a94c..50dacffe0b 100644 --- a/frontend/src/components/atoms/buttons/userButton/avatar.tsx +++ b/frontend/src/components/atoms/buttons/userButton/avatar.tsx @@ -14,9 +14,9 @@ export default function UserAvatar() { bgcolor: 'primary.main', color: 'primary.contrastText' }} - src={user.image || undefined} + src={user.metadata?.image || undefined} > - {user.username?.[0]} + {user.identifier?.[0]} ); } else { diff --git a/frontend/src/components/atoms/buttons/userButton/menu.tsx b/frontend/src/components/atoms/buttons/userButton/menu.tsx index a46574aaaa..df1585d7e5 100644 --- a/frontend/src/components/atoms/buttons/userButton/menu.tsx +++ b/frontend/src/components/atoms/buttons/userButton/menu.tsx @@ -36,7 +36,7 @@ export default function UserMenu({ anchorEl, open, handleClose }: Props) { {user.id} - {user.username} + {user.identifier} ); diff --git a/frontend/src/components/molecules/attachments.tsx b/frontend/src/components/molecules/attachments.tsx new file mode 100644 index 0000000000..dcdbf94240 --- /dev/null +++ b/frontend/src/components/molecules/attachments.tsx @@ -0,0 +1,99 @@ +import { useRecoilValue } from 'recoil'; + +import Close from '@mui/icons-material/Close'; +import Box from '@mui/material/Box'; +import IconButton from '@mui/material/IconButton'; +import Stack from '@mui/material/Stack'; +import Tooltip from '@mui/material/Tooltip'; + +import { Attachment } from '@chainlit/react-components'; + +import CircularProgressIconButton from 'components/atoms/buttons/progressIconButton'; + +import { attachmentsState } from 'state/chat'; + +const Attachments = (): JSX.Element => { + const attachments = useRecoilValue(attachmentsState); + + if (attachments.length === 0) return <>; + + return ( + + {attachments.map((attachment) => { + const showProgress = !attachment.uploaded && attachment.cancel; + + const progress = showProgress ? ( + + + `1px solid ${theme.palette.divider}`, + '&:hover': { + backgroundColor: 'background.default' + } + }} + > + + + + + ) : null; + + const remove = + !showProgress && attachment.remove ? ( + + `1px solid ${theme.palette.divider}`, + '&:hover': { + backgroundColor: 'background.default' + } + }} + onClick={attachment.remove} + > + + + + ) : null; + + return ( + + {progress} + {remove} + + ); + })} + + ); +}; + +export { Attachments }; diff --git a/frontend/src/components/molecules/tasklist/TaskList.tsx b/frontend/src/components/molecules/tasklist/TaskList.tsx index cc9d90dbf1..2dbfc1f14d 100644 --- a/frontend/src/components/molecules/tasklist/TaskList.tsx +++ b/frontend/src/components/molecules/tasklist/TaskList.tsx @@ -1,3 +1,5 @@ +import { useFetch } from 'usehooks-ts'; + import { Box, Chip, List, Theme, useTheme } from '@mui/material'; import { useChatData } from '@chainlit/react-client'; @@ -50,18 +52,22 @@ const TaskList = ({ isMobile }: { isMobile: boolean }) => { const theme = useTheme(); const { tasklists } = useChatData(); - let content: ITaskList | null = null; const tasklist = tasklists[tasklists.length - 1]; - try { - if (tasklist?.content) { - content = JSON.parse(tasklist.content); - } - } catch (e) { - console.error(e); - content = null; + const url = tasklist?.url; + + const { data, error } = useFetch(url); + + if (!url) return null; + + if (!data && !error) { + return
Loading...
; + } else if (error) { + return
An error occured
; } + const content = data as ITaskList; + if (!content) { return null; } diff --git a/frontend/src/components/organisms/chat/Messages/container.tsx b/frontend/src/components/organisms/chat/Messages/container.tsx index 911a589d12..ee2e2d80f2 100644 --- a/frontend/src/components/organisms/chat/Messages/container.tsx +++ b/frontend/src/components/organisms/chat/Messages/container.tsx @@ -1,20 +1,19 @@ import { apiClient } from 'api'; import { memo, useCallback, useMemo } from 'react'; -import toast from 'react-hot-toast'; import { useNavigate } from 'react-router-dom'; import { useRecoilValue, useSetRecoilState } from 'recoil'; +import { toast } from 'sonner'; import { IAction, IAsk, IAvatarElement, + IFeedback, IFunction, - IMessage, IMessageElement, + IStep, ITool, - accessTokenState, - messagesState, - updateMessageById + useChatInteract } from '@chainlit/react-client'; import { MessageContainer as CMessageContainer } from '@chainlit/react-components'; @@ -28,9 +27,14 @@ interface Props { actions: IAction[]; elements: IMessageElement[]; avatars: IAvatarElement[]; - messages: IMessage[]; + messages: IStep[]; askUser?: IAsk; autoScroll?: boolean; + onFeedbackUpdated: ( + message: IStep, + onSuccess: () => void, + feedback: IFeedback + ) => void; callAction?: (action: IAction) => void; setAutoScroll?: (autoScroll: boolean) => void; } @@ -44,29 +48,36 @@ const MessageContainer = memo( autoScroll, elements, messages, + onFeedbackUpdated, callAction, setAutoScroll }: Props) => { - const accessToken = useRecoilValue(accessTokenState); const appSettings = useRecoilValue(settingsState); const projectSettings = useRecoilValue(projectSettingsState); const setPlayground = useSetRecoilState(playgroundState); - const setMessages = useSetRecoilState(messagesState); const setSideView = useSetRecoilState(sideViewState); const highlightedMessage = useRecoilValue(highlightMessage); + const { uploadFile: _uploadFile } = useChatInteract(); + + const uploadFile = useCallback( + (file: File, onProgress: (progress: number) => void) => { + return _uploadFile(apiClient, file, onProgress); + }, + [_uploadFile] + ); const enableFeedback = !!projectSettings?.dataPersistence; const navigate = useNavigate(); const onPlaygroundButtonClick = useCallback( - (message: IMessage) => { + (message: IStep) => { setPlayground((old) => { + const generation = message.generation; let functions = - (message.prompt?.settings?.functions as unknown as IFunction[]) || - []; + (generation?.settings?.functions as unknown as IFunction[]) || []; const tools = - (message.prompt?.settings?.tools as unknown as ITool[]) || []; + (generation?.settings?.tools as unknown as ITool[]) || []; if (tools.length) { functions = [ ...functions, @@ -77,15 +88,15 @@ const MessageContainer = memo( } return { ...old, - prompt: message.prompt + generation: generation ? { - ...message.prompt, + ...generation, functions } : undefined, - originalPrompt: message.prompt + originalGeneration: generation ? { - ...message.prompt, + ...generation, functions } : undefined @@ -95,44 +106,6 @@ const MessageContainer = memo( [setPlayground] ); - const onFeedbackUpdated = useCallback( - async ( - message: IMessage, - feedback: number, - onSuccess: () => void, - feedbackComment?: string - ) => { - try { - await toast.promise( - apiClient.setHumanFeedback( - message.id, - feedback, - feedbackComment, - accessToken - ), - { - loading: 'Updating...', - success: 'Feedback updated!', - error: (err) => { - return {err.message}; - } - } - ); - setMessages((prev) => - updateMessageById(prev, message.id, { - ...message, - humanFeedback: feedback, - humanFeedbackComment: feedbackComment - }) - ); - onSuccess(); - } catch (err) { - console.log(err); - } - }, - [] - ); - const onElementRefClick = useCallback( (element: IMessageElement) => { let path = `/element/${element.id}`; @@ -142,8 +115,8 @@ const MessageContainer = memo( return; } - if (element.conversationId) { - path += `?conversation=${element.conversationId}`; + if (element.threadId) { + path += `?thread=${element.threadId}`; } return navigate(element.display === 'page' ? path : '#'); @@ -174,6 +147,7 @@ const MessageContainer = memo( // This prevents unnecessary re-renders of children components when no props have changed. const memoizedContext = useMemo(() => { return { + uploadFile, askUser, allowHtml: projectSettings?.features?.unsafe_allow_html, latex: projectSettings?.features?.latex, diff --git a/frontend/src/components/organisms/chat/Messages/index.tsx b/frontend/src/components/organisms/chat/Messages/index.tsx index 8ad0a6d5da..91a66a0873 100644 --- a/frontend/src/components/organisms/chat/Messages/index.tsx +++ b/frontend/src/components/organisms/chat/Messages/index.tsx @@ -1,4 +1,15 @@ +import { apiClient } from 'api'; +import { useCallback } from 'react'; +import { useRecoilValue, useSetRecoilState } from 'recoil'; +import { toast } from 'sonner'; + import { + IAction, + IFeedback, + IStep, + accessTokenState, + messagesState, + updateMessageById, useChatData, useChatInteract, useChatMessages, @@ -25,6 +36,63 @@ const Messages = ({ const { messages } = useChatMessages(); const { callAction } = useChatInteract(); const { idToResume } = useChatSession(); + const accessToken = useRecoilValue(accessTokenState); + const setMessages = useSetRecoilState(messagesState); + + const callActionWithToast = useCallback( + (action: IAction) => { + const promise = callAction(action); + if (promise) { + toast.promise(promise, { + loading: `Running ${action.name}`, + success: (res) => { + if (res.response) { + return res.response; + } else { + return `${action.name} executed successfully`; + } + }, + error: (res) => { + if (res.response) { + return res.response; + } else { + return `${action.name} failed`; + } + } + }); + } + }, + [callAction] + ); + + const onFeedbackUpdated = useCallback( + async (message: IStep, onSuccess: () => void, feedback: IFeedback) => { + try { + toast.promise(apiClient.setFeedback(feedback, accessToken), { + loading: 'Updating', + success: (res) => { + setMessages((prev) => + updateMessageById(prev, message.id, { + ...message, + feedback: { + ...feedback, + id: res.feedbackId + } + }) + ); + onSuccess(); + return 'Feedback updated!'; + }, + error: (err) => { + return {err.message}; + } + }); + } catch (err) { + console.log(err); + } + }, + [] + ); return !idToResume && !messages.length && @@ -43,7 +111,8 @@ const Messages = ({ elements={elements} messages={messages} autoScroll={autoScroll} - callAction={callAction} + onFeedbackUpdated={onFeedbackUpdated} + callAction={callActionWithToast} setAutoScroll={setAutoScroll} /> ); diff --git a/frontend/src/components/organisms/chat/history/index.tsx b/frontend/src/components/organisms/chat/history/index.tsx index b633231662..27a2ebd252 100644 --- a/frontend/src/components/organisms/chat/history/index.tsx +++ b/frontend/src/components/organisms/chat/history/index.tsx @@ -14,18 +14,18 @@ import { Typography } from '@mui/material'; -import { MessageHistory } from '@chainlit/react-client'; +import { UserInput } from '@chainlit/react-client'; import { grey } from '@chainlit/react-components/theme'; -import { chatHistoryState } from 'state/chatHistory'; +import { inputHistoryState } from 'state/userInputHistory'; interface Props { disabled?: boolean; onClick: (content: string) => void; } -function buildHistory(historyMessages: MessageHistory[]) { - const history: Record< +function buildInputHistory(userInputs: UserInput[]) { + const inputHistory: Record< string, { key: number | string; @@ -34,48 +34,48 @@ function buildHistory(historyMessages: MessageHistory[]) { }[] > = {}; - const reversedHistory = cloneDeep(historyMessages).reverse(); + const reversedHistory = cloneDeep(userInputs).reverse(); - reversedHistory?.forEach((hm) => { - const { createdAt, content } = hm; + reversedHistory?.forEach((userInput) => { + const { createdAt, content } = userInput; const dateOptions: Intl.DateTimeFormatOptions = { day: 'numeric', month: 'numeric', year: 'numeric' }; const date = new Date(createdAt).toLocaleDateString(undefined, dateOptions); - if (!history[date]) { - history[date] = []; + if (!inputHistory[date]) { + inputHistory[date] = []; } const timeOptions: Intl.DateTimeFormatOptions = { hour: 'numeric', minute: 'numeric' }; - history[date].push({ + inputHistory[date].push({ key: createdAt, hour: new Date(createdAt).toLocaleTimeString(undefined, timeOptions), content: content }); }); - return history; + return inputHistory; } -export default function HistoryButton({ disabled, onClick }: Props) { - const [chatHistory, setChatHistory] = useRecoilState(chatHistoryState); +export default function InputHistoryButton({ disabled, onClick }: Props) { + const [inputHistory, setInputHistory] = useRecoilState(inputHistoryState); const ref = useRef(); const [anchorEl, setAnchorEl] = useState(null); - if (chatHistory.open && !anchorEl) { + if (inputHistory.open && !anchorEl) { if (ref.current) { setAnchorEl(ref.current); } } const toggleChatHistoryMenu = (open: boolean) => - setChatHistory((old) => ({ ...old, open })); + setInputHistory((old) => ({ ...old, open })); const header = ( // @ts-ignore @@ -91,10 +91,10 @@ export default function HistoryButton({ disabled, onClick }: Props) { color="text.primary" sx={{ fontSize: '14px', fontWeight: 700 }} > - Last messages + Last inputs setChatHistory((old) => ({ ...old, messages: [] }))} + onClick={() => setInputHistory((old) => ({ ...old, inputs: [] }))} > @@ -102,7 +102,7 @@ export default function HistoryButton({ disabled, onClick }: Props) { ); const empty = - chatHistory?.messages.length === 0 ? ( + inputHistory?.inputs.length === 0 ? ( // @ts-ignore
) : null; - const loading = !chatHistory.messages ? ( + const loading = !inputHistory.inputs ? ( // @ts-ignore
{ menuEls.push( // @ts-ignore @@ -207,7 +207,7 @@ export default function HistoryButton({ disabled, onClick }: Props) { toggleChatHistoryMenu(false)} PaperProps={{ sx: { @@ -244,7 +244,7 @@ export default function HistoryButton({ disabled, onClick }: Props) { toggleChatHistoryMenu(!chatHistory.open)} + onClick={() => toggleChatHistoryMenu(!inputHistory.open)} ref={ref} > diff --git a/frontend/src/components/organisms/chat/index.tsx b/frontend/src/components/organisms/chat/index.tsx index 8d5718890b..b29efa3f5f 100644 --- a/frontend/src/components/organisms/chat/index.tsx +++ b/frontend/src/components/organisms/chat/index.tsx @@ -1,14 +1,15 @@ +import { apiClient } from 'api'; import { useCallback, useEffect, useMemo, useState } from 'react'; -import toast from 'react-hot-toast'; import { useRecoilValue, useSetRecoilState } from 'recoil'; +import { toast } from 'sonner'; import { v4 as uuidv4 } from 'uuid'; import { Alert, Box } from '@mui/material'; import { - IFileResponse, - conversationsHistoryState, - useChatData + threadHistoryState, + useChatData, + useChatInteract } from '@chainlit/react-client'; import { ErrorBoundary, useUpload } from '@chainlit/react-components'; @@ -16,7 +17,7 @@ import SideView from 'components/atoms/element/sideView'; import ChatProfiles from 'components/molecules/chatProfiles'; import { TaskList } from 'components/molecules/tasklist/TaskList'; -import { attachmentsState } from 'state/chat'; +import { IAttachment, attachmentsState } from 'state/chat'; import { projectSettingsState, sideViewState } from 'state/project'; import Messages from './Messages'; @@ -26,24 +27,79 @@ import InputBox from './inputBox'; const Chat = () => { const projectSettings = useRecoilValue(projectSettingsState); const setAttachments = useSetRecoilState(attachmentsState); - const setConversations = useSetRecoilState(conversationsHistoryState); + const setThreads = useSetRecoilState(threadHistoryState); const sideViewElement = useRecoilValue(sideViewState); const [autoScroll, setAutoScroll] = useState(true); const { error, disabled } = useChatData(); + const { uploadFile } = useChatInteract(); - const fileSpec = useMemo(() => ({ max_size_mb: 20 }), []); + const fileSpec = useMemo(() => ({ max_size_mb: 500 }), []); - const onFileUpload = useCallback((payloads: IFileResponse[]) => { - const fileElements = payloads.map((file) => ({ - id: uuidv4(), - type: 'file' as const, - display: 'inline' as const, - name: file.name, - mime: file.type, - content: file.content - })); - setAttachments((prev) => prev.concat(fileElements)); + const onFileUpload = useCallback((payloads: File[]) => { + const attachements: IAttachment[] = payloads.map((file) => { + const id = uuidv4(); + + const { xhr, promise } = uploadFile(apiClient, file, (progress) => { + setAttachments((prev) => + prev.map((attachment) => { + if (attachment.id === id) { + return { + ...attachment, + uploadProgress: progress + }; + } + return attachment; + }) + ); + }); + + promise + .then((res) => { + setAttachments((prev) => + prev.map((attachment) => { + if (attachment.id === id) { + return { + ...attachment, + // Update with the server ID + serverId: res.id, + uploaded: true, + uploadProgress: 100, + cancel: undefined + }; + } + return attachment; + }) + ); + }) + .catch((error) => { + toast.error(`Failed to upload ${file.name}: ${error.message}`); + setAttachments((prev) => + prev.filter((attachment) => attachment.id !== id) + ); + }); + + return { + id, + type: file.type, + name: file.name, + size: file.size, + uploadProgress: 0, + cancel: () => { + toast.info(`Cancelled upload of ${file.name}`); + xhr.abort(); + setAttachments((prev) => + prev.filter((attachment) => attachment.id !== id) + ); + }, + remove: () => { + setAttachments((prev) => + prev.filter((attachment) => attachment.id !== id) + ); + } + }; + }); + setAttachments((prev) => prev.concat(attachements)); }, []); const onFileUploadError = useCallback( @@ -59,9 +115,9 @@ const Chat = () => { }); useEffect(() => { - setConversations((prev) => ({ + setThreads((prev) => ({ ...prev, - currentConversationId: undefined + currentThreadId: undefined })); }, []); diff --git a/frontend/src/components/organisms/chat/inputBox/UploadButton.tsx b/frontend/src/components/organisms/chat/inputBox/UploadButton.tsx index 59efbd8257..851f8d65a3 100644 --- a/frontend/src/components/organisms/chat/inputBox/UploadButton.tsx +++ b/frontend/src/components/organisms/chat/inputBox/UploadButton.tsx @@ -3,7 +3,7 @@ import { useRecoilValue } from 'recoil'; import AttachFile from '@mui/icons-material/AttachFile'; import { IconButton, Tooltip } from '@mui/material'; -import { FileSpec, IFileResponse } from '@chainlit/react-client'; +import { FileSpec } from '@chainlit/react-client'; import { useUpload } from '@chainlit/react-components'; import { projectSettingsState } from 'state/project'; @@ -11,7 +11,7 @@ import { projectSettingsState } from 'state/project'; type Props = { disabled?: boolean; fileSpec: FileSpec; - onFileUpload: (files: IFileResponse[]) => void; + onFileUpload: (files: File[]) => void; onFileUploadError: (error: string) => void; }; @@ -25,21 +25,21 @@ const UploadButton = ({ const upload = useUpload({ spec: fileSpec, - onResolved: (payloads: IFileResponse[]) => onFileUpload(payloads), + onResolved: (payloads: File[]) => onFileUpload(payloads), onError: onFileUploadError, options: { noDrag: true } }); if (!upload || !pSettings?.features?.multi_modal) return null; - const { getRootProps, getInputProps, uploading } = upload; + const { getRootProps, getInputProps } = upload; return ( diff --git a/frontend/src/components/organisms/chat/inputBox/index.tsx b/frontend/src/components/organisms/chat/inputBox/index.tsx index d594916343..7d3239f723 100644 --- a/frontend/src/components/organisms/chat/inputBox/index.tsx +++ b/frontend/src/components/organisms/chat/inputBox/index.tsx @@ -5,16 +5,11 @@ import { v4 as uuidv4 } from 'uuid'; import { Box } from '@mui/material'; -import { - FileSpec, - IFileElement, - IFileResponse, - IMessage, - useChatInteract -} from '@chainlit/react-client'; - -import { chatHistoryState } from 'state/chatHistory'; +import { FileSpec, IStep, useChatInteract } from '@chainlit/react-client'; + +import { IAttachment } from 'state/chat'; import { IProjectSettings } from 'state/project'; +import { inputHistoryState } from 'state/userInputHistory'; import StopButton from '../stopButton'; import Input from './input'; @@ -22,7 +17,7 @@ import WaterMark from './waterMark'; interface Props { fileSpec: FileSpec; - onFileUpload: (payload: IFileResponse[]) => void; + onFileUpload: (payload: File[]) => void; onFileUploadError: (error: string) => void; setAutoScroll: (autoScroll: boolean) => void; projectSettings?: IProjectSettings; @@ -36,52 +31,58 @@ const InputBox = memo( setAutoScroll, projectSettings }: Props) => { - const setChatHistory = useSetRecoilState(chatHistoryState); + const setInputHistory = useSetRecoilState(inputHistoryState); const { user } = useAuth(); const { sendMessage, replyMessage } = useChatInteract(); // const tokenCount = useRecoilValue(tokenCountState); const onSubmit = useCallback( - async (msg: string, files?: IFileElement[]) => { - const message: IMessage = { + async (msg: string, attachments?: IAttachment[]) => { + const message: IStep = { + threadId: '', id: uuidv4(), - author: user?.username || 'User', - authorIsUser: true, - content: msg, + name: user?.identifier || 'User', + type: 'user_message', + output: msg, createdAt: new Date().toISOString() }; - setChatHistory((old) => { + setInputHistory((old) => { const MAX_SIZE = 50; - const messages = [...(old.messages || [])]; - messages.push({ + const inputs = [...(old.inputs || [])]; + inputs.push({ content: msg, createdAt: new Date().getTime() }); return { ...old, - messages: - messages.length > MAX_SIZE - ? messages.slice(messages.length - MAX_SIZE) - : messages + inputs: + inputs.length > MAX_SIZE + ? inputs.slice(inputs.length - MAX_SIZE) + : inputs }; }); + const fileReferences = attachments + ?.filter((a) => !!a.serverId) + .map((a) => ({ id: a.serverId! })); + setAutoScroll(true); - sendMessage(message, files); + sendMessage(message, fileReferences); }, [user, projectSettings, sendMessage] ); const onReply = useCallback( async (msg: string) => { - const message = { + const message: IStep = { + threadId: '', id: uuidv4(), - author: user?.username || 'User', - authorIsUser: true, - content: msg, + name: user?.identifier || 'User', + type: 'user_message', + output: msg, createdAt: new Date().toISOString() }; diff --git a/frontend/src/components/organisms/chat/inputBox/input.tsx b/frontend/src/components/organisms/chat/inputBox/input.tsx index 45a7ee1224..15793ab1b6 100644 --- a/frontend/src/components/organisms/chat/inputBox/input.tsx +++ b/frontend/src/components/organisms/chat/inputBox/input.tsx @@ -7,28 +7,23 @@ import TuneIcon from '@mui/icons-material/Tune'; import { Box, IconButton, Stack, TextField } from '@mui/material'; import InputAdornment from '@mui/material/InputAdornment'; -import { - FileSpec, - IFileElement, - IFileResponse, - useChatData -} from '@chainlit/react-client'; -import { Attachments } from '@chainlit/react-components'; +import { FileSpec, useChatData } from '@chainlit/react-client'; +import { Attachments } from 'components/molecules/attachments'; import HistoryButton from 'components/organisms/chat/history'; -import { attachmentsState } from 'state/chat'; -import { chatHistoryState } from 'state/chatHistory'; +import { IAttachment, attachmentsState } from 'state/chat'; import { chatSettingsOpenState, projectSettingsState } from 'state/project'; +import { inputHistoryState } from 'state/userInputHistory'; import UploadButton from './UploadButton'; import SpeechButton from './speechButton'; interface Props { fileSpec: FileSpec; - onFileUpload: (payload: IFileResponse[]) => void; + onFileUpload: (payload: File[]) => void; onFileUploadError: (error: string) => void; - onSubmit: (message: string, files?: IFileElement[]) => void; + onSubmit: (message: string, attachments?: IAttachment[]) => void; onReply: (message: string) => void; } @@ -43,13 +38,20 @@ function getLineCount(el: HTMLDivElement) { const Input = memo( ({ fileSpec, onFileUpload, onFileUploadError, onSubmit, onReply }: Props) => { - const [fileElements, setFileElements] = useRecoilState(attachmentsState); + const [attachments, setAttachments] = useRecoilState(attachmentsState); const [pSettings] = useRecoilState(projectSettingsState); - const setChatHistory = useSetRecoilState(chatHistoryState); + const setInputHistory = useSetRecoilState(inputHistoryState); const setChatSettingsOpen = useSetRecoilState(chatSettingsOpenState); const ref = useRef(null); - const { loading, askUser, chatSettingsInputs, disabled } = useChatData(); + const { + loading, + askUser, + chatSettingsInputs, + disabled: _disabled + } = useChatData(); + + const disabled = _disabled || !!attachments.find((a) => !a.uploaded); const [value, setValue] = useState(''); const [isComposing, setIsComposing] = useState(false); @@ -64,21 +66,7 @@ const Input = memo( if (item.kind === 'file') { const file = item.getAsFile(); if (file) { - const reader = new FileReader(); - reader.onload = function (e) { - const content = e.target?.result as ArrayBuffer; - if (content) { - onFileUpload([ - { - name: file.name, - type: file.type, - content, - size: file.size - } - ]); - } - }; - reader.readAsArrayBuffer(file); + onFileUpload([file]); } } }); @@ -111,17 +99,17 @@ const Input = memo( if (askUser) { onReply(value); } else { - onSubmit(value, fileElements); + onSubmit(value, attachments); } - setFileElements([]); + setAttachments([]); setValue(''); }, [ value, disabled, setValue, askUser, - fileElements, - setFileElements, + attachments, + setAttachments, onSubmit ]); @@ -135,11 +123,11 @@ const Input = memo( } else if (e.key === 'ArrowUp') { const lineCount = getLineCount(e.currentTarget as HTMLDivElement); if (lineCount <= 1) { - setChatHistory((old) => ({ ...old, open: true })); + setInputHistory((old) => ({ ...old, open: true })); } } }, - [submit, setChatHistory, isComposing] + [submit, setInputHistory, isComposing] ); const onHistoryClick = useCallback((content: string) => { @@ -201,7 +189,7 @@ const Input = memo( } }} > - {fileElements.length > 0 ? ( + {attachments.length > 0 ? ( - + ) : null} diff --git a/frontend/src/components/organisms/conversationsHistory/Conversation.tsx b/frontend/src/components/organisms/conversationsHistory/Conversation.tsx deleted file mode 100644 index c392ac5047..0000000000 --- a/frontend/src/components/organisms/conversationsHistory/Conversation.tsx +++ /dev/null @@ -1,98 +0,0 @@ -import { Link } from 'react-router-dom'; - -import { Alert, Box, Button, Skeleton, Stack } from '@mui/material'; - -import { - IAction, - IConversation, - IMessageElement, - nestMessages -} from '@chainlit/react-client'; - -import SideView from 'components/atoms/element/sideView'; -import MessageContainer from 'components/organisms/chat/Messages/container'; - -type ConversationProps = { - conversation?: IConversation; - error?: Error; - isLoading?: boolean; -}; - -const Conversation = ({ - conversation, - error, - isLoading -}: ConversationProps) => { - if (isLoading) { - return [1, 2, 3].map((index) => ( - - - - - - - - )); - } - - if (!conversation || error) { - return null; - } - - const elements = conversation.elements; - const actions: IAction[] = []; - - return ( - - - - - Go back to chat - - } - > - This conversation was created on{' '} - {new Intl.DateTimeFormat().format(conversation.createdAt as number)} - . - - - - - - ); -}; - -export { Conversation }; diff --git a/frontend/src/components/organisms/header.tsx b/frontend/src/components/organisms/header.tsx index 5c284c8cfe..a974093910 100644 --- a/frontend/src/components/organisms/header.tsx +++ b/frontend/src/components/organisms/header.tsx @@ -23,7 +23,7 @@ import NewChatButton from 'components/molecules/newChatButton'; import { IProjectSettings } from 'state/project'; -import OpenChatHistoryButton from './conversationsHistory/sidebar/OpenChatHistoryButton'; +import OpenChatHistoryButton from './threadHistory/sidebar/OpenThreadListButton'; interface INavItem { to: string; diff --git a/frontend/src/components/organisms/playground/index.tsx b/frontend/src/components/organisms/playground/index.tsx index a49f5cef97..2659fa9e5e 100644 --- a/frontend/src/components/organisms/playground/index.tsx +++ b/frontend/src/components/organisms/playground/index.tsx @@ -1,9 +1,9 @@ import { apiClient } from 'api'; import { useCallback } from 'react'; -import { toast } from 'react-hot-toast'; import { useRecoilState, useRecoilValue } from 'recoil'; +import { toast } from 'sonner'; -import { IPrompt, accessTokenState } from '@chainlit/react-client'; +import { IGeneration, accessTokenState } from '@chainlit/react-client'; import { IPlaygroundContext, PromptPlayground @@ -28,7 +28,7 @@ export default function PlaygroundWrapper() { const [promptMode, setPromptMode] = useRecoilState(modeState); const shoulFetchProviders = - playground?.prompt && !playground?.providers?.length; + playground?.generation && !playground?.providers?.length; useLLMProviders(shoulFetchProviders); @@ -50,12 +50,12 @@ export default function PlaygroundWrapper() { const createCompletion = useCallback( ( - prompt: IPrompt, + generation: IGeneration, controller: AbortController, cb: (done: boolean, token: string) => void ) => { - return apiClient.getCompletion( - prompt, + return apiClient.getGeneration( + generation, userEnv, controller, accessToken, diff --git a/frontend/src/components/organisms/threadHistory/Thread.tsx b/frontend/src/components/organisms/threadHistory/Thread.tsx new file mode 100644 index 0000000000..aa513fddb0 --- /dev/null +++ b/frontend/src/components/organisms/threadHistory/Thread.tsx @@ -0,0 +1,153 @@ +import { apiClient } from 'api'; +import { useCallback, useEffect, useState } from 'react'; +import { Link } from 'react-router-dom'; +import { useRecoilValue } from 'recoil'; +import { toast } from 'sonner'; + +import { Alert, Box, Button, Skeleton, Stack } from '@mui/material'; + +import { + IAction, + IFeedback, + IMessageElement, + IStep, + IThread, + accessTokenState, + nestMessages +} from '@chainlit/react-client'; + +import SideView from 'components/atoms/element/sideView'; +import MessageContainer from 'components/organisms/chat/Messages/container'; + +type Props = { + thread?: IThread; + error?: Error; + isLoading?: boolean; +}; + +const Thread = ({ thread, error, isLoading }: Props) => { + const accessToken = useRecoilValue(accessTokenState); + const [steps, setSteps] = useState([]); + + useEffect(() => { + if (!thread) return; + setSteps(thread.steps); + }, [thread]); + + const onFeedbackUpdated = useCallback( + async (message: IStep, onSuccess: () => void, feedback: IFeedback) => { + try { + toast.promise(apiClient.setFeedback(feedback, accessToken), { + loading: 'Updating', + success: (res) => { + setSteps((prev) => + prev.map((step) => { + if (step.id === message.id) { + return { + ...step, + feedback: { + ...feedback, + id: res.feedbackId + } + }; + } + return step; + }) + ); + + onSuccess(); + return 'Feedback updated!'; + }, + error: (err) => { + return {err.message}; + } + }); + } catch (err) { + console.log(err); + } + }, + [] + ); + + if (isLoading) { + return [1, 2, 3].map((index) => ( + + + + + + + + )); + } + + if (!thread || error) { + return null; + } + + const elements = thread.elements; + const actions: IAction[] = []; + const messages = nestMessages(steps); + + return ( + + + + + Go back to chat + + } + > + This chat was created on{' '} + {new Intl.DateTimeFormat(undefined, { + day: 'numeric', + month: 'numeric', + year: 'numeric', + hour: 'numeric', + minute: 'numeric' + }).format(new Date(thread.createdAt))} + . + + + + + + ); +}; + +export { Thread }; diff --git a/frontend/src/components/organisms/conversationsHistory/sidebar/DeleteConversationButton.tsx b/frontend/src/components/organisms/threadHistory/sidebar/DeleteThreadButton.tsx similarity index 73% rename from frontend/src/components/organisms/conversationsHistory/sidebar/DeleteConversationButton.tsx rename to frontend/src/components/organisms/threadHistory/sidebar/DeleteThreadButton.tsx index 1aa30637da..8c4ef0537f 100644 --- a/frontend/src/components/organisms/conversationsHistory/sidebar/DeleteConversationButton.tsx +++ b/frontend/src/components/organisms/threadHistory/sidebar/DeleteThreadButton.tsx @@ -1,7 +1,7 @@ import { apiClient } from 'api'; import { useState } from 'react'; -import toast from 'react-hot-toast'; import { useRecoilValue } from 'recoil'; +import { toast } from 'sonner'; import DeleteOutline from '@mui/icons-material/DeleteOutline'; import LoadingButton from '@mui/lab/LoadingButton'; @@ -16,11 +16,11 @@ import DialogTitle from '@mui/material/DialogTitle'; import { ClientError, accessTokenState } from '@chainlit/react-client'; interface Props { - conversationId: string; + threadId: string; onDelete: () => void; } -const DeleteConversationButton = ({ conversationId, onDelete }: Props) => { +const DeleteThreadButton = ({ threadId, onDelete }: Props) => { const [open, setOpen] = useState(false); const accessToken = useRecoilValue(accessTokenState); @@ -33,21 +33,20 @@ const DeleteConversationButton = ({ conversationId, onDelete }: Props) => { }; const handleConfirm = async () => { - await toast.promise( - apiClient.deleteConversation(conversationId, accessToken), - { - loading: 'Deleting conversation...', - success: 'Conversation deleted!', - error: (err) => { - if (err instanceof ClientError) { - return {err.message}; - } else { - return ; - } + toast.promise(apiClient.deleteThread(threadId, accessToken), { + loading: 'Deleting chat', + success: () => { + onDelete(); + return 'Chat deleted!'; + }, + error: (err) => { + if (err instanceof ClientError) { + return {err.message}; + } else { + return ; } } - ); - onDelete(); + }); handleClose(); }; @@ -68,13 +67,10 @@ const DeleteConversationButton = ({ conversationId, onDelete }: Props) => { } }} > - - {'Delete conversation?'} - + {'Delete Thread?'} - This will delete the conversation as well as it's messages and - elements. + This will delete the thread as well as it's messages and elements. @@ -96,4 +92,4 @@ const DeleteConversationButton = ({ conversationId, onDelete }: Props) => { ); }; -export { DeleteConversationButton }; +export { DeleteThreadButton }; diff --git a/frontend/src/components/organisms/conversationsHistory/sidebar/OpenChatHistoryButton.tsx b/frontend/src/components/organisms/threadHistory/sidebar/OpenThreadListButton.tsx similarity index 90% rename from frontend/src/components/organisms/conversationsHistory/sidebar/OpenChatHistoryButton.tsx rename to frontend/src/components/organisms/threadHistory/sidebar/OpenThreadListButton.tsx index 6fdc0ef725..ac6b01990d 100644 --- a/frontend/src/components/organisms/conversationsHistory/sidebar/OpenChatHistoryButton.tsx +++ b/frontend/src/components/organisms/threadHistory/sidebar/OpenThreadListButton.tsx @@ -6,7 +6,7 @@ import IconButton from '@mui/material/IconButton'; import { settingsState } from 'state/settings'; -const OpenChatHistoryButton = ({ mode }: { mode: 'mobile' | 'desktop' }) => { +const OpenThreadListButton = ({ mode }: { mode: 'mobile' | 'desktop' }) => { const [settings, setSettings] = useRecoilState(settingsState); const isDesktop = mode === 'desktop'; @@ -45,4 +45,4 @@ const OpenChatHistoryButton = ({ mode }: { mode: 'mobile' | 'desktop' }) => { ) : null; }; -export default OpenChatHistoryButton; +export default OpenThreadListButton; diff --git a/frontend/src/components/organisms/conversationsHistory/sidebar/ConversationsHistoryList.tsx b/frontend/src/components/organisms/threadHistory/sidebar/ThreadList.tsx similarity index 75% rename from frontend/src/components/organisms/conversationsHistory/sidebar/ConversationsHistoryList.tsx rename to frontend/src/components/organisms/threadHistory/sidebar/ThreadList.tsx index 7cc4da7c9f..5ac81713e5 100644 --- a/frontend/src/components/organisms/conversationsHistory/sidebar/ConversationsHistoryList.tsx +++ b/frontend/src/components/organisms/threadHistory/sidebar/ThreadList.tsx @@ -14,40 +14,39 @@ import Stack from '@mui/material/Stack'; import Typography from '@mui/material/Typography'; import { - ConversationsHistory, + ThreadHistory, useChatInteract, useChatSession } from '@chainlit/react-client'; import { grey } from '@chainlit/react-components'; -import { DeleteConversationButton } from './DeleteConversationButton'; +import { DeleteThreadButton } from './DeleteThreadButton'; -interface ConversationsHistoryProps { - conversations?: ConversationsHistory; +interface Props { + threadHistory?: ThreadHistory; error?: string; - fetchConversations: () => void; + fetchThreads: () => void; isFetching: boolean; isLoadingMore: boolean; } -const ConversationsHistoryList = ({ - conversations, +const ThreadList = ({ + threadHistory, error, - fetchConversations, + fetchThreads, isFetching, isLoadingMore -}: ConversationsHistoryProps) => { +}: Props) => { const { idToResume } = useChatSession(); const { clear } = useChatInteract(); const navigate = useNavigate(); - - if (isFetching || (!conversations?.groupedConversations && isLoadingMore)) { + if (isFetching || (!threadHistory?.timeGroupedThreads && isLoadingMore)) { return [1, 2, 3].map((index) => ( - + {[1, 2].map((childIndex) => ( Empty... @@ -83,14 +82,14 @@ const ConversationsHistoryList = ({ ); } - const handleDeleteConversation = (conversationId: string) => { - if (conversationId === idToResume) { + const handleDeleteThread = (threadId: string) => { + if (threadId === idToResume) { clear(); } - if (conversationId === conversations.currentConversationId) { + if (threadId === threadHistory.currentThreadId) { navigate('/'); } - fetchConversations(); + fetchThreads(); }; return ( @@ -104,7 +103,7 @@ const ConversationsHistoryList = ({ }} subheader={
  • } > - {map(conversations.groupedConversations, (items, index) => { + {map(threadHistory.timeGroupedThreads, (items, index) => { return (
    • @@ -121,20 +120,18 @@ const ConversationsHistoryList = ({ {index} - {map(items, (conversation) => { + {map(items, (thread) => { const isResumed = - idToResume === conversation.id && - !conversations.currentConversationId; + idToResume === thread.id && !threadHistory.currentThreadId; const isSelected = - isResumed || - conversations.currentConversationId === conversation.id; + isResumed || threadHistory.currentThreadId === thread.id; return ( ({ textDecoration: 'none', cursor: 'pointer', @@ -156,7 +153,7 @@ const ConversationsHistoryList = ({ : 'grey.200' } })} - to={isResumed ? '' : `/conversation/${conversation.id}`} + to={isResumed ? '' : `/thread/${thread.id}`} > - {capitalize(conversation.messages[0]?.content)} + {capitalize(thread.metadata?.name || 'Unknown')} {isSelected ? ( - - handleDeleteConversation(conversation.id) - } + handleDeleteThread(thread.id)} /> ) : null} @@ -218,4 +213,4 @@ const ConversationsHistoryList = ({ ); }; -export { ConversationsHistoryList }; +export { ThreadList }; diff --git a/frontend/src/components/organisms/conversationsHistory/sidebar/filters/FeedbackSelect.tsx b/frontend/src/components/organisms/threadHistory/sidebar/filters/FeedbackSelect.tsx similarity index 95% rename from frontend/src/components/organisms/conversationsHistory/sidebar/filters/FeedbackSelect.tsx rename to frontend/src/components/organisms/threadHistory/sidebar/filters/FeedbackSelect.tsx index 8b99879823..0923ca76a4 100644 --- a/frontend/src/components/organisms/conversationsHistory/sidebar/filters/FeedbackSelect.tsx +++ b/frontend/src/components/organisms/threadHistory/sidebar/filters/FeedbackSelect.tsx @@ -11,7 +11,7 @@ import Stack from '@mui/material/Stack'; import { grey } from '@chainlit/react-components'; -import { conversationsFiltersState } from 'state/conversations'; +import { threadsFiltersState } from 'state/threads'; export enum FEEDBACKS { ALL = 0, @@ -20,7 +20,7 @@ export enum FEEDBACKS { } export default function FeedbackSelect() { - const [filters, setFilters] = useRecoilState(conversationsFiltersState); + const [filters, setFilters] = useRecoilState(threadsFiltersState); const [anchorEl, setAnchorEl] = useState(null); const handleChange = (feedback: number) => { diff --git a/frontend/src/components/organisms/conversationsHistory/sidebar/filters/SearchBar.tsx b/frontend/src/components/organisms/threadHistory/sidebar/filters/SearchBar.tsx similarity index 93% rename from frontend/src/components/organisms/conversationsHistory/sidebar/filters/SearchBar.tsx rename to frontend/src/components/organisms/threadHistory/sidebar/filters/SearchBar.tsx index a6fe862ded..10137e9751 100644 --- a/frontend/src/components/organisms/conversationsHistory/sidebar/filters/SearchBar.tsx +++ b/frontend/src/components/organisms/threadHistory/sidebar/filters/SearchBar.tsx @@ -10,10 +10,10 @@ import TextField from '@mui/material/TextField'; import { grey } from '@chainlit/react-components'; -import { conversationsFiltersState } from 'state/conversations'; +import { threadsFiltersState } from 'state/threads'; export default function SearchBar() { - const [filters, setFilters] = useRecoilState(conversationsFiltersState); + const [filters, setFilters] = useRecoilState(threadsFiltersState); const handleChange = (value: string) => { value = value.trim(); diff --git a/frontend/src/components/organisms/conversationsHistory/sidebar/filters/index.tsx b/frontend/src/components/organisms/threadHistory/sidebar/filters/index.tsx similarity index 100% rename from frontend/src/components/organisms/conversationsHistory/sidebar/filters/index.tsx rename to frontend/src/components/organisms/threadHistory/sidebar/filters/index.tsx diff --git a/frontend/src/components/organisms/conversationsHistory/sidebar/index.tsx b/frontend/src/components/organisms/threadHistory/sidebar/index.tsx similarity index 73% rename from frontend/src/components/organisms/conversationsHistory/sidebar/index.tsx rename to frontend/src/components/organisms/threadHistory/sidebar/index.tsx index 496e637bb8..92369674d3 100644 --- a/frontend/src/components/organisms/conversationsHistory/sidebar/index.tsx +++ b/frontend/src/components/organisms/threadHistory/sidebar/index.tsx @@ -13,16 +13,16 @@ import Typography from '@mui/material/Typography'; import useMediaQuery from '@mui/material/useMediaQuery'; import { - IConversationsFilters, + IThreadFilters, accessTokenState, - conversationsHistoryState + threadHistoryState } from '@chainlit/react-client'; -import { conversationsFiltersState } from 'state/conversations'; import { projectSettingsState } from 'state/project'; import { settingsState } from 'state/settings'; +import { threadsFiltersState } from 'state/threads'; -import { ConversationsHistoryList } from './ConversationsHistoryList'; +import { ThreadList } from './ThreadList'; import Filters from './filters'; const DRAWER_WIDTH = 260; @@ -30,20 +30,17 @@ const BATCH_SIZE = 20; let _scrollTop = 0; -const _ConversationsHistorySidebar = () => { +const _ThreadHistorySideBar = () => { const isMobile = useMediaQuery('(max-width:66rem)'); - const [conversations, setConversations] = useRecoilState( - conversationsHistoryState - ); + const [threadHistory, setThreadHistory] = useRecoilState(threadHistoryState); const accessToken = useRecoilValue(accessTokenState); - const filters = useRecoilValue(conversationsFiltersState); + const filters = useRecoilValue(threadsFiltersState); const [settings, setSettings] = useRecoilState(settingsState); const [shouldLoadMore, setShouldLoadMore] = useState(false); const [error, setError] = useState(undefined); - const [prevFilters, setPrevFilters] = - useState(filters); + const [prevFilters, setPrevFilters] = useState(filters); const [isLoadingMore, setIsLoadingMore] = useState(false); const [isFetching, setIsFetching] = useState(false); @@ -62,32 +59,32 @@ const _ConversationsHistorySidebar = () => { setShouldLoadMore(atBottom); }; - const fetchConversations = async (cursor?: string | number) => { + const fetchThreads = async (cursor?: string | number) => { try { if (cursor) { setIsLoadingMore(true); } else { setIsFetching(true); } - const { pageInfo, data } = await apiClient.getConversations( + const { pageInfo, data } = await apiClient.listThreads( { first: BATCH_SIZE, cursor }, filters, accessToken ); setError(undefined); - // Prevent conversations to be duplicated - const allConversations = uniqBy( - // We should only concatenate conversations when we have a cursor indicating that we have loaded more items. - cursor ? conversations?.conversations?.concat(data) : data, + // Prevent threads to be duplicated + const allThreads = uniqBy( + // We should only concatenate threads when we have a cursor indicating that we have loaded more items. + cursor ? threadHistory?.threads?.concat(data) : data, 'id' ); - if (allConversations) { - setConversations((prev) => ({ + if (allThreads) { + setThreadHistory((prev) => ({ ...prev, pageInfo: pageInfo, - conversations: allConversations + threads: allThreads })); } } catch (error) { @@ -99,21 +96,21 @@ const _ConversationsHistorySidebar = () => { } }; - if (accessToken && !isFetching && !conversations?.conversations && !error) { - fetchConversations(); + if (accessToken && !isFetching && !threadHistory?.threads && !error) { + fetchThreads(); } - if (conversations?.pageInfo) { - const { hasNextPage, endCursor } = conversations.pageInfo; + if (threadHistory?.pageInfo) { + const { hasNextPage, endCursor } = threadHistory.pageInfo; if (shouldLoadMore && !isLoadingMore && hasNextPage && endCursor) { - fetchConversations(endCursor); + fetchThreads(endCursor); } } if (filtersHasChanged) { setPrevFilters(filters); - fetchConversations(); + fetchThreads(); } const setChatHistoryOpen = (open: boolean) => @@ -170,20 +167,20 @@ const _ConversationsHistorySidebar = () => { color: (theme) => theme.palette.text.primary }} > - Chat History + Past Chats setChatHistoryOpen(false)}> - {conversations ? ( - ) : null} @@ -191,7 +188,7 @@ const _ConversationsHistorySidebar = () => { ); }; -const ConversationsHistorySidebar = () => { +const ThreadHistorySideBar = () => { const { user } = useAuth(); const pSettings = useRecoilValue(projectSettingsState); @@ -199,7 +196,7 @@ const ConversationsHistorySidebar = () => { return null; } - return <_ConversationsHistorySidebar />; + return <_ThreadHistorySideBar />; }; -export { ConversationsHistorySidebar }; +export { ThreadHistorySideBar }; diff --git a/frontend/src/hooks/localChatHistory.ts b/frontend/src/hooks/localChatHistory.ts deleted file mode 100644 index a961c7d32e..0000000000 --- a/frontend/src/hooks/localChatHistory.ts +++ /dev/null @@ -1,42 +0,0 @@ -import { useCallback } from 'react'; - -import { MessageHistory } from '@chainlit/react-client'; - -const KEY = 'chatHistory'; -const MAX_SIZE = 50; - -export default function useLocalChatHistory() { - const getLocalChatHistory = useCallback(() => { - const messageHistory = localStorage.getItem(KEY); - if (messageHistory) { - return JSON.parse(messageHistory) as MessageHistory[]; - } - return []; - }, []); - - const persistChatLocally = useCallback((message: string) => { - const messageHistory: { messages: MessageHistory[] } = { - messages: [ - { - content: message, - createdAt: new Date().getTime() - } - ] - }; - - const chatHistory = getLocalChatHistory(); - - if (!chatHistory) { - localStorage.setItem(KEY, JSON.stringify([messageHistory])); - } else { - let curr = [messageHistory, ...chatHistory]; - if (curr.length > MAX_SIZE) { - curr = curr.slice(0, MAX_SIZE); - } - localStorage.setItem(KEY, JSON.stringify(curr)); - } - return []; - }, []); - - return { persistChatLocally, getLocalChatHistory }; -} diff --git a/frontend/src/hooks/useLLMProviders.ts b/frontend/src/hooks/useLLMProviders.ts index a0b0c23fc6..d38c370c35 100644 --- a/frontend/src/hooks/useLLMProviders.ts +++ b/frontend/src/hooks/useLLMProviders.ts @@ -1,7 +1,7 @@ import { apiClient } from 'api'; import { useEffect } from 'react'; -import toast from 'react-hot-toast'; import { useSetRecoilState } from 'recoil'; +import { toast } from 'sonner'; import { useApi } from '@chainlit/react-client'; import { IPlayground } from '@chainlit/react-components'; diff --git a/frontend/src/pages/Conversation.tsx b/frontend/src/pages/Conversation.tsx deleted file mode 100644 index 2ced403f45..0000000000 --- a/frontend/src/pages/Conversation.tsx +++ /dev/null @@ -1,63 +0,0 @@ -import { apiClient } from 'api'; -import { useEffect } from 'react'; -import { useParams } from 'react-router-dom'; -import { useRecoilState } from 'recoil'; - -import { Box } from '@mui/material'; - -import { - IConversation, - conversationsHistoryState, - useApi -} from '@chainlit/react-client'; - -import { Conversation } from 'components/organisms/conversationsHistory/Conversation'; - -import Page from './Page'; -import ResumeButton from './ResumeButton'; - -export default function ConversationPage() { - const { id } = useParams(); - const { data, error, isLoading } = useApi( - apiClient, - id ? `/project/conversation/${id}` : null, - { - revalidateOnFocus: false, - revalidateIfStale: false - } - ); - - const [conversations, setConversations] = useRecoilState( - conversationsHistoryState - ); - - useEffect(() => { - if (conversations?.currentConversationId !== id) { - setConversations((prev) => { - return { ...prev, currentConversationId: id }; - }); - } - }, [id]); - - return ( - - - - - - - - - ); -} diff --git a/frontend/src/pages/Element.tsx b/frontend/src/pages/Element.tsx index 092384205d..eb3e31c55b 100644 --- a/frontend/src/pages/Element.tsx +++ b/frontend/src/pages/Element.tsx @@ -1,6 +1,7 @@ import { apiClient } from 'api'; import { useEffect, useState } from 'react'; import { useNavigate, useParams } from 'react-router-dom'; +import { useRecoilValue } from 'recoil'; import Page from 'pages/Page'; @@ -9,34 +10,39 @@ import { ElementView } from '@chainlit/react-components'; import { useQuery } from 'hooks/query'; +import { projectSettingsState } from 'state/project'; + export default function Element() { const { id } = useParams(); const query = useQuery(); const { elements } = useChatData(); + const pSettings = useRecoilValue(projectSettingsState); const [element, setElement] = useState(null); const navigate = useNavigate(); - const conversationId = query.get('conversation'); + const threadId = query.get('thread'); + + const dataPersistence = pSettings?.dataPersistence; const { data, error } = useApi( apiClient, - id && conversationId - ? `/project/conversation/${conversationId}/element/${id}` + id && threadId && dataPersistence + ? `/project/thread/${threadId}/element/${id}` : null ); useEffect(() => { if (data) { setElement(data); - } else if (id && !conversationId && !element) { + } else if (id && !dataPersistence && !element) { const foundElement = elements.find((element) => element.id === id); if (foundElement) { setElement(foundElement); } } - }, [data, element, elements, id, conversationId]); + }, [data, element, elements, id, threadId]); if (!element || error) { return null; diff --git a/frontend/src/pages/Env.tsx b/frontend/src/pages/Env.tsx index 2dff0299f9..fb57f520f9 100644 --- a/frontend/src/pages/Env.tsx +++ b/frontend/src/pages/Env.tsx @@ -1,7 +1,7 @@ import { useFormik } from 'formik'; -import { toast } from 'react-hot-toast'; import { useNavigate } from 'react-router-dom'; import { useRecoilState, useRecoilValue } from 'recoil'; +import { toast } from 'sonner'; import * as yup from 'yup'; import { Alert, Box, Button, Typography } from '@mui/material'; diff --git a/frontend/src/pages/Page.tsx b/frontend/src/pages/Page.tsx index dc8d072832..f7101ae4fd 100644 --- a/frontend/src/pages/Page.tsx +++ b/frontend/src/pages/Page.tsx @@ -4,9 +4,9 @@ import { useRecoilValue } from 'recoil'; import { Alert, Box, Stack } from '@mui/material'; -import { ConversationsHistorySidebar } from 'components/organisms/conversationsHistory/sidebar'; -import OpenChatHistoryButton from 'components/organisms/conversationsHistory/sidebar/OpenChatHistoryButton'; import { Header } from 'components/organisms/header'; +import { ThreadHistorySideBar } from 'components/organisms/threadHistory/sidebar'; +import OpenChatHistoryButton from 'components/organisms/threadHistory/sidebar/OpenThreadListButton'; import { projectSettingsState } from 'state/project'; import { userEnvState } from 'state/user'; @@ -43,7 +43,7 @@ const Page = ({ children }: Props) => { You are not part of this project. ) : ( - + {children} diff --git a/frontend/src/pages/ResumeButton.tsx b/frontend/src/pages/ResumeButton.tsx index 4d199cc99a..93426dcc29 100644 --- a/frontend/src/pages/ResumeButton.tsx +++ b/frontend/src/pages/ResumeButton.tsx @@ -1,6 +1,6 @@ -import toast from 'react-hot-toast'; import { useNavigate } from 'react-router-dom'; import { useRecoilValue } from 'recoil'; +import { toast } from 'sonner'; import { Box, Button } from '@mui/material'; @@ -11,22 +11,22 @@ import WaterMark from 'components/organisms/chat/inputBox/waterMark'; import { projectSettingsState } from 'state/project'; interface Props { - conversationId?: string; + threadId?: string; } -export default function ResumeButton({ conversationId }: Props) { +export default function ResumeButton({ threadId }: Props) { const navigate = useNavigate(); const pSettings = useRecoilValue(projectSettingsState); const { clear, setIdToResume } = useChatInteract(); - if (!conversationId || !pSettings?.conversationResumable) { + if (!threadId || !pSettings?.threadResumable) { return; } const onClick = () => { clear(); - setIdToResume(conversationId!); - toast.success('Conversation resumed!'); + setIdToResume(threadId!); + toast.success('Chat resumed!'); navigate('/'); }; @@ -44,8 +44,8 @@ export default function ResumeButton({ conversationId }: Props) { justifyContent: 'center' }} > - diff --git a/frontend/src/pages/Thread.tsx b/frontend/src/pages/Thread.tsx new file mode 100644 index 0000000000..60a0338ddb --- /dev/null +++ b/frontend/src/pages/Thread.tsx @@ -0,0 +1,53 @@ +import { apiClient } from 'api'; +import { useEffect } from 'react'; +import { useParams } from 'react-router-dom'; +import { useRecoilState } from 'recoil'; + +import { Box } from '@mui/material'; + +import { IThread, threadHistoryState, useApi } from '@chainlit/react-client'; + +import { Thread } from 'components/organisms/threadHistory/Thread'; + +import Page from './Page'; +import ResumeButton from './ResumeButton'; + +export default function ThreadPage() { + const { id } = useParams(); + const { data, error, isLoading } = useApi( + apiClient, + id ? `/project/thread/${id}` : null, + { + revalidateOnFocus: false, + revalidateIfStale: false + } + ); + + const [threadHistory, setThreadHistory] = useRecoilState(threadHistoryState); + + useEffect(() => { + if (threadHistory?.currentThreadId !== id) { + setThreadHistory((prev) => { + return { ...prev, currentThreadId: id }; + }); + } + }, [id]); + + return ( + + + + + + + + + ); +} diff --git a/frontend/src/router.tsx b/frontend/src/router.tsx index b51ca19fdf..b25a6767b5 100644 --- a/frontend/src/router.tsx +++ b/frontend/src/router.tsx @@ -1,13 +1,13 @@ import { Navigate, createBrowserRouter } from 'react-router-dom'; import AuthCallback from 'pages/AuthCallback'; -import Conversation from 'pages/Conversation'; import Design from 'pages/Design'; import Element from 'pages/Element'; import Env from 'pages/Env'; import Home from 'pages/Home'; import Login from 'pages/Login'; import Readme from 'pages/Readme'; +import Thread from 'pages/Thread'; export const router = createBrowserRouter([ { @@ -23,8 +23,8 @@ export const router = createBrowserRouter([ element: }, { - path: '/conversation/:id?', - element: + path: '/thread/:id?', + element: }, { path: '/element/:id', diff --git a/frontend/src/state/chat.ts b/frontend/src/state/chat.ts index ff44d3689b..97d1aa49d5 100644 --- a/frontend/src/state/chat.ts +++ b/frontend/src/state/chat.ts @@ -1,8 +1,18 @@ import { atom } from 'recoil'; -import { IFileElement } from '@chainlit/react-client'; +export interface IAttachment { + id: string; + serverId?: string; + name: string; + size: number; + type: string; + uploadProgress?: number; + uploaded?: boolean; + cancel?: () => void; + remove?: () => void; +} -export const attachmentsState = atom({ +export const attachmentsState = atom({ key: 'Attachments', default: [] }); diff --git a/frontend/src/state/conversations.ts b/frontend/src/state/conversations.ts deleted file mode 100644 index 5c0b87e7aa..0000000000 --- a/frontend/src/state/conversations.ts +++ /dev/null @@ -1,8 +0,0 @@ -import { atom } from 'recoil'; - -import { IConversationsFilters } from '@chainlit/react-client'; - -export const conversationsFiltersState = atom({ - key: 'ConversationsFilters', - default: {} -}); diff --git a/frontend/src/state/project.ts b/frontend/src/state/project.ts index 27963c2342..a98d8c89d6 100644 --- a/frontend/src/state/project.ts +++ b/frontend/src/state/project.ts @@ -1,6 +1,6 @@ import { atom } from 'recoil'; -import { IMessage, IMessageElement } from '@chainlit/react-client'; +import { IMessageElement, IStep } from '@chainlit/react-client'; export interface ChatProfile { icon: string; @@ -30,7 +30,7 @@ export interface IProjectSettings { }; userEnv: string[]; dataPersistence: boolean; - conversationResumable: boolean; + threadResumable: boolean; chatProfiles: ChatProfile[]; } @@ -44,7 +44,7 @@ export const sideViewState = atom({ default: undefined }); -export const highlightMessage = atom({ +export const highlightMessage = atom({ key: 'HighlightMessage', default: null }); diff --git a/frontend/src/state/threads.ts b/frontend/src/state/threads.ts new file mode 100644 index 0000000000..72ad24f840 --- /dev/null +++ b/frontend/src/state/threads.ts @@ -0,0 +1,8 @@ +import { atom } from 'recoil'; + +import { IThreadFilters } from '@chainlit/react-client'; + +export const threadsFiltersState = atom({ + key: 'ThreadsFilters', + default: {} +}); diff --git a/frontend/src/state/chatHistory.ts b/frontend/src/state/userInputHistory.ts similarity index 66% rename from frontend/src/state/chatHistory.ts rename to frontend/src/state/userInputHistory.ts index dfebbfb4e1..fe53083afe 100644 --- a/frontend/src/state/chatHistory.ts +++ b/frontend/src/state/userInputHistory.ts @@ -1,8 +1,8 @@ import { atom } from 'recoil'; -import { MessageHistory } from '@chainlit/react-client'; +import { UserInput } from '@chainlit/react-client'; -const KEY = 'chat_history'; +const KEY = 'input_history'; const localStorageEffect = (key: string) => @@ -12,7 +12,7 @@ const localStorageEffect = setSelf(JSON.parse(savedValue)); } - onSet((newValue: MessageHistory, _: any, isReset: boolean) => { + onSet((newValue: UserInput, _: any, isReset: boolean) => { if (isReset) { localStorage.removeItem(key); } else { @@ -21,14 +21,14 @@ const localStorageEffect = }); }; -export const chatHistoryState = atom<{ +export const inputHistoryState = atom<{ open: boolean; - messages: MessageHistory[]; + inputs: UserInput[]; }>({ - key: 'ChatHistory', + key: 'UserInputHistory', default: { open: false, - messages: [] + inputs: [] }, effects: [localStorageEffect(KEY)] }); diff --git a/libs/react-client/README.md b/libs/react-client/README.md index af8668b9bf..5c46baa39a 100644 --- a/libs/react-client/README.md +++ b/libs/react-client/README.md @@ -41,7 +41,11 @@ This hook is responsible for managing the chat session's connection to the WebSo #### Example ```jsx -import { useChatSession } from '@chainlit/react-client'; +import { ChainlitAPI, useChatSession } from '@chainlit/react-client'; + +const CHAINLIT_SERVER_URL = 'http://localhost:8000'; + +const apiClient = new ChainlitAPI(CHAINLIT_SERVER_URL); const ChatComponent = () => { const { connect, disconnect, chatProfile, setChatProfile } = useChatSession(); @@ -49,7 +53,7 @@ const ChatComponent = () => { // Connect to the WebSocket server useEffect(() => { connect({ - wsEndpoint: 'YOUR_WEBSOCKET_ENDPOINT', // Your Chainlit server url + client: apiClient, userEnv: { /* user environment variables */ }, @@ -140,7 +144,7 @@ This hook provides methods to interact with the chat, such as sending messages, - `replyMessage`: Replies to a message. - `sendMessage`: Sends a message. - `stopTask`: Stops the current task. -- `setIdToResume`: Sets the ID to resume a conversation. +- `setIdToResume`: Sets the ID to resume a thread. - `updateChatSettings`: Updates the chat settings. #### Example diff --git a/libs/react-client/src/api/hooks/auth.ts b/libs/react-client/src/api/hooks/auth.ts index 23e97a649f..a9613d0849 100644 --- a/libs/react-client/src/api/hooks/auth.ts +++ b/libs/react-client/src/api/hooks/auth.ts @@ -1,12 +1,8 @@ import jwt_decode from 'jwt-decode'; import { useEffect } from 'react'; import { useRecoilState, useSetRecoilState } from 'recoil'; -import { - accessTokenState, - conversationsHistoryState, - userState -} from 'src/state'; -import { IAppUser } from 'src/types'; +import { accessTokenState, threadHistoryState, userState } from 'src/state'; +import { IUser } from 'src/types'; import { getToken, removeToken, setToken } from 'src/utils/token'; import { ChainlitAPI } from '..'; @@ -20,7 +16,7 @@ export const useAuth = (apiClient: ChainlitAPI) => { oauthProviders: string[]; }>(apiClient, '/auth/config'); const [accessToken, setAccessToken] = useRecoilState(accessTokenState); - const setConversationsHistory = useSetRecoilState(conversationsHistoryState); + const setThreadHistory = useSetRecoilState(threadHistoryState); const [user, setUser] = useRecoilState(userState); const isReady = !!(!isLoading && data); @@ -29,7 +25,7 @@ export const useAuth = (apiClient: ChainlitAPI) => { setUser(null); removeToken(); setAccessToken(''); - setConversationsHistory(undefined); + setThreadHistory(undefined); }; const saveAndSetToken = (token: string | null | undefined) => { @@ -38,10 +34,10 @@ export const useAuth = (apiClient: ChainlitAPI) => { return; } try { - const { exp, ...AppUser } = jwt_decode(token) as any; + const { exp, ...User } = jwt_decode(token) as any; setToken(token); setAccessToken(`Bearer ${token}`); - setUser(AppUser as IAppUser); + setUser(User as IUser); } catch (e) { console.error( 'Invalid token, clearing token from local storage', @@ -66,7 +62,6 @@ export const useAuth = (apiClient: ChainlitAPI) => { return { data, user: null, - role: 'ANONYMOUS', isReady, isAuthenticated: true, accessToken: '', @@ -78,7 +73,6 @@ export const useAuth = (apiClient: ChainlitAPI) => { return { data, user: user, - role: user?.role, isAuthenticated, isReady, accessToken: accessToken, diff --git a/libs/react-client/src/api/index.tsx b/libs/react-client/src/api/index.tsx index da8b8a19f3..6d5225314f 100644 --- a/libs/react-client/src/api/index.tsx +++ b/libs/react-client/src/api/index.tsx @@ -1,11 +1,12 @@ -import { IConversation, IPrompt } from 'src/types'; +import { IGeneration, IThread } from 'src/types'; import { removeToken } from 'src/utils/token'; +import { IFeedback } from 'src/types/feedback'; + export * from './hooks/auth'; export * from './hooks/api'; -export interface IConversationsFilters { - authorEmail?: string; +export interface IThreadFilters { search?: string; feedback?: number; } @@ -147,16 +148,22 @@ export class ChainlitAPI extends APIBase { return res.json(); } - async getCompletion( - prompt: IPrompt, + async getGeneration( + generation: IGeneration, userEnv = {}, controller: AbortController, accessToken?: string, tokenCb?: (done: boolean, token: string) => void ) { + const payload = { userEnv }; + if (generation.type === 'CHAT') { + payload['chatGeneration'] = generation; + } else { + payload['completionGeneration'] = generation; + } const response = await this.post( - `/completion`, - { prompt, userEnv }, + `/generation`, + payload, accessToken, controller.signal ); @@ -192,29 +199,24 @@ export class ChainlitAPI extends APIBase { return stream; } - async setHumanFeedback( - messageId: string, - feedback: number, - feedbackComment?: string, + async setFeedback( + feedback: IFeedback, accessToken?: string - ) { - await this.put( - `/message/feedback`, - { messageId, feedback, feedbackComment }, - accessToken - ); + ): Promise<{ success: boolean; feedbackId: string }> { + const res = await this.put(`/feedback`, { feedback }, accessToken); + return res.json(); } - async getConversations( + async listThreads( pagination: IPagination, - filter: IConversationsFilters, + filter: IThreadFilters, accessToken?: string ): Promise<{ pageInfo: IPageInfo; - data: IConversation[]; + data: IThread[]; }> { const res = await this.post( - `/project/conversations`, + `/project/threads`, { pagination, filter }, accessToken ); @@ -222,16 +224,71 @@ export class ChainlitAPI extends APIBase { return res.json(); } - async deleteConversation(conversationId: string, accessToken?: string) { - const res = await this.delete( - `/project/conversation`, - { conversationId }, - accessToken - ); + async deleteThread(threadId: string, accessToken?: string) { + const res = await this.delete(`/project/thread`, { threadId }, accessToken); return res.json(); } + uploadFile( + file: File, + onProgress: (progress: number) => void, + sessionId: string, + token?: string + ) { + const xhr = new XMLHttpRequest(); + + const promise = new Promise<{ id: string }>((resolve, reject) => { + const formData = new FormData(); + formData.append('file', file); + + xhr.open( + 'POST', + this.buildEndpoint(`/project/file?session_id=${sessionId}`), + true + ); + + if (token) { + xhr.setRequestHeader('Authorization', this.checkToken(token)); + } + + // Track the progress of the upload + xhr.upload.onprogress = function (event) { + if (event.lengthComputable) { + const percentage = (event.loaded / event.total) * 100; + onProgress(percentage); + } + }; + + xhr.onload = function () { + if (xhr.status === 200) { + const response = JSON.parse(xhr.responseText); + resolve(response); + } else { + reject('Upload failed'); + } + }; + + xhr.onerror = function () { + reject('Upload error'); + }; + + xhr.send(formData); + }); + + return { xhr, promise }; + } + + getElementUrl(id: string, sessionId: string, accessToken?: string) { + let tokenParam = ''; + if (accessToken) { + tokenParam = `?token=${accessToken}`; + } + return this.buildEndpoint( + `/project/file/${id}?session_id=${sessionId}${tokenParam}` + ); + } + getLogoEndpoint(theme: string) { return this.buildEndpoint(`/logo?theme=${theme}`); } diff --git a/libs/react-client/src/state.ts b/libs/react-client/src/state.ts index 7e32da575f..a297f4d3be 100644 --- a/libs/react-client/src/state.ts +++ b/libs/react-client/src/state.ts @@ -4,15 +4,14 @@ import { Socket } from 'socket.io-client'; import { v4 as uuidv4 } from 'uuid'; import { - ConversationsHistory, IAction, - IAppUser, IAsk, IAvatarElement, - IMessage, IMessageElement, + IStep, ITasklistElement, - Role + IUser, + ThreadHistory } from './types'; import { groupByDate } from './utils/group'; @@ -21,8 +20,8 @@ export interface ISession { error?: boolean; } -export const conversationIdToResumeState = atom({ - key: 'ConversationIdToResume', +export const threadIdToResumeState = atom({ + key: 'ThreadIdToResume', default: undefined }); @@ -54,7 +53,7 @@ export const actionState = atom({ default: [] }); -export const messagesState = atom({ +export const messagesState = atom({ key: 'Messages', dangerouslyAllowMutability: true, default: [] @@ -113,7 +112,7 @@ export const tasklistState = atom({ default: [] }); -export const firstUserMessageState = atom({ +export const firstUserMessageState = atom({ key: 'FirstUserMessage', default: undefined }); @@ -123,48 +122,40 @@ export const accessTokenState = atom({ default: undefined }); -export const roleState = atom({ - key: 'Role', - default: undefined -}); - -export const userState = atom({ +export const userState = atom({ key: 'User', default: null }); -export const conversationsHistoryState = atom( - { - key: 'ConversationsHistory', - default: { - conversations: undefined, - currentConversationId: undefined, - groupedConversations: undefined, - pageInfo: undefined - }, - effects: [ - ({ setSelf, onSet }: { setSelf: any; onSet: any }) => { - onSet( - ( - newValue: ConversationsHistory | undefined, - oldValue: ConversationsHistory | undefined - ) => { - let groupedConversations = newValue?.groupedConversations; - - if ( - newValue?.conversations && - !isEqual(newValue.conversations, oldValue?.groupedConversations) - ) { - groupedConversations = groupByDate(newValue.conversations); - } - - setSelf({ - ...newValue, - groupedConversations - }); +export const threadHistoryState = atom({ + key: 'ThreadHistory', + default: { + threads: undefined, + currentThreadId: undefined, + timeGroupedThreads: undefined, + pageInfo: undefined + }, + effects: [ + ({ setSelf, onSet }: { setSelf: any; onSet: any }) => { + onSet( + ( + newValue: ThreadHistory | undefined, + oldValue: ThreadHistory | undefined + ) => { + let timeGroupedThreads = newValue?.timeGroupedThreads; + if ( + newValue?.threads && + !isEqual(newValue.threads, oldValue?.timeGroupedThreads) + ) { + timeGroupedThreads = groupByDate(newValue.threads); } - ); - } - ] - } -); + + setSelf({ + ...newValue, + timeGroupedThreads + }); + } + ); + } + ] +}); diff --git a/libs/react-client/src/types/chatHistory.ts b/libs/react-client/src/types/chatHistory.ts deleted file mode 100644 index 8b2e19f3d2..0000000000 --- a/libs/react-client/src/types/chatHistory.ts +++ /dev/null @@ -1,15 +0,0 @@ -import { IConversation } from 'src/types'; - -import { IPageInfo } from '..'; - -export type MessageHistory = { - content: string; - createdAt: number; -}; - -export type ConversationsHistory = { - conversations?: IConversation[]; - currentConversationId?: string; - groupedConversations?: { [key: string]: IConversation[] }; - pageInfo?: IPageInfo; -}; diff --git a/libs/react-client/src/types/conversation.ts b/libs/react-client/src/types/conversation.ts deleted file mode 100644 index dd2318db27..0000000000 --- a/libs/react-client/src/types/conversation.ts +++ /dev/null @@ -1,12 +0,0 @@ -import { IElement } from './element'; -import { IMessage } from './message'; -import { IAppUser } from './user'; - -export interface IConversation { - id: string; - createdAt: number | string; - appUser?: IAppUser; - metadata?: Record; - messages: IMessage[]; - elements: IElement[]; -} diff --git a/libs/react-client/src/types/element.ts b/libs/react-client/src/types/element.ts index 397258e0ab..0bbe6dd24a 100644 --- a/libs/react-client/src/types/element.ts +++ b/libs/react-client/src/types/element.ts @@ -24,9 +24,11 @@ export type IElementSize = 'small' | 'medium' | 'large'; interface TElement { id: string; type: T; - conversationId?: string; - forIds?: string[]; + threadId?: string; + forId: string; + mime?: string; url?: string; + chainlitKey?: string; } interface TMessageElement extends TElement { @@ -35,43 +37,29 @@ interface TMessageElement extends TElement { } export interface IImageElement extends TMessageElement<'image'> { - content?: ArrayBuffer; size?: IElementSize; } export interface IAvatarElement extends TElement<'avatar'> { name: string; - content?: ArrayBuffer; } export interface ITextElement extends TMessageElement<'text'> { - content?: string; language?: string; } -export interface IPdfElement extends TMessageElement<'pdf'> { - content?: string; -} +export interface IPdfElement extends TMessageElement<'pdf'> {} -export interface IAudioElement extends TMessageElement<'audio'> { - content?: ArrayBuffer; -} +export interface IAudioElement extends TMessageElement<'audio'> {} export interface IVideoElement extends TMessageElement<'video'> { - content?: ArrayBuffer; size?: IElementSize; } export interface IFileElement extends TMessageElement<'file'> { type: 'file'; - mime?: string; - content?: ArrayBuffer; } -export interface IPlotlyElement extends TMessageElement<'plotly'> { - content?: string; -} +export interface IPlotlyElement extends TMessageElement<'plotly'> {} -export interface ITasklistElement extends TElement<'tasklist'> { - content?: string; -} +export interface ITasklistElement extends TElement<'tasklist'> {} diff --git a/libs/react-client/src/types/feedback.ts b/libs/react-client/src/types/feedback.ts new file mode 100644 index 0000000000..7c8b764875 --- /dev/null +++ b/libs/react-client/src/types/feedback.ts @@ -0,0 +1,7 @@ +export interface IFeedback { + id?: string; + forId?: string; + comment?: string; + strategy: 'BINARY'; + value: number; +} diff --git a/libs/react-client/src/types/file.ts b/libs/react-client/src/types/file.ts index e25d18b4e5..af3fde1eba 100644 --- a/libs/react-client/src/types/file.ts +++ b/libs/react-client/src/types/file.ts @@ -1,5 +1,5 @@ import { IAction } from './action'; -import { IMessage } from './message'; +import { IStep } from './step'; export interface FileSpec { accept?: string[] | Record; @@ -11,16 +11,12 @@ export interface ActionSpec { keys?: string[]; } -export interface IFileResponse { - name: string; - path?: string; - size: number; - type: string; - content: ArrayBuffer; +export interface IFileRef { + id: string; } export interface IAsk { - callback: (payload: IMessage | IFileResponse[] | IAction) => void; + callback: (payload: IStep | IFileRef[] | IAction) => void; spec: { type: 'text' | 'file' | 'action'; timeout: number; diff --git a/libs/react-client/src/types/generation.ts b/libs/react-client/src/types/generation.ts new file mode 100644 index 0000000000..f782fcc849 --- /dev/null +++ b/libs/react-client/src/types/generation.ts @@ -0,0 +1,53 @@ +export type GenerationMessageRole = + | 'system' + | 'assistant' + | 'user' + | 'function' + | 'tool'; +export type ILLMSettings = Record; + +export interface IGenerationMessage { + template?: string; + formatted?: string; + templateFormat: string; + role: GenerationMessageRole; + name?: string; +} + +export interface IFunction { + name: string; + description: string; + parameters: { + required: string[]; + properties: Record; + }; +} + +export interface ITool { + type: string; + function: IFunction; +} + +export interface IBaseGeneration { + provider: string; + id?: string; + inputs?: Record; + completion?: string; + settings?: ILLMSettings; + functions?: IFunction[]; + tokenCount?: number; +} + +export interface ICompletionGeneration extends IBaseGeneration { + type: 'COMPLETION'; + template?: string; + formatted?: string; + templateFormat: string; +} + +export interface IChatGeneration extends IBaseGeneration { + type: 'CHAT'; + messages?: IGenerationMessage[]; +} + +export type IGeneration = ICompletionGeneration | IChatGeneration; diff --git a/libs/react-client/src/types/history.ts b/libs/react-client/src/types/history.ts new file mode 100644 index 0000000000..c243a879d7 --- /dev/null +++ b/libs/react-client/src/types/history.ts @@ -0,0 +1,15 @@ +import { IThread } from 'src/types'; + +import { IPageInfo } from '..'; + +export type UserInput = { + content: string; + createdAt: number; +}; + +export type ThreadHistory = { + threads?: IThread[]; + currentThreadId?: string; + timeGroupedThreads?: { [key: string]: IThread[] }; + pageInfo?: IPageInfo; +}; diff --git a/libs/react-client/src/types/index.ts b/libs/react-client/src/types/index.ts index 46a4819763..378c223de0 100644 --- a/libs/react-client/src/types/index.ts +++ b/libs/react-client/src/types/index.ts @@ -1,7 +1,9 @@ export * from './action'; export * from './element'; export * from './file'; -export * from './message'; +export * from './feedback'; +export * from './step'; export * from './user'; -export * from './conversation'; -export * from './chatHistory'; +export * from './thread'; +export * from './generation'; +export * from './history'; diff --git a/libs/react-client/src/types/message.ts b/libs/react-client/src/types/message.ts deleted file mode 100644 index 7ccac19201..0000000000 --- a/libs/react-client/src/types/message.ts +++ /dev/null @@ -1,59 +0,0 @@ -import { IFileElement } from './element'; - -interface IBaseTemplate { - template?: string; - formatted?: string; - template_format: string; -} - -export type PromptMessageRole = 'system' | 'assistant' | 'user' | 'function'; -export type ILLMSettings = Record; - -export interface IPromptMessage extends IBaseTemplate { - role: PromptMessageRole; - name?: string; -} - -export interface IFunction { - name: string; - description: string; - parameters: { - required: string[]; - properties: Record; - }; -} - -export interface ITool { - type: string; - function: IFunction; -} - -export interface IPrompt extends IBaseTemplate { - provider: string; - id?: string; - inputs?: Record; - completion?: string; - settings?: ILLMSettings; - functions?: IFunction[]; - messages?: IPromptMessage[]; -} - -export interface IMessage { - author: string; - authorIsUser?: boolean; - content?: string; - createdAt: number | string; - disableHumanFeedback?: boolean; - elements?: IFileElement[]; - humanFeedback?: number; - humanFeedbackComment?: string; - id: string; - indent?: number; - isError?: boolean; - language?: string; - parentId?: string; - prompt?: IPrompt; - streaming?: boolean; - waitForAnswer?: boolean; - subMessages?: IMessage[]; -} diff --git a/libs/react-client/src/types/step.ts b/libs/react-client/src/types/step.ts new file mode 100644 index 0000000000..87cde8723c --- /dev/null +++ b/libs/react-client/src/types/step.ts @@ -0,0 +1,38 @@ +import { IFeedback } from './feedback'; +import { IGeneration } from './generation'; + +type StepType = + | 'assistant_message' + | 'user_message' + | 'system_message' + | 'run' + | 'tool' + | 'llm' + | 'embedding' + | 'retrieval' + | 'rerank' + | 'undefined'; + +export interface IStep { + id: string; + name: string; + type: StepType; + threadId?: string; + parentId?: string; + isError?: boolean; + showInput?: boolean | string; + waitForAnswer?: boolean; + input?: string; + output: string; + createdAt: number | string; + start?: number | string; + end?: number | string; + disableFeedback?: boolean; + feedback?: IFeedback; + language?: string; + streaming?: boolean; + generation?: IGeneration; + steps?: IStep[]; + //legacy + indent?: number; +} diff --git a/libs/react-client/src/types/thread.ts b/libs/react-client/src/types/thread.ts new file mode 100644 index 0000000000..c829fb1f34 --- /dev/null +++ b/libs/react-client/src/types/thread.ts @@ -0,0 +1,12 @@ +import { IElement } from './element'; +import { IStep } from './step'; +import { IUser } from './user'; + +export interface IThread { + id: string; + createdAt: number | string; + user?: IUser; + metadata?: Record; + steps: IStep[]; + elements?: IElement[]; +} diff --git a/libs/react-client/src/types/user.ts b/libs/react-client/src/types/user.ts index cedfd1d838..ad349c9e14 100644 --- a/libs/react-client/src/types/user.ts +++ b/libs/react-client/src/types/user.ts @@ -1,17 +1,18 @@ -export type Role = 'USER' | 'ADMIN' | 'OWNER' | 'ANONYMOUS'; - -export type AppUserProvider = +export type AuthProvider = | 'credentials' | 'header' | 'github' | 'google' | 'azure-ad'; -export interface IAppUser { - id: string; - username: string; - role: Role; +export interface IUserMetadata extends Record { tags?: string[]; image?: string; - provider?: AppUserProvider; + provider?: AuthProvider; +} + +export interface IUser { + id: string; + identifier: string; + metadata: IUserMetadata; } diff --git a/libs/react-client/src/useChatData.ts b/libs/react-client/src/useChatData.ts index addc5a689e..7a5b496395 100644 --- a/libs/react-client/src/useChatData.ts +++ b/libs/react-client/src/useChatData.ts @@ -1,5 +1,4 @@ import { useRecoilValue } from 'recoil'; -import { IMessage } from 'src/types'; import { actionState, @@ -14,10 +13,6 @@ import { tasklistState } from './state'; -export interface IMessageUpdate extends IMessage { - newId?: string; -} - export interface IToken { id: number | string; token: string; diff --git a/libs/react-client/src/useChatInteract.ts b/libs/react-client/src/useChatInteract.ts index eebd537dfe..483ffa1697 100644 --- a/libs/react-client/src/useChatInteract.ts +++ b/libs/react-client/src/useChatInteract.ts @@ -1,12 +1,12 @@ import { useCallback } from 'react'; import { useRecoilValue, useResetRecoilState, useSetRecoilState } from 'recoil'; import { + accessTokenState, actionState, askUserState, avatarState, chatSettingsInputsState, chatSettingsValueState, - conversationIdToResumeState, elementState, firstUserMessageState, loadingState, @@ -14,14 +14,19 @@ import { sessionIdState, sessionState, tasklistState, + threadIdToResumeState, tokenCountState } from 'src/state'; -import { IAction, IFileElement, IMessage } from 'src/types'; +import { IAction, IFileRef, IStep } from 'src/types'; import { addMessage } from 'src/utils/message'; +import { ChainlitAPI } from './api'; + const useChatInteract = () => { + const accessToken = useRecoilValue(accessTokenState); const session = useRecoilValue(sessionState); const askUser = useRecoilValue(askUserState); + const sessionId = useRecoilValue(sessionIdState); const resetChatSettings = useResetRecoilState(chatSettingsInputsState); const resetSessionId = useResetRecoilState(sessionIdState); @@ -35,7 +40,7 @@ const useChatInteract = () => { const setTasklists = useSetRecoilState(tasklistState); const setActions = useSetRecoilState(actionState); const setTokenCount = useSetRecoilState(tokenCountState); - const setIdToResume = useSetRecoilState(conversationIdToResumeState); + const setIdToResume = useSetRecoilState(threadIdToResumeState); const clear = useCallback(() => { session?.socket.emit('clear_session'); @@ -54,44 +59,75 @@ const useChatInteract = () => { }, [session]); const sendMessage = useCallback( - (message: IMessage, files?: IFileElement[]) => { + (message: IStep, fileReferences?: IFileRef[]) => { setMessages((oldMessages) => addMessage(oldMessages, message)); - session?.socket.emit('ui_message', { message, files }); + session?.socket.emit('ui_message', { message, fileReferences }); }, - [session] + [session?.socket] ); const replyMessage = useCallback( - (message: IMessage) => { + (message: IStep) => { if (askUser) { setMessages((oldMessages) => addMessage(oldMessages, message)); askUser.callback(message); } }, - [askUser, session] + [askUser] ); const updateChatSettings = useCallback( (values: object) => { session?.socket.emit('chat_settings_change', values); }, - [session] + [session?.socket] ); const stopTask = useCallback(() => { setLoading(false); session?.socket.emit('stop'); - }, [session]); + }, [session?.socket]); const callAction = useCallback( (action: IAction) => { - session?.socket.emit('action_call', action); + const socket = session?.socket; + if (!socket) return; + + const promise = new Promise<{ + id: string; + status: boolean; + response?: string; + }>((resolve, reject) => { + socket.once('action_response', (response) => { + if (response.status) { + resolve(response); + } else { + reject(response); + } + }); + }); + + socket.emit('action_call', action); + + return promise; + }, + [session?.socket] + ); + + const uploadFile = useCallback( + ( + client: ChainlitAPI, + file: File, + onProgress: (progress: number) => void + ) => { + return client.uploadFile(file, onProgress, sessionId, accessToken); }, - [session] + [sessionId, accessToken] ); return { + uploadFile, callAction, clear, replyMessage, diff --git a/libs/react-client/src/useChatSession.ts b/libs/react-client/src/useChatSession.ts index 489c272a9c..65f1cf47c9 100644 --- a/libs/react-client/src/useChatSession.ts +++ b/libs/react-client/src/useChatSession.ts @@ -14,7 +14,6 @@ import { chatProfileState, chatSettingsInputsState, chatSettingsValueState, - conversationIdToResumeState, elementState, firstUserMessageState, loadingState, @@ -22,16 +21,17 @@ import { sessionIdState, sessionState, tasklistState, + threadIdToResumeState, tokenCountState } from 'src/state'; import { IAction, IAvatarElement, - IConversation, IElement, - IMessage, IMessageElement, - ITasklistElement + IStep, + ITasklistElement, + IThread } from 'src/types'; import { addMessage, @@ -40,7 +40,8 @@ import { updateMessageContentById } from 'src/utils/message'; -import type { IMessageUpdate, IToken } from './useChatData'; +import { ChainlitAPI } from './api'; +import type { IToken } from './useChatData'; const useChatSession = () => { const sessionId = useRecoilValue(sessionIdState); @@ -59,24 +60,24 @@ const useChatSession = () => { const setChatSettingsInputs = useSetRecoilState(chatSettingsInputsState); const setTokenCount = useSetRecoilState(tokenCountState); const [chatProfile, setChatProfile] = useRecoilState(chatProfileState); - const idToResume = useRecoilValue(conversationIdToResumeState); + const idToResume = useRecoilValue(threadIdToResumeState); const _connect = useCallback( ({ - wsEndpoint, + client, userEnv, accessToken }: { - wsEndpoint: string; + client: ChainlitAPI; userEnv: Record; accessToken?: string; }) => { - const socket = io(wsEndpoint, { + const socket = io(client.httpEndpoint, { path: '/ws/socket.io', extraHeaders: { Authorization: accessToken || '', 'X-Chainlit-Session-Id': sessionId, - 'X-Chainlit-Conversation-Id': idToResume || '', + 'X-Chainlit-Thread-Id': idToResume || '', 'user-env': JSON.stringify(userEnv), 'X-Chainlit-Chat-Profile': chatProfile || '' } @@ -111,53 +112,50 @@ const useChatSession = () => { window.location.reload(); }); - socket.on('resume_conversation', (conversation: IConversation) => { - let messages: IMessage[] = []; - for (const message of conversation.messages) { - messages = addMessage(messages, message); + socket.on('resume_thread', (thread: IThread) => { + let messages: IStep[] = []; + for (const step of thread.steps) { + messages = addMessage(messages, step); } - if (conversation.metadata?.chat_profile) { - setChatProfile(conversation.metadata?.chat_profile); + if (thread.metadata?.chat_profile) { + setChatProfile(thread.metadata?.chat_profile); } setMessages(messages); + const elements = thread.elements || []; setAvatars( - (conversation.elements as IAvatarElement[]).filter( - (e) => e.type === 'avatar' - ) + (elements as IAvatarElement[]).filter((e) => e.type === 'avatar') ); setTasklists( - (conversation.elements as ITasklistElement[]).filter( - (e) => e.type === 'tasklist' - ) + (elements as ITasklistElement[]).filter((e) => e.type === 'tasklist') ); setElements( - (conversation.elements as IMessageElement[]).filter( + (elements as IMessageElement[]).filter( (e) => ['avatar', 'tasklist'].indexOf(e.type) === -1 ) ); }); - socket.on('new_message', (message: IMessage) => { + socket.on('new_message', (message: IStep) => { setMessages((oldMessages) => addMessage(oldMessages, message)); }); - socket.on('init_conversation', (message: IMessage) => { + socket.on('init_thread', (message: IStep) => { setFirstUserMessage(message); }); - socket.on('update_message', (message: IMessageUpdate) => { + socket.on('update_message', (message: IStep) => { setMessages((oldMessages) => updateMessageById(oldMessages, message.id, message) ); }); - socket.on('delete_message', (message: IMessage) => { + socket.on('delete_message', (message: IStep) => { setMessages((oldMessages) => deleteMessageById(oldMessages, message.id) ); }); - socket.on('stream_start', (message: IMessage) => { + socket.on('stream_start', (message: IStep) => { setMessages((oldMessages) => addMessage(oldMessages, message)); }); @@ -189,31 +187,43 @@ const useChatSession = () => { }); socket.on('element', (element: IElement) => { + if (!element.url && element.chainlitKey) { + element.url = client.getElementUrl( + element.chainlitKey, + sessionId, + accessToken + ); + } + if (element.type === 'avatar') { - setAvatars((old) => [...old, element]); + setAvatars((old) => { + const index = old.findIndex((e) => e.id === element.id); + if (index === -1) { + return [...old, element]; + } else { + return [...old.slice(0, index), element, ...old.slice(index + 1)]; + } + }); } else if (element.type === 'tasklist') { - setTasklists((old) => [...old, element]); + setTasklists((old) => { + const index = old.findIndex((e) => e.id === element.id); + if (index === -1) { + return [...old, element]; + } else { + return [...old.slice(0, index), element, ...old.slice(index + 1)]; + } + }); } else { - setElements((old) => [...old, element]); - } - }); - - socket.on( - 'update_element', - (update: { id: string; forIds: string[] }) => { setElements((old) => { - const index = old.findIndex((e) => e.id === update.id); - if (index === -1) return old; - const element = old[index]; - const newElement = { ...element, forIds: update.forIds }; - return [ - ...old.slice(0, index), - newElement, - ...old.slice(index + 1) - ]; + const index = old.findIndex((e) => e.id === element.id); + if (index === -1) { + return [...old, element]; + } else { + return [...old.slice(0, index), element, ...old.slice(index + 1)]; + } }); } - ); + }); socket.on('remove_element', (remove: { id: string }) => { setElements((old) => { diff --git a/libs/react-client/src/utils/group.ts b/libs/react-client/src/utils/group.ts index 7e216e8857..ae1f32f8a4 100644 --- a/libs/react-client/src/utils/group.ts +++ b/libs/react-client/src/utils/group.ts @@ -1,7 +1,7 @@ -import { IConversation } from 'src/types'; +import { IThread } from 'src/types'; -export const groupByDate = (data: IConversation[]) => { - const groupedData: { [key: string]: IConversation[] } = {}; +export const groupByDate = (data: IThread[]) => { + const groupedData: { [key: string]: IThread[] } = {}; const today = new Date(); const yesterday = new Date(); diff --git a/libs/react-client/src/utils/message.ts b/libs/react-client/src/utils/message.ts index 59a2512da6..ef6b3d79bf 100644 --- a/libs/react-client/src/utils/message.ts +++ b/libs/react-client/src/utils/message.ts @@ -1,8 +1,9 @@ import isEqual from 'lodash/isEqual'; -import { IMessage } from 'src/types'; -const nestMessages = (messages: IMessage[]): IMessage[] => { - let nestedMessages: IMessage[] = []; +import { IStep } from '..'; + +const nestMessages = (messages: IStep[]): IStep[] => { + let nestedMessages: IStep[] = []; for (const message of messages) { nestedMessages = addMessage(nestedMessages, message); @@ -11,7 +12,7 @@ const nestMessages = (messages: IMessage[]): IMessage[] => { return nestedMessages; }; -const isLastMessage = (messages: IMessage[], index: number) => { +const isLastMessage = (messages: IStep[], index: number) => { if (messages.length - 1 === index) { return true; } @@ -29,12 +30,12 @@ const isLastMessage = (messages: IMessage[], index: number) => { // Nested messages utils -const addMessage = (messages: IMessage[], message: IMessage): IMessage[] => { +const addMessage = (messages: IStep[], message: IStep): IStep[] => { if (hasMessageById(messages, message.id)) { return updateMessageById(messages, message.id, message); - } else if (message.parentId) { + } else if ('parentId' in message && message.parentId) { return addMessageToParent(messages, message.parentId, message); - } else if (message.indent && message.indent > 0) { + } else if ('indent' in message && message.indent && message.indent > 0) { return addIndentMessage(messages, message.indent, message); } else { return [...messages, message]; @@ -42,11 +43,11 @@ const addMessage = (messages: IMessage[], message: IMessage): IMessage[] => { }; const addIndentMessage = ( - messages: IMessage[], + messages: IStep[], indent: number, - newMessage: IMessage, + newMessage: IStep, currentIndentation: number = 0 -): IMessage[] => { +): IStep[] => { const nextMessages = [...messages]; if (nextMessages.length === 0) { @@ -54,16 +55,16 @@ const addIndentMessage = ( } else { const index = nextMessages.length - 1; const msg = nextMessages[index]; - msg.subMessages = msg.subMessages || []; + msg.steps = msg.steps || []; if (currentIndentation + 1 === indent) { - msg.subMessages = [...msg.subMessages, newMessage]; + msg.steps = [...msg.steps, newMessage]; nextMessages[index] = { ...msg }; return nextMessages; } else { - msg.subMessages = addIndentMessage( - msg.subMessages, + msg.steps = addIndentMessage( + msg.steps, indent, newMessage, currentIndentation + 1 @@ -76,26 +77,20 @@ const addIndentMessage = ( }; const addMessageToParent = ( - messages: IMessage[], + messages: IStep[], parentId: string, - newMessage: IMessage -): IMessage[] => { + newMessage: IStep +): IStep[] => { const nextMessages = [...messages]; for (let index = 0; index < nextMessages.length; index++) { const msg = nextMessages[index]; if (isEqual(msg.id, parentId)) { - msg.subMessages = msg.subMessages - ? [...msg.subMessages, newMessage] - : [newMessage]; + msg.steps = msg.steps ? [...msg.steps, newMessage] : [newMessage]; nextMessages[index] = { ...msg }; - } else if (hasMessageById(nextMessages, parentId) && msg.subMessages) { - msg.subMessages = addMessageToParent( - msg.subMessages, - parentId, - newMessage - ); + } else if (hasMessageById(nextMessages, parentId) && msg.steps) { + msg.steps = addMessageToParent(msg.steps, parentId, newMessage); nextMessages[index] = { ...msg }; } } @@ -103,12 +98,12 @@ const addMessageToParent = ( return nextMessages; }; -const hasMessageById = (messages: IMessage[], messageId: string) => { +const hasMessageById = (messages: IStep[], messageId: string) => { for (const message of messages) { if (isEqual(message.id, messageId)) { return true; - } else if (message.subMessages && message.subMessages.length > 0) { - if (hasMessageById(message.subMessages, messageId)) { + } else if (message.steps && message.steps.length > 0) { + if (hasMessageById(message.steps, messageId)) { return true; } } @@ -117,23 +112,19 @@ const hasMessageById = (messages: IMessage[], messageId: string) => { }; const updateMessageById = ( - messages: IMessage[], + messages: IStep[], messageId: string, - updatedMessage: IMessage -): IMessage[] => { + updatedMessage: IStep +): IStep[] => { const nextMessages = [...messages]; for (let index = 0; index < nextMessages.length; index++) { const msg = nextMessages[index]; if (isEqual(msg.id, messageId)) { - nextMessages[index] = { subMessages: msg.subMessages, ...updatedMessage }; - } else if (hasMessageById(nextMessages, messageId) && msg.subMessages) { - msg.subMessages = updateMessageById( - msg.subMessages, - messageId, - updatedMessage - ); + nextMessages[index] = { steps: msg.steps, ...updatedMessage }; + } else if (hasMessageById(nextMessages, messageId) && msg.steps) { + msg.steps = updateMessageById(msg.steps, messageId, updatedMessage); nextMessages[index] = { ...msg }; } } @@ -141,7 +132,7 @@ const updateMessageById = ( return nextMessages; }; -const deleteMessageById = (messages: IMessage[], messageId: string) => { +const deleteMessageById = (messages: IStep[], messageId: string) => { let nextMessages = [...messages]; for (let index = 0; index < nextMessages.length; index++) { @@ -152,8 +143,8 @@ const deleteMessageById = (messages: IMessage[], messageId: string) => { ...nextMessages.slice(0, index), ...nextMessages.slice(index + 1) ]; - } else if (hasMessageById(nextMessages, messageId) && msg.subMessages) { - msg.subMessages = deleteMessageById(msg.subMessages, messageId); + } else if (hasMessageById(nextMessages, messageId) && msg.steps) { + msg.steps = deleteMessageById(msg.steps, messageId); nextMessages[index] = { ...msg }; } } @@ -162,27 +153,37 @@ const deleteMessageById = (messages: IMessage[], messageId: string) => { }; const updateMessageContentById = ( - messages: IMessage[], + messages: IStep[], messageId: number | string, updatedContent: string, isSequence: boolean -): IMessage[] => { +): IStep[] => { const nextMessages = [...messages]; for (let index = 0; index < nextMessages.length; index++) { const msg = nextMessages[index]; if (isEqual(msg.id, messageId)) { - if (isSequence) { - msg.content = updatedContent; + if ('content' in msg && msg.content !== undefined) { + if (isSequence) { + msg.content = updatedContent; + } else { + msg.content += updatedContent; + } } else { - msg.content += updatedContent; + if ('output' in msg && msg.output !== undefined) { + if (isSequence) { + msg.output = updatedContent; + } else { + msg.output += updatedContent; + } + } } nextMessages[index] = { ...msg }; - } else if (msg.subMessages) { - msg.subMessages = updateMessageContentById( - msg.subMessages, + } else if (msg.steps) { + msg.steps = updateMessageContentById( + msg.steps, messageId, updatedContent, isSequence diff --git a/libs/react-components/hooks/useUpload.tsx b/libs/react-components/hooks/useUpload.tsx index ab1cc09744..910fa96311 100644 --- a/libs/react-components/hooks/useUpload.tsx +++ b/libs/react-components/hooks/useUpload.tsx @@ -1,4 +1,4 @@ -import { useCallback, useState } from 'react'; +import { useCallback } from 'react'; import { DropzoneOptions, FileRejection, @@ -6,18 +6,16 @@ import { useDropzone } from 'react-dropzone'; -import type { FileSpec, IFileResponse } from 'client-types/'; +import type { FileSpec } from 'client-types/'; interface useUploadProps { onError?: (error: string) => void; - onResolved: (payloads: IFileResponse[]) => void; + onResolved: (payloads: FileWithPath[]) => void; options?: DropzoneOptions; spec: FileSpec; } const useUpload = ({ onError, onResolved, options, spec }: useUploadProps) => { - const [uploading, setUploading] = useState(false); - const onDrop: DropzoneOptions['onDrop'] = useCallback( (acceptedFiles: FileWithPath[], fileRejections: FileRejection[]) => { if (fileRejections.length > 0) { @@ -26,39 +24,7 @@ const useUpload = ({ onError, onResolved, options, spec }: useUploadProps) => { } if (!acceptedFiles.length) return; - setUploading(true); - - const promises = acceptedFiles.map((file) => { - return new Promise((resolve, reject) => { - const reader = new FileReader(); - reader.onload = function (e) { - const rawData = e.target?.result; - const payload: IFileResponse = { - path: file.path, - name: file.name, - size: file.size, - type: file.type, - content: rawData as ArrayBuffer - }; - resolve(payload); - }; - reader.onerror = function () { - if (!reader.error) return; - reject(reader.error.message); - }; - reader.readAsArrayBuffer(file); - }); - }); - - Promise.all(promises) - .then((payloads) => { - onResolved(payloads); - setUploading(false); - }) - .catch((err) => { - onError && onError(err); - setUploading(false); - }); + return onResolved(acceptedFiles); }, [spec] ); @@ -84,7 +50,7 @@ const useUpload = ({ onError, onResolved, options, spec }: useUploadProps) => { ...options }); - return { getInputProps, getRootProps, isDragActive, uploading }; + return { getInputProps, getRootProps, isDragActive }; }; export { useUpload }; diff --git a/libs/react-components/src/Attachment.tsx b/libs/react-components/src/Attachment.tsx new file mode 100644 index 0000000000..b41b2702ef --- /dev/null +++ b/libs/react-components/src/Attachment.tsx @@ -0,0 +1,64 @@ +import { DefaultExtensionType, FileIcon, defaultStyles } from 'react-file-icon'; + +import Box from '@mui/material/Box'; +import Stack from '@mui/material/Stack'; +import Typography from '@mui/material/Typography'; + +interface Props { + name: string; + mime: string; + children?: React.ReactNode; +} + +const Attachment = ({ name, mime, children }: Props) => { + const extension = ( + mime ? mime.split('/').pop() : 'txt' + ) as DefaultExtensionType; + + return ( + + {children} + `1px solid ${theme.palette.primary.main}`, + color: (theme) => + theme.palette.mode === 'light' + ? theme.palette.primary.main + : theme.palette.text.primary, + background: (theme) => + theme.palette.mode === 'light' + ? theme.palette.primary.light + : theme.palette.primary.dark + }} + > + + + + + {name} + + + + ); +}; +export { Attachment }; diff --git a/libs/react-components/src/Attachments.tsx b/libs/react-components/src/Attachments.tsx deleted file mode 100644 index 9e9079249a..0000000000 --- a/libs/react-components/src/Attachments.tsx +++ /dev/null @@ -1,49 +0,0 @@ -import React from 'react'; - -import Stack from '@mui/material/Stack'; - -import type { IFileElement } from 'client-types/'; - -import { FileElement } from './elements'; - -interface AttachmentsProps { - fileElements: IFileElement[]; - setFileElements: React.Dispatch>; -} - -const Attachments = ({ - fileElements, - setFileElements -}: AttachmentsProps): JSX.Element => { - if (fileElements.length === 0) return <>; - - const onRemove = (index: number) => { - setFileElements((prev) => - prev.filter((_, prevIndex) => index !== prevIndex) - ); - }; - - return ( - - {fileElements.map((fileElement, index) => { - return ( - onRemove(index)} - /> - ); - })} - - ); -}; - -export { Attachments }; diff --git a/libs/react-components/src/elements/Audio.tsx b/libs/react-components/src/elements/Audio.tsx index 1d98479784..54e4f0a403 100644 --- a/libs/react-components/src/elements/Audio.tsx +++ b/libs/react-components/src/elements/Audio.tsx @@ -3,12 +3,12 @@ import { grey } from 'theme/palette'; import Box from '@mui/material/Box'; import useTheme from '@mui/material/styles/useTheme'; -import type { IAudioElement } from 'client-types/'; +import { type IAudioElement } from 'client-types/'; const AudioElement = ({ element }: { element: IAudioElement }) => { const theme = useTheme(); - if (!element.url && !element.content) { + if (!element.url) { return null; } @@ -24,15 +24,7 @@ const AudioElement = ({ element }: { element: IAudioElement }) => { > {element.name} - + ); }; diff --git a/libs/react-components/src/elements/Avatar.tsx b/libs/react-components/src/elements/Avatar.tsx index 96b8565972..91cf33206c 100644 --- a/libs/react-components/src/elements/Avatar.tsx +++ b/libs/react-components/src/elements/Avatar.tsx @@ -1,22 +1,25 @@ import Avatar from '@mui/material/Avatar'; import Tooltip from '@mui/material/Tooltip'; -import type { IAvatarElement } from 'client-types/'; +import { type IAvatarElement } from 'client-types/'; interface Props { element: IAvatarElement; author: string; } -const AvatarElement = ({ element, author }: Props) => ( - - - - - -); +const AvatarElement = ({ element, author }: Props) => { + if (!element.url) { + return null; + } + + return ( + + + + + + ); +}; export { AvatarElement }; diff --git a/libs/react-components/src/elements/File.tsx b/libs/react-components/src/elements/File.tsx index fcbaf4b355..d607256528 100644 --- a/libs/react-components/src/elements/File.tsx +++ b/libs/react-components/src/elements/File.tsx @@ -1,138 +1,27 @@ -import { useState } from 'react'; -import { FileIcon, defaultStyles } from 'react-file-icon'; +import { Attachment } from 'src/Attachment'; -import Close from '@mui/icons-material/Close'; -import Box from '@mui/material/Box'; -import IconButton from '@mui/material/IconButton'; import Link from '@mui/material/Link'; -import Stack from '@mui/material/Stack'; -import Typography from '@mui/material/Typography'; -import type { IFileElement } from 'client-types/'; +import { type IFileElement } from 'client-types/'; -const FileElement = ({ - element, - onRemove -}: { - element: IFileElement; - onRemove?: () => void; -}) => { - const [isHovered, setIsHovered] = useState(false); - - if (!element.url && !element.content) { +const FileElement = ({ element }: { element: IFileElement }) => { + if (!element.url) { return null; } - let children; - const mime = element.mime ? element.mime.split('/').pop()! : 'file'; - - if (element.mime?.includes('image') && !element.mime?.includes('svg')) { - children = ( - - ); - } else { - children = ( - `1px solid ${theme.palette.primary.main}`, - color: (theme) => - theme.palette.mode === 'light' - ? theme.palette.primary.main - : theme.palette.text.primary, - background: (theme) => - theme.palette.mode === 'light' - ? theme.palette.primary.light - : theme.palette.primary.dark - }} - > - svg': { - height: '30px' - } - }} - > - - - - {element.name} - - - ); - } - - const fileElement = ( - setIsHovered(true)} - onMouseLeave={() => setIsHovered(false)} - height={50} + return ( + - {isHovered && onRemove ? ( - `1px solid ${theme.palette.divider}`, - '&:hover': { - backgroundColor: 'background.default' - } - }} - onClick={onRemove} - > - - - ) : null} - {children} - + + ); - - if (!onRemove) { - return ( - - {fileElement} - - ); - } else { - return fileElement; - } }; export { FileElement }; diff --git a/libs/react-components/src/elements/Image.tsx b/libs/react-components/src/elements/Image.tsx index 05015fa8d9..c43aa6eed4 100644 --- a/libs/react-components/src/elements/Image.tsx +++ b/libs/react-components/src/elements/Image.tsx @@ -1,4 +1,4 @@ -import type { IImageElement } from 'client-types/'; +import { type IImageElement } from 'client-types/'; import { FrameElement } from './Frame'; @@ -18,17 +18,19 @@ const handleImageClick = (name: string, src: string) => { }; const ImageElement = ({ element }: Props) => { - const src = element.url || URL.createObjectURL(new Blob([element.content!])); + if (!element.url) { + return null; + } return ( { if (element.display === 'inline') { const name = `${element.name}.png`; - handleImageClick(name, src); + handleImageClick(name, element.url!); } }} style={{ diff --git a/libs/react-components/src/elements/InlinedImageList.tsx b/libs/react-components/src/elements/InlinedImageList.tsx index e3fc4924ea..e474dabe0b 100644 --- a/libs/react-components/src/elements/InlinedImageList.tsx +++ b/libs/react-components/src/elements/InlinedImageList.tsx @@ -8,7 +8,10 @@ interface Props { } const InlinedImageList = ({ items }: Props) => ( - elements={items} renderElement={ImageElement} /> + + elements={items} + renderElement={(ctx) => } + /> ); export { InlinedImageList }; diff --git a/libs/react-components/src/elements/InlinedVideoList.tsx b/libs/react-components/src/elements/InlinedVideoList.tsx index 1f550b3871..6cfd6bc783 100644 --- a/libs/react-components/src/elements/InlinedVideoList.tsx +++ b/libs/react-components/src/elements/InlinedVideoList.tsx @@ -8,7 +8,10 @@ interface Props { } const InlinedVideoList = ({ items }: Props) => ( - elements={items} renderElement={VideoElement} /> + + elements={items} + renderElement={(ctx) => } + /> ); export { InlinedVideoList }; diff --git a/libs/react-components/src/elements/PDF.tsx b/libs/react-components/src/elements/PDF.tsx index 46e8c31960..0be4e2c466 100644 --- a/libs/react-components/src/elements/PDF.tsx +++ b/libs/react-components/src/elements/PDF.tsx @@ -1,23 +1,18 @@ -import type { IPdfElement } from 'client-types/'; +import { type IPdfElement } from 'client-types/'; interface Props { element: IPdfElement; } const PDFElement = ({ element }: Props) => { - if (!element.url && !element.content) { + if (!element.url) { return null; } return (