diff --git a/cypress/e2e/action/main.py b/cypress/e2e/action/main.py index b128d3151c..b36ad18b8c 100644 --- a/cypress/e2e/action/main.py +++ b/cypress/e2e/action/main.py @@ -1,13 +1,21 @@ import chainlit as cl -@cl.action("action1") +@cl.action("test action") def on_action(): - cl.send_message("Executed action 1!") + cl.send_message("Executed test action!") + + +@cl.action("removable action") +def on_action(action: cl.Action): + cl.send_message("Executed removable action!") + action.remove() @cl.on_chat_start def main(): - cl.send_message("Hello, here is a clickable action!") - cl.send_action(name="action1", trigger="clickable action", - description="Click on this to run action1!") + actions = [ + cl.Action(name="test action", value="test"), + cl.Action(name="removable action", value="test"), + ] + cl.send_message("Hello, this is a test message!", actions=actions) diff --git a/cypress/e2e/action/spec.cy.ts b/cypress/e2e/action/spec.cy.ts index 34799f70e0..e73fdcc8bb 100644 --- a/cypress/e2e/action/spec.cy.ts +++ b/cypress/e2e/action/spec.cy.ts @@ -7,14 +7,31 @@ describe("Action", () => { cy.wait(["@settings"]); }); - it("should correcly execute the action", () => { - cy.get(".message").should("have.length", 1); - cy.get("#action-action1").should("exist"); - cy.get("#action-action1").click(); + it("should correcly execute actions", () => { + cy.get(".message") + .should("have.length", 1) + .eq(0) + .get("#action-test-action") + .should("exist") + .click(); cy.wait(["@action"]); - const messages = cy.get(".message"); - messages.should("have.length", 2); - messages.eq(1).should("contain", "Executed action 1!"); + cy.get(".message").should("have.length", 2); + cy.get(".message").eq(1).should("contain", "Executed test action!"); + + cy.get(".message") + .eq(0) + .get("#action-removable-action") + .should("exist") + .click(); + cy.wait(["@action"]); + + cy.get(".message").should("have.length", 3); + cy.get(".message").eq(2).should("contain", "Executed removable action!"); + + cy.get(".message") + .eq(0) + .get("#action-removable-action") + .should("not.exist"); }); }); diff --git a/cypress/e2e/elements/main.py b/cypress/e2e/elements/main.py deleted file mode 100644 index 9a31061e60..0000000000 --- a/cypress/e2e/elements/main.py +++ /dev/null @@ -1,13 +0,0 @@ -import chainlit as cl - - -@cl.on_chat_start -def start(): - cl.send_local_image(path="./cat.jpeg", name="image1", display="inline") - cl.send_text(text="Here is a side text document", name="text1", display="side") - cl.send_text(text="Here is a page text document", name="text2", display="page") - - msg = "Here is image1, a nice image of a cat! As well as text1 and text2!" - cl.send_message( - content=msg, - ) diff --git a/cypress/e2e/elements/.chainlit/config.toml b/cypress/e2e/global_elements/.chainlit/config.toml similarity index 100% rename from cypress/e2e/elements/.chainlit/config.toml rename to cypress/e2e/global_elements/.chainlit/config.toml diff --git a/cypress/e2e/elements/cat.jpeg b/cypress/e2e/global_elements/cat.jpeg similarity index 100% rename from cypress/e2e/elements/cat.jpeg rename to cypress/e2e/global_elements/cat.jpeg diff --git a/cypress/e2e/global_elements/main.py b/cypress/e2e/global_elements/main.py new file mode 100644 index 0000000000..a66d7d0b76 --- /dev/null +++ b/cypress/e2e/global_elements/main.py @@ -0,0 +1,12 @@ +import chainlit as cl + + +@cl.on_chat_start +def start(): + cl.LocalImage(path="./cat.jpeg", name="image1", display="inline").send() + cl.Text(text="Here is a side text document", name="text1", display="side").send() + cl.Text(text="Here is a page text document", name="text2", display="page").send() + + cl.send_message( + content="Here is image1, a nice image of a cat! As well as text1 and text2!", + ) diff --git a/cypress/e2e/elements/spec.cy.ts b/cypress/e2e/global_elements/spec.cy.ts similarity index 96% rename from cypress/e2e/elements/spec.cy.ts rename to cypress/e2e/global_elements/spec.cy.ts index 4f411bde42..e47a9e1ae0 100644 --- a/cypress/e2e/elements/spec.cy.ts +++ b/cypress/e2e/global_elements/spec.cy.ts @@ -1,4 +1,4 @@ -describe("Elements", () => { +describe("Global Elements", () => { before(() => { cy.intercept("/project/settings").as("settings"); cy.visit("http://127.0.0.1:8000"); diff --git a/cypress/e2e/langchain_postprocess/main.py b/cypress/e2e/langchain_postprocess/main.py index 766b710aa7..6002de10e4 100644 --- a/cypress/e2e/langchain_postprocess/main.py +++ b/cypress/e2e/langchain_postprocess/main.py @@ -13,4 +13,4 @@ def main(): @cl.langchain_postprocess def postprocess(output: str): - return "In the end it doesn't even matter." + cl.send_message("In the end it doesn't even matter.") diff --git a/cypress/e2e/langchain_run/main.py b/cypress/e2e/langchain_run/main.py index 2fbcd1eb7a..ad17804dd8 100644 --- a/cypress/e2e/langchain_run/main.py +++ b/cypress/e2e/langchain_run/main.py @@ -14,4 +14,4 @@ def main(): @cl.langchain_run def run(agent, input_str): res = agent("2+2") - return res["text"] + cl.send_message(res["text"]) diff --git a/cypress/e2e/scoped_elements/.chainlit/config.toml b/cypress/e2e/scoped_elements/.chainlit/config.toml new file mode 100644 index 0000000000..6bbaf53fc2 --- /dev/null +++ b/cypress/e2e/scoped_elements/.chainlit/config.toml @@ -0,0 +1,24 @@ +[project] +# Name of the app and chatbot. +name = "Chatbot" + +# If true (default), the app will be available to anonymous users (once deployed). +# If false, users will need to authenticate and be part of the project to use the app. +public = true + +# The project ID (found on https://cloud.chainlit.io). +# If provided, all the message data will be stored in the cloud. +# The project ID is required when public is set to false. +#id = "" + +# Whether to enable telemetry (default: true). No personal data is collected. +enable_telemetry = true + +# List of environment variables to be provided by each user to use the app. +user_env = [] + +# Hide the chain of thought details from the user in the UI. +hide_cot = false + +# Limit the number of requests per user. +#request_limit = "10 per day" diff --git a/cypress/e2e/scoped_elements/cat.jpeg b/cypress/e2e/scoped_elements/cat.jpeg new file mode 100644 index 0000000000..53834a3d82 Binary files /dev/null and b/cypress/e2e/scoped_elements/cat.jpeg differ diff --git a/cypress/e2e/scoped_elements/main.py b/cypress/e2e/scoped_elements/main.py new file mode 100644 index 0000000000..2b602ea7ff --- /dev/null +++ b/cypress/e2e/scoped_elements/main.py @@ -0,0 +1,20 @@ +import chainlit as cl + + +@cl.on_chat_start +def start(): + elements = [ + cl.LocalImage(path="./cat.jpeg", name="image1", display="inline"), + cl.Text(text="Here is a side text document", name="text1", display="side"), + cl.Text(text="Here is a page text document", name="text2", display="page"), + ] + + # Element should not be inlined or referenced + cl.send_message( + content="Here is image1, a nice image of a cat! As well as text1 and text2!", + ) + # Image should be inlined even if not referenced + cl.send_message( + content="Here a nice image of a cat! As well as text1 and text2!", + elements=elements, + ) diff --git a/cypress/e2e/scoped_elements/spec.cy.ts b/cypress/e2e/scoped_elements/spec.cy.ts new file mode 100644 index 0000000000..a2f304a083 --- /dev/null +++ b/cypress/e2e/scoped_elements/spec.cy.ts @@ -0,0 +1,17 @@ +describe("Scoped Elements", () => { + before(() => { + cy.intercept("/project/settings").as("settings"); + cy.visit("http://127.0.0.1:8000"); + cy.wait(["@settings"]); + }); + + it("should be able to display inlined, side and page elements", () => { + cy.get(".message").should("have.length", 2); + + cy.get(".message").eq(0).find(".inlined-image").should("have.length", 0); + cy.get(".message").eq(0).find(".element-link").should("have.length", 0); + + cy.get(".message").eq(1).find(".inlined-image").should("have.length", 1); + cy.get(".message").eq(1).find(".element-link").should("have.length", 2); + }); +}); diff --git a/src/chainlit/__init__.py b/src/chainlit/__init__.py index ba09b1fd9c..c1977767a1 100644 --- a/src/chainlit/__init__.py +++ b/src/chainlit/__init__.py @@ -8,21 +8,21 @@ monkey.patch() from chainlit.sdk import get_sdk -from chainlit.user_session import user_session from chainlit.config import config from chainlit.types import ( - ElementDisplay, LLMSettings, AskSpec, AskFileSpec, AskFileResponse, AskResponse, - Action, ) from chainlit.telemetry import trace from chainlit.version import __version__ from chainlit.logger import logger from chainlit.server import socketio +from chainlit.action import Action +from chainlit.element import LocalImage, RemoteImage, Text +from chainlit.user_session import user_session from typing import Callable, Any, List, Union from dotenv import load_dotenv import inspect @@ -73,55 +73,6 @@ def wrapper(*args): return wrapper -@trace -def send_text(text: str, name: str, display: ElementDisplay = "side"): - """ - Send a text element to the chatbot UI. - If a project ID is configured, the element will be uploaded to the cloud storage. - - Args: - text (str): The content of the text element. - name (str): The name of the text element to be displayed in the UI. - display (ElementDisplay, optional): Determines how the element should be displayed in the UI. - Choices are "side" (default) or "inline" or "page". - """ - sdk = get_sdk() - if sdk: - sdk.send_text(text, name, display) - - -@trace -def send_local_image(path: str, name: str, display: ElementDisplay = "side"): - """ - Send a local image to the chatbot UI. - If a project ID is configured, the image will be uploaded to the cloud storage. - - Args: - path (str): The local file path of the image. - name (str): The name of the image to be displayed in the UI. - display (ElementDisplay, optional): Determines how the image should be displayed in the UI. - Choices are "side" (default) or "inline" or "page". - """ - sdk = get_sdk() - if sdk: - sdk.send_local_image(path, name, display) - - -@trace -def send_image(url: str, name: str, display: ElementDisplay = "side"): - """ - Send an image to the chatbot UI. - Args: - url (str): The URL of the image. - name (str): The name of the image to be displayed in the UI. - display (ElementDisplay, optional): Determines how the image should be displayed in the UI. - Choices are "side" (default) or "inline" or "page". - """ - sdk = get_sdk() - if sdk: - sdk.send_image(url, name, display) - - @trace def send_message( content: str, @@ -131,6 +82,8 @@ def send_message( indent=0, llm_settings: LLMSettings = None, end_stream=False, + actions: List[Action] = [], + elements: List[Union[LocalImage, RemoteImage, Text]] = [], ): """ Send a message to the chatbot UI. @@ -144,10 +97,13 @@ def send_message( indent (int, optional): If positive, the message will be nested in the UI. llm_settings (LLMSettings, optional): Settings of the LLM used to generate the prompt. This is useful for debug purposes in the prompt playground. end_stream (bool, optional): Pass True if this message was streamed. + + Returns: + str: The message ID. """ sdk = get_sdk() if sdk: - sdk.send_message( + msg_id = sdk.send_message( author=author, content=content, prompt=prompt, @@ -157,6 +113,12 @@ def send_message( end_stream=end_stream, ) + for action in actions: + action.send(for_id=msg_id) + + for element in elements: + element.send(for_id=msg_id) + @trace def send_error_message(content: str, author=config.chatbot_name, indent=0): @@ -237,20 +199,6 @@ def ask_for_file( return None -@trace -def send_action(name: str, trigger: str, description=""): - """ - Send an action to the chatbot UI. - Args: - name (str): The name of the action to send. - trigger (str): The text that should trigger the action when clicked. - description (str, optional): The description of the action. Defaults to "". - """ - sdk = get_sdk() - if sdk: - sdk.send_action(name=name, trigger=trigger, description=description) - - @trace def start_stream( author=config.chatbot_name, @@ -301,7 +249,6 @@ def langchain_factory(func: Callable) -> Callable: Returns: Callable[[], Any]: The decorated factory function. """ - from chainlit.config import config config.lc_factory = wrap_user_function(func, with_task=True) return func @@ -311,7 +258,8 @@ def langchain_factory(func: Callable) -> Callable: def langchain_postprocess(func: Callable[[Any], str]) -> Callable: """ Useful to post process the response a LangChain object instantiated with @langchain_factory. - The decorated function takes the raw output of the LangChain object and return a string as the final response. + The decorated function takes the raw output of the LangChain object as input. + The response will NOT be automatically sent to the UI, you need to call send_message. Args: func (Callable[[Any], str]): The post-processing function to apply after generating a response. Takes the response as parameter. @@ -319,7 +267,6 @@ def langchain_postprocess(func: Callable[[Any], str]) -> Callable: Returns: Callable[[Any], str]: The decorated post-processing function. """ - from chainlit.config import config config.lc_postprocess = wrap_user_function(func) return func @@ -337,7 +284,6 @@ def on_message(func: Callable) -> Callable: Returns: Callable[[str], Any]: The decorated on_message function. """ - from chainlit.config import config config.on_message = wrap_user_function(func) return func @@ -348,14 +294,14 @@ def langchain_run(func: Callable[[Any, str], str]) -> Callable: """ Useful to override the default behavior of the LangChain object instantiated with @langchain_factory. Use when your agent run method has custom parameters. - This function should return a string as the final response. + Takes the LangChain agent and the user input as parameters. + The response will NOT be automatically sent to the UI, you need to call send_message. Args: func (Callable[[Any, str], str]): The function to be called when a new message is received. Takes the agent and user input as parameters and returns the output string. Returns: Callable[[Any, str], Any]: The decorated function. """ - from chainlit.config import config config.lc_run = wrap_user_function(func) return func @@ -371,7 +317,6 @@ def langchain_rename(func: Callable[[str], str]) -> Callable[[str], str]: Returns: Callable[[Any, str], Any]: The decorated function. """ - from chainlit.config import config config.lc_rename = wrap_user_function(func) return func @@ -388,7 +333,6 @@ def on_chat_start(func: Callable) -> Callable: Returns: Callable[], Any]: The decorated hook. """ - from chainlit.config import config config.on_chat_start = wrap_user_function(func, with_task=True) return func @@ -405,7 +349,6 @@ def on_stop(func: Callable) -> Callable: Returns: Callable[[], Any]: The decorated stop hook. """ - from chainlit.config import config config.on_stop = wrap_user_function(func) return func @@ -413,12 +356,13 @@ def on_stop(func: Callable) -> Callable: def action(name: str) -> Callable: """ - Callback to call when an action is triggered in the UI. + Callback to call when an action is clicked in the UI. + + Args: + func (Callable[[Action], Any]): The action callback to exexute. First parameter is the action. """ def decorator(func: Callable[[Action], Any]): - from chainlit.config import config - config.action_callbacks[name] = wrap_user_function(func, with_task=True) return func @@ -432,3 +376,26 @@ def sleep(duration: int): duration (int): The duration in seconds. """ return socketio.sleep(duration) + + +__all__ = [ + "user_session", + "Action", + "LocalImage", + "RemoteImage", + "Text", + "send_message", + "send_error_message", + "ask_for_input", + "ask_for_file", + "start_stream", + "send_token", + "langchain_factory", + "langchain_postprocess", + "langchain_run", + "langchain_rename", + "on_chat_start", + "on_stop", + "action", + "sleep", +] diff --git a/src/chainlit/action.py b/src/chainlit/action.py new file mode 100644 index 0000000000..b641f80b71 --- /dev/null +++ b/src/chainlit/action.py @@ -0,0 +1,29 @@ +from pydantic.dataclasses import dataclass +from dataclasses_json import dataclass_json +from chainlit.sdk import get_emit +from chainlit.telemetry import trace_event + + +@dataclass_json +@dataclass +class Action: + name: str + value: str + description: str = "" + forId: str = None + + def __post_init__(self) -> None: + trace_event(f"init {self.__class__.__name__}") + + def send(self, for_id: str): + emit = get_emit() + if emit: + trace_event(f"send {self.__class__.__name__}") + self.forId = for_id + emit("action", self.to_dict()) + + def remove(self): + emit = get_emit() + if emit: + trace_event(f"remove {self.__class__.__name__}") + emit("remove_action", self.to_dict()) diff --git a/src/chainlit/client.py b/src/chainlit/client.py index 7f67f9f928..3afd581850 100644 --- a/src/chainlit/client.py +++ b/src/chainlit/client.py @@ -20,13 +20,18 @@ def create_message(self, variables: Dict[str, Any]) -> int: pass @abstractmethod - def upload_element(self, ext: str, content: bytes) -> int: + def upload_element(self, content: bytes) -> int: pass @abstractmethod def create_element( - self, conversation_id: str, type: ElementType, url: str, name: str, display: str - ) -> int: + self, + type: ElementType, + url: str, + name: str, + display: str, + for_id: str = None, + ) -> Dict[str, Any]: pass @@ -95,18 +100,19 @@ def create_message(self, variables: Dict[str, Any]) -> int: return int(res["data"]["createMessage"]["id"]) def create_element( - self, type: ElementType, url: str, name: str, display: str + self, type: ElementType, url: str, name: str, display: str, for_id: str = None ) -> Dict[str, Any]: c_id = self.get_conversation_id() mutation = """ - mutation ($conversationId: ID!, $type: String!, $url: String!, $name: String!, $display: String!) { - createElement(conversationId: $conversationId, type: $type, url: $url, name: $name, display: $display) { + mutation ($conversationId: ID!, $type: String!, $url: String!, $name: String!, $display: String!, $forId: String) { + createElement(conversationId: $conversationId, type: $type, url: $url, name: $name, display: $display, forId: $forId) { id, type, url, name, - display + display, + forId } } """ @@ -116,12 +122,13 @@ def create_element( "url": url, "name": name, "display": display, + "forId": for_id, } res = self.mutation(mutation, variables) return res["data"]["createElement"] - def upload_element(self, ext: str, content: bytes) -> str: - id = f"{uuid.uuid4()}{ext}" + def upload_element(self, content: bytes) -> str: + id = f"{uuid.uuid4()}" url = f"{self.url}/api/upload/file" body = {"projectId": self.project_id, "fileName": id} diff --git a/src/chainlit/config.py b/src/chainlit/config.py index f0ba3b3143..da607d4472 100644 --- a/src/chainlit/config.py +++ b/src/chainlit/config.py @@ -1,12 +1,14 @@ import os import sys -from typing import Optional, Literal, Any, Callable, List, Dict +from typing import Optional, Literal, Any, Callable, List, Dict, TYPE_CHECKING import tomli -from chainlit.types import Action from pydantic.dataclasses import dataclass from importlib import machinery from chainlit.logger import logger +if TYPE_CHECKING: + from chainlit.action import Action + # Get the directory the script is running from root = os.getcwd() @@ -71,7 +73,7 @@ class ChainlitConfig: # Path to the local langchain cache database lc_cache_path: str # Developer defined callbacks for each action. Key is the action name, value is the callback function. - action_callbacks: Dict[str, Callable[[Action], Any]] + action_callbacks: Dict[str, Callable[["Action"], Any]] # Directory where the Chainlit project is located root = root # Link to your github repo. This will add a github button in the UI's header. diff --git a/src/chainlit/element.py b/src/chainlit/element.py new file mode 100644 index 0000000000..99d4115945 --- /dev/null +++ b/src/chainlit/element.py @@ -0,0 +1,122 @@ +from pydantic.dataclasses import dataclass +from dataclasses_json import dataclass_json +from typing import Dict +from abc import ABC, abstractmethod +from chainlit.sdk import get_sdk, BaseClient +from chainlit.telemetry import trace_event +from chainlit.types import ElementType, ElementDisplay + + +@dataclass_json +@dataclass +class Element(ABC): + name: str + type: ElementType + display: ElementDisplay = "side" + forId: str = None + + def __post_init__(self) -> None: + trace_event(f"init {self.__class__.__name__}") + + @abstractmethod + def persist(self, client: BaseClient, for_id: str = None) -> Dict: + pass + + def before_emit(self, element: Dict) -> Dict: + return element + + def send(self, for_id: str = None): + sdk = get_sdk() + + # Cloud is enabled, upload the element to S3 + if sdk.client: + element = self.persist(sdk.client, for_id) + else: + element = self.to_dict() + if for_id: + element["forId"] = for_id + + if sdk.emit and element: + trace_event(f"send {self.__class__.__name__}") + element = self.before_emit(element) + sdk.emit("element", element) + + +@dataclass +class LocalElementBase: + content: bytes + + +@dataclass +class LocalElement(Element, LocalElementBase): + def persist(self, client: BaseClient, for_id: str = None): + url = client.upload_element(content=self.content) + if url: + element = client.create_element( + name=self.name, + url=url, + type=self.type, + display=self.display, + for_id=for_id, + ) + return element + + +@dataclass +class RemoteElementBase: + url: str + + +@dataclass +class RemoteElement(Element, RemoteElementBase): + def persist(self, client: BaseClient, for_id: str = None): + element = client.create_element( + name=self.name, + url=self.url, + type=self.type, + display=self.display, + for_id=for_id, + ) + return element + + +class LocalImage(LocalElement): + def __init__( + self, + name: str, + display: ElementDisplay = "side", + path: str = None, + content: bytes = None, + ): + if path: + with open(path, "rb") as f: + self.content = f.read() + elif content: + self.content = content + else: + raise ValueError("Must provide either path or content") + + self.name = name + self.display = display + self.type = "image" + + +class RemoteImage(RemoteElement): + def __init__(self, name: str, url: str, display: ElementDisplay = "side"): + self.name = name + self.display = display + self.type = "image" + self.url = url + + +class Text(LocalElement): + def __init__(self, name: str, text: str, display: ElementDisplay = "side"): + self.name = name + self.display = display + self.type = "text" + self.content = bytes(text, "utf-8") + + def before_emit(self, text_element): + if "content" in text_element and isinstance(text_element["content"], bytes): + text_element["content"] = text_element["content"].decode("utf-8") + return text_element diff --git a/src/chainlit/frontend/src/components/chat/message/actionRef.tsx b/src/chainlit/frontend/src/components/chat/message/actionRef.tsx index 7d15403d82..4048875ffd 100644 --- a/src/chainlit/frontend/src/components/chat/message/actionRef.tsx +++ b/src/chainlit/frontend/src/components/chat/message/actionRef.tsx @@ -30,13 +30,11 @@ export default function ActionRef({ action }: Props) { } }, [session]); + const formattedName = action.name.trim().toLowerCase().replaceAll(' ', '-'); + const id = `action-${formattedName}`; const button = ( - - {action.trigger} + + {action.name} ); if (action.description) { diff --git a/src/chainlit/frontend/src/components/chat/message/container.tsx b/src/chainlit/frontend/src/components/chat/message/container.tsx index 407a04741c..5874a7ef21 100644 --- a/src/chainlit/frontend/src/components/chat/message/container.tsx +++ b/src/chainlit/frontend/src/components/chat/message/container.tsx @@ -3,12 +3,12 @@ import { useEffect, useRef } from 'react'; import { IMessage, INestedMessage } from 'state/chat'; import { IElements } from 'state/element'; import Messages from './messages'; -import { IActions } from 'state/action'; +import { IAction } from 'state/action'; interface Props { messages: IMessage[]; elements: IElements; - actions: IActions; + actions: IAction[]; autoScroll?: boolean; setAutoSroll?: (autoScroll: boolean) => void; } diff --git a/src/chainlit/frontend/src/components/chat/message/content.tsx b/src/chainlit/frontend/src/components/chat/message/content.tsx index 3d03db867c..f6cfca088d 100644 --- a/src/chainlit/frontend/src/components/chat/message/content.tsx +++ b/src/chainlit/frontend/src/components/chat/message/content.tsx @@ -4,45 +4,50 @@ import { Typography, Link, Stack } from '@mui/material'; import { IElements } from 'state/element'; import InlinedElements from '../../element/inlined'; import { memo } from 'react'; -import { IActions } from 'state/action'; +import { IAction } from 'state/action'; import ElementRef from './elementRef'; -import ActionRef from './actionRef'; import Code from 'components/Code'; interface Props { + id?: string; content?: string; elements: IElements; - actions: IActions; + actions: IAction[]; language?: string; authorIsUser?: boolean; } -function prepareContent({ elements, actions, content, language }: Props) { - const elementNames = Object.keys(elements); +function prepareContent({ id, elements, actions, content, language }: Props) { + const filteredElements = elements.filter((e) => { + if (e.forId) { + return e.forId === id; + } + return true; + }); + + const elementNames = filteredElements.map((e) => e.name); + const elementRegexp = elementNames.length ? new RegExp(`(${elementNames.join('|')})`, 'g') : undefined; - const actionContents = Object.values(actions).map((a) => a.trigger); - const actionRegexp = actionContents.length - ? new RegExp(`(${actionContents.join('|')})`, 'g') - : undefined; + const filteredActions = actions.filter((a) => { + if (a.forId) { + return a.forId === id; + } + return true; + }); let preparedContent = content ? content.trim() : ''; - const inlinedElements: IElements = {}; + const inlinedElements: IElements = filteredElements.filter( + (e) => e.display === 'inline' + ); if (elementRegexp) { preparedContent = preparedContent.replaceAll(elementRegexp, (match) => { - if (elements[match].display === 'inline') { - inlinedElements[match] = elements[match]; - } - // spaces break markdown links. The address in the link is not used anyway - return `[${match}](${match.replaceAll(' ', '_')})`; - }); - } + const element = filteredElements.find((e) => e.name === match); + if (!element) return match; - if (actionRegexp) { - preparedContent = preparedContent.replaceAll(actionRegexp, (match) => { // spaces break markdown links. The address in the link is not used anyway return `[${match}](${match.replaceAll(' ', '_')})`; }); @@ -51,17 +56,29 @@ function prepareContent({ elements, actions, content, language }: Props) { if (language) { preparedContent = `\`\`\`${language}\n${preparedContent}\n\`\`\``; } - return { preparedContent, inlinedElements }; + return { + preparedContent, + inlinedElements, + filteredElements, + filteredActions + }; } export default memo(function MessageContent({ + id, content, elements, actions, language, authorIsUser }: Props) { - const { preparedContent, inlinedElements } = prepareContent({ + const { + preparedContent, + inlinedElements, + filteredActions, + filteredElements + } = prepareContent({ + id, content, language, elements, @@ -69,6 +86,7 @@ export default memo(function MessageContent({ }); if (!preparedContent) return null; + return ( a.trigger === name - ); + const element = filteredElements.find((e) => e.name === name); if (element) { return ; - } else if (action) { - return ; } else { return ( @@ -112,7 +125,7 @@ export default memo(function MessageContent({ {preparedContent} - + ); }); diff --git a/src/chainlit/frontend/src/components/chat/message/message.tsx b/src/chainlit/frontend/src/components/chat/message/message.tsx index f5094902a3..eae55ac8e6 100644 --- a/src/chainlit/frontend/src/components/chat/message/message.tsx +++ b/src/chainlit/frontend/src/components/chat/message/message.tsx @@ -6,14 +6,14 @@ import DetailsButton from 'components/chat/message/detailsButton'; import Messages from './messages'; import MessageContent from './content'; import UploadButton from './uploadButton'; -import { IActions } from 'state/action'; +import { IAction } from 'state/action'; import Author, { authorBoxWidth } from './author'; import Buttons from './buttons'; interface Props { message: INestedMessage; elements: IElements; - actions: IActions; + actions: IAction[]; indent: number; showAvatar?: boolean; showBorder?: boolean; @@ -66,6 +66,7 @@ const Message = ({ authorIsUser={message.authorIsUser} actions={actions} elements={elements} + id={message.id ? message.id.toString() : message.tempId} content={message.content} language={message.language} /> diff --git a/src/chainlit/frontend/src/components/chat/message/messages.tsx b/src/chainlit/frontend/src/components/chat/message/messages.tsx index 29a91bdd7e..7a881bdb03 100644 --- a/src/chainlit/frontend/src/components/chat/message/messages.tsx +++ b/src/chainlit/frontend/src/components/chat/message/messages.tsx @@ -2,12 +2,12 @@ import { INestedMessage, loadingState } from 'state/chat'; import Message from './message'; import { IElements } from 'state/element'; import { useRecoilValue } from 'recoil'; -import { IActions } from 'state/action'; +import { IAction } from 'state/action'; interface Props { messages: INestedMessage[]; elements: IElements; - actions: IActions; + actions: IAction[]; indent: number; isRunning?: boolean; } diff --git a/src/chainlit/frontend/src/components/element/inlined/action.tsx b/src/chainlit/frontend/src/components/element/inlined/action.tsx new file mode 100644 index 0000000000..dfc9a9cf8e --- /dev/null +++ b/src/chainlit/frontend/src/components/element/inlined/action.tsx @@ -0,0 +1,17 @@ +import { Stack } from '@mui/material'; +import ActionRef from 'components/chat/message/actionRef'; +import { IAction } from 'state/action'; + +interface Props { + actions: IAction[]; +} + +export default function InlinedActionList({ actions }: Props) { + return ( + + {actions.map((a) => { + return ; + })} + + ); +} diff --git a/src/chainlit/frontend/src/components/element/inlined/index.tsx b/src/chainlit/frontend/src/components/element/inlined/index.tsx index efa25883bf..78d3118440 100644 --- a/src/chainlit/frontend/src/components/element/inlined/index.tsx +++ b/src/chainlit/frontend/src/components/element/inlined/index.tsx @@ -2,36 +2,36 @@ import { ElementType, IElements } from 'state/element'; import InlinedImageList from './image'; import { Stack } from '@mui/material'; import InlinedTextList from './text'; +import { IAction } from 'state/action'; +import InlinedActionList from './action'; interface Props { - inlined: IElements; + elements: IElements; + actions: IAction[]; } -export default function InlinedElements({ inlined }: Props) { - if (!inlined || !Object.keys(inlined).length) { +export default function InlinedElements({ elements, actions }: Props) { + if (!elements.length && !actions.length) { return null; } - const images = Object.keys(inlined) - .filter((key) => inlined[key].type === ElementType.img) - .map((k) => { + const images = elements + .filter((el) => el.type === ElementType.img) + .map((el) => { return { - url: inlined[k].url, - src: - inlined[k].url || - URL.createObjectURL(new Blob([inlined[k].content!])), - title: inlined[k].name + url: el.url, + src: el.url || URL.createObjectURL(new Blob([el.content!])), + title: el.name }; }); - const texts = Object.fromEntries( - Object.entries(inlined).filter(([k, v]) => v.type === ElementType.txt) - ); + const texts = elements.filter((el) => el.type === ElementType.txt); return ( {images.length ? : null} {Object.keys(texts).length ? : null} + {actions.length ? : null} ); } diff --git a/src/chainlit/frontend/src/components/element/view.tsx b/src/chainlit/frontend/src/components/element/view.tsx index acbd8a480c..fa2be93724 100644 --- a/src/chainlit/frontend/src/components/element/view.tsx +++ b/src/chainlit/frontend/src/components/element/view.tsx @@ -20,7 +20,7 @@ const ElementView = () => { const { name } = useParams(); const elements = useRecoilValue(elementState); - const element = elements[name!]; + const element = elements.find((element) => element.name === name); if (!element) { return ; diff --git a/src/chainlit/frontend/src/components/socket.tsx b/src/chainlit/frontend/src/components/socket.tsx index e12ed84811..e68429c6a2 100644 --- a/src/chainlit/frontend/src/components/socket.tsx +++ b/src/chainlit/frontend/src/components/socket.tsx @@ -14,6 +14,7 @@ import { useAuth } from 'hooks/auth'; import io from 'socket.io-client'; import { IElement, elementState } from 'state/element'; import { IAction, actionState } from 'state/action'; +import { deepEqual } from 'helpers/object'; export default memo(function Socket() { const { accessToken, isAuthenticated, isLoading } = useAuth(); @@ -104,17 +105,19 @@ export default memo(function Socket() { }); socket.on('element', (element: IElement) => { - setElements((old) => ({ - ...old, - ...{ [element.name]: element } - })); + setElements((old) => [...old, element]); }); socket.on('action', (action: IAction) => { - setActions((old) => ({ - ...old, - ...{ [action.name]: action } - })); + setActions((old) => [...old, action]); + }); + + socket.on('remove_action', (action: IAction) => { + setActions((old) => { + const index = old.findIndex((a) => deepEqual(a, action)); + if (index === -1) return old; + return [...old.slice(0, index), ...old.slice(index + 1)]; + }); }); socket.on('token_usage', (count: number) => { diff --git a/src/chainlit/frontend/src/helpers/object.ts b/src/chainlit/frontend/src/helpers/object.ts new file mode 100644 index 0000000000..e18e2a06c6 --- /dev/null +++ b/src/chainlit/frontend/src/helpers/object.ts @@ -0,0 +1,29 @@ +function isObject(object: any) { + return object != null && typeof object === 'object'; +} + +export function deepEqual( + object1: Record, + object2: Record +) { + const keys1 = Object.keys(object1); + const keys2 = Object.keys(object2); + + if (keys1.length !== keys2.length) { + return false; + } + + for (const key of keys1) { + const val1 = object1[key]; + const val2 = object2[key]; + const areObjects = isObject(val1) && isObject(val2); + if ( + (areObjects && !deepEqual(val1, val2)) || + (!areObjects && val1 !== val2) + ) { + return false; + } + } + + return true; +} diff --git a/src/chainlit/frontend/src/hooks/clearChat.ts b/src/chainlit/frontend/src/hooks/clearChat.ts index a3cb14edbe..39a20d8092 100644 --- a/src/chainlit/frontend/src/hooks/clearChat.ts +++ b/src/chainlit/frontend/src/hooks/clearChat.ts @@ -1,10 +1,12 @@ import { useRecoilValue, useSetRecoilState } from 'recoil'; +import { actionState } from 'state/action'; import { messagesState, sessionState, tokenCountState } from 'state/chat'; import { sideViewState, elementState } from 'state/element'; export default function useClearChat() { const setMessages = useSetRecoilState(messagesState); const setElements = useSetRecoilState(elementState); + const setActions = useSetRecoilState(actionState); const setSideView = useSetRecoilState(sideViewState); const setTokenCount = useSetRecoilState(tokenCountState); const session = useRecoilValue(sessionState); @@ -13,7 +15,8 @@ export default function useClearChat() { session?.socket.disconnect(); session?.socket.connect(); setMessages([]); - setElements({}); + setElements([]); + setActions([]); setSideView(undefined); setTokenCount(0); }; diff --git a/src/chainlit/frontend/src/pages/Conversation.tsx b/src/chainlit/frontend/src/pages/Conversation.tsx index 3a368ac483..5b2e4100bb 100644 --- a/src/chainlit/frontend/src/pages/Conversation.tsx +++ b/src/chainlit/frontend/src/pages/Conversation.tsx @@ -6,7 +6,7 @@ import { gql, useQuery } from '@apollo/client'; import { IElements } from 'state/element'; import SideView from 'components/element/sideView'; import Playground from 'components/playground'; -import { IActions } from 'state/action'; +import { IAction } from 'state/action'; const ConversationQuery = gql` query ($id: ID!) { @@ -31,6 +31,7 @@ const ConversationQuery = gql` name url display + forId } } } @@ -48,10 +49,8 @@ export default function Conversation() { return null; } - const elements: IElements = {}; - data.conversation.elements.forEach((d: any) => (elements[d.name] = d)); - - const actions: IActions = {}; + const elements: IElements = data.conversation.elements; + const actions: IAction[] = []; return ( diff --git a/src/chainlit/frontend/src/state/action.ts b/src/chainlit/frontend/src/state/action.ts index 32105a8712..e0babacf83 100644 --- a/src/chainlit/frontend/src/state/action.ts +++ b/src/chainlit/frontend/src/state/action.ts @@ -2,13 +2,12 @@ import { atom } from 'recoil'; export interface IAction { name: string; - trigger: string; + value: string; + forId: string; description?: string; } -export type IActions = Record; - -export const actionState = atom({ +export const actionState = atom({ key: 'Actions', - default: {} + default: [] }); diff --git a/src/chainlit/frontend/src/state/chat.ts b/src/chainlit/frontend/src/state/chat.ts index c0509ff9da..81219fc93f 100644 --- a/src/chainlit/frontend/src/state/chat.ts +++ b/src/chainlit/frontend/src/state/chat.ts @@ -20,6 +20,7 @@ export interface IChat { export interface IMessage { id?: number; + tempId?: string; author: string; authorIsUser?: boolean; waitForAnswer?: boolean; diff --git a/src/chainlit/frontend/src/state/element.ts b/src/chainlit/frontend/src/state/element.ts index 134c518d7c..235c5cecd1 100644 --- a/src/chainlit/frontend/src/state/element.ts +++ b/src/chainlit/frontend/src/state/element.ts @@ -19,13 +19,14 @@ export interface IElement { content?: ValueOf; name: string; display: 'inline' | 'side' | 'page'; + forId?: string; } -export type IElements = Record; +export type IElements = IElement[]; export const elementState = atom({ key: 'Elements', - default: {} + default: [] }); export const sideViewState = atom({ diff --git a/src/chainlit/lc/utils.py b/src/chainlit/lc/utils.py index d7e128fa10..ab26883a2b 100644 --- a/src/chainlit/lc/utils.py +++ b/src/chainlit/lc/utils.py @@ -1,7 +1,6 @@ from typing import Any from chainlit.types import LLMSettings from typing import List, Optional -from langchain.llms.base import BaseLLM def run_langchain_agent(agent: Any, input_str: str): @@ -19,7 +18,7 @@ def run_langchain_agent(agent: Any, input_str: str): return raw_res, output_key -def get_llm_settings(llm: BaseLLM, stop: Optional[List[str]] = None): +def get_llm_settings(llm, stop: Optional[List[str]] = None): if llm.__class__.__name__ == "OpenAI": return LLMSettings( model_name=llm.model_name, diff --git a/src/chainlit/sdk.py b/src/chainlit/sdk.py index f20159bab2..62245ff223 100644 --- a/src/chainlit/sdk.py +++ b/src/chainlit/sdk.py @@ -1,8 +1,8 @@ from typing import Union -import os import time +import uuid from chainlit.session import Session -from chainlit.types import ElementDisplay, LLMSettings, ElementType, AskSpec +from chainlit.types import LLMSettings, AskSpec from chainlit.client import BaseClient from socketio.exceptions import TimeoutError import inspect @@ -46,89 +46,6 @@ def client(self) -> Union[BaseClient, None]: """Get the 'client' property from the session.""" return self._get_session_property("client") - def send_remote_element( - self, - url: str, - name: str, - type: ElementType, - display: ElementDisplay, - ): - """Send an element to the UI.""" - element = { - "name": name, - "url": url, - "type": type, - "display": display, - } - if self.emit and element: - self.emit("element", element) - - def send_local_element( - self, - ext: str, - content: bytes, - name: str, - type: ElementType, - display: ElementDisplay, - ): - """Send an element to the UI.""" - if self.client: - # Cloud is enabled, upload the element to S3 - url = self.client.upload_element(ext=ext, content=content) - if url: - element = self.client.create_element( - name=name, url=url, type=type, display=display - ) - else: - element = { - "name": name, - "content": content.decode("utf-8") if type == "text" else content, - "type": type, - "display": display, - } - if self.emit and element: - self.emit("element", element) - - def send_local_image(self, path: str, name: str, display: ElementDisplay = "side"): - """Send a local image to the UI.""" - if not self.emit: - return - - with open(path, "rb") as f: - _, ext = os.path.splitext(path) - type = "image" - image_data = f.read() - self.send_local_element(ext, image_data, name, type, display) - - def send_image(self, url: str, name: str, display: ElementDisplay = "side"): - """Send an image to the UI.""" - if not self.emit: - return - - type = "image" - self.send_remote_element(url, name, type, display) - - def send_text(self, text: str, name: str, display: ElementDisplay = "side"): - """Send a text element to the UI.""" - if not self.emit: - return - - type = "text" - ext = ".txt" - self.send_local_element(ext, bytes(text, "utf-8"), name, type, display) - - def send_action(self, name: str, trigger: str, description=""): - """Send an action to the UI.""" - if not self.emit: - return - - action = { - "name": name, - "trigger": trigger, - "description": description, - } - self.emit("action", action) - def send_message( self, author: str, @@ -162,6 +79,9 @@ def send_message( if self.client: message_id = self.client.create_message(msg) msg["id"] = message_id + else: + message_id = uuid.uuid4().hex + msg["tempId"] = message_id msg["createdAt"] = current_milli_time() @@ -170,6 +90,8 @@ def send_message( else: self.emit("message", msg) + return str(message_id) + def send_ask_timeout(self, author: str): """Send a prompt timeout message to the UI.""" self.send_message(author=author, content="Time out", is_error=True) @@ -295,3 +217,10 @@ def get_sdk() -> Union[Chainlit, None]: sdk = candidate break return sdk + + +def get_emit(): + sdk = get_sdk() + if sdk: + return sdk.emit + return None diff --git a/src/chainlit/server.py b/src/chainlit/server.py index b8de4f1cfb..f463d890c3 100644 --- a/src/chainlit/server.py +++ b/src/chainlit/server.py @@ -12,7 +12,7 @@ from chainlit.client import CloudClient from chainlit.sdk import Chainlit from chainlit.markdown import get_markdown_str -from chainlit.types import Action +from chainlit.action import Action from chainlit.telemetry import trace from chainlit.logger import logger @@ -243,14 +243,16 @@ def process_message(session: Session, author: str, input_str: str): # If a langchain agent is available, run it if config.lc_run: # If the developer provided a custom run function, use it - res = config.lc_run(langchain_agent, input_str) + config.lc_run(langchain_agent, input_str) + return else: # Otherwise, use the default run function raw_res, output_key = run_langchain_agent(langchain_agent, input_str) if config.lc_postprocess: # If the developer provided a custom postprocess function, use it - res = config.lc_postprocess(raw_res) + config.lc_postprocess(raw_res) + return elif output_key is not None: # Use the output key if provided res = raw_res[output_key] @@ -293,11 +295,11 @@ def message(): def process_action(session: Session, action: Action): __chainlit_sdk__ = Chainlit(session) - callback = config.action_callbacks.get(action["name"]) + callback = config.action_callbacks.get(action.name) if callback: callback(action) else: - logger.warning("No callback found for action %s", action["name"]) + logger.warning("No callback found for action %s", action.name) @app.route("/action", methods=["POST"]) @@ -308,7 +310,7 @@ def on_action(): body = request.json session_id = body["sessionId"] - action = body["action"] + action = Action(**body["action"]) session = need_session(session_id) diff --git a/src/chainlit/types.py b/src/chainlit/types.py index f555238fb1..6e9ab52fb6 100644 --- a/src/chainlit/types.py +++ b/src/chainlit/types.py @@ -38,12 +38,6 @@ class AskFileResponse: content: bytes -class Action(TypedDict): - name: str - trigger: str - description: str - - @dataclass_json @dataclass class LLMSettings: diff --git a/src/pyproject.toml b/src/pyproject.toml index eab36fe525..7ebea10240 100644 --- a/src/pyproject.toml +++ b/src/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "chainlit" -version = "0.1.103" +version = "0.2.0" keywords = ['LLM', 'Agents', 'gen ai', 'chat ui', 'chatbot ui', 'langchain'] description = "A faster way to build chatbot UIs." authors = ["Chainlit"]