From d5316335d515806d486532e24b1c579299682284 Mon Sep 17 00:00:00 2001 From: Willy Douhard Date: Thu, 12 Dec 2024 12:21:58 +0100 Subject: [PATCH] feat: enhance socket.io configuration and CORS handling (#1575) This improves Socket.IO configuration and CORS handling for cross-origin use cases, particularly focusing on scenarios where Chainlit is embedded in other websites (copilot, closes #1279) or deployed behind load balancers: - Make Socket.IO client transports configurable via `project.transports` setting - Move connection parameters from headers to Socket.IO auth object for better websocket compatibility - Update CORS handling to properly support authenticated cross-origin requests. Closes #1359. - Remove unnecessary task start/end events from window_message handler BREAKING CHANGE: For cross-origin deployments, wildcard "*" is no longer supported in allow_origins when using authenticated connections. Specific origins must be listed in the project config. Possibly resolves #719, #1339, #1369, #1407, #1492, #1507. --- backend/chainlit/config.py | 2 + backend/chainlit/server.py | 5 ++- backend/chainlit/socket.py | 45 ++++++----------------- cypress/e2e/copilot/.chainlit/config.toml | 2 +- frontend/src/App.tsx | 2 + libs/copilot/src/chat/index.tsx | 1 + libs/react-client/src/useChatSession.ts | 25 +++++++------ 7 files changed, 36 insertions(+), 46 deletions(-) diff --git a/backend/chainlit/config.py b/backend/chainlit/config.py index b90f162f07..18ee6be8db 100644 --- a/backend/chainlit/config.py +++ b/backend/chainlit/config.py @@ -311,6 +311,8 @@ class CodeSettings: @dataclass() class ProjectSettings(DataClassJsonMixin): allow_origins: List[str] = Field(default_factory=lambda: ["*"]) + # Socket.io client transports option + transports: Optional[List[str]] = None enable_telemetry: bool = True # List of environment variables to be provided by each user to use the app. If empty, no environment variables will be asked to the user. user_env: Optional[List[str]] = None diff --git a/backend/chainlit/server.py b/backend/chainlit/server.py index 5118f544a7..7aeabe5329 100644 --- a/backend/chainlit/server.py +++ b/backend/chainlit/server.py @@ -301,7 +301,10 @@ def get_html_template(): """ - js = f"""""" + js = f"""""" css = None if config.ui.custom_css: diff --git a/backend/chainlit/socket.py b/backend/chainlit/socket.py index d79c76c16e..5053262e2f 100644 --- a/backend/chainlit/socket.py +++ b/backend/chainlit/socket.py @@ -1,7 +1,6 @@ import asyncio import json import time -import uuid from typing import Any, Dict, Literal from urllib.parse import unquote @@ -77,24 +76,8 @@ def load_user_env(user_env): return user_env -def build_anon_user_identifier(environ): - scope = environ.get("asgi.scope", {}) - client_ip, _ = scope.get("client") - ip = environ.get("HTTP_X_FORWARDED_FOR", client_ip) - - try: - headers = scope.get("headers", {}) - user_agent = next( - (v.decode("utf-8") for k, v in headers if k.decode("utf-8") == "user-agent") - ) - return str(uuid.uuid5(uuid.NAMESPACE_DNS, user_agent + ip)) - - except StopIteration: - return str(uuid.uuid5(uuid.NAMESPACE_DNS, ip)) - - @sio.on("connect") -async def connect(sid, environ): +async def connect(sid, environ, auth): if ( not config.code.on_chat_start and not config.code.on_message @@ -110,8 +93,8 @@ async def connect(sid, environ): try: # Check if the authentication is required if login_required: - authorization_header = environ.get("HTTP_AUTHORIZATION") - token = authorization_header.split(" ")[1] if authorization_header else None + token = auth.get("token") + token = token.split(" ")[1] if token else None user = await get_current_user(token=token) except Exception: logger.info("Authentication failed") @@ -125,16 +108,16 @@ def emit_fn(event, data): def emit_call_fn(event: Literal["ask", "call_fn"], data, timeout): return sio.call(event, data, timeout=timeout, to=sid) - session_id = environ.get("HTTP_X_CHAINLIT_SESSION_ID") + session_id = auth.get("sessionId") if restore_existing_session(sid, session_id, emit_fn, emit_call_fn): return True - user_env_string = environ.get("HTTP_USER_ENV") + user_env_string = auth.get("userEnv") user_env = load_user_env(user_env_string) - client_type = environ.get("HTTP_X_CHAINLIT_CLIENT_TYPE") + client_type = auth.get("clientType") http_referer = environ.get("HTTP_REFERER") - url_encoded_chat_profile = environ.get("HTTP_X_CHAINLIT_CHAT_PROFILE") + url_encoded_chat_profile = auth.get("chatProfile") chat_profile = ( unquote(url_encoded_chat_profile) if url_encoded_chat_profile else None ) @@ -149,7 +132,7 @@ def emit_call_fn(event: Literal["ask", "call_fn"], data, timeout): user=user, token=token, chat_profile=chat_profile, - thread_id=environ.get("HTTP_X_CHAINLIT_THREAD_ID"), + thread_id=auth.get("threadId"), languages=environ.get("HTTP_ACCEPT_LANGUAGE"), http_referer=http_referer, ) @@ -162,13 +145,13 @@ def emit_call_fn(event: Literal["ask", "call_fn"], data, timeout): async def connection_successful(sid): context = init_ws_context(sid) - if context.session.restored: - return - await context.emitter.task_end() await context.emitter.clear("clear_ask") await context.emitter.clear("clear_call_fn") + if context.session.restored: + return + if context.session.thread_id_to_resume and config.code.on_chat_resume: thread = await resume_thread(context.session) if thread: @@ -312,17 +295,13 @@ async def message(sid, payload: MessagePayload): async def window_message(sid, data): """Handle a message send by the host window.""" session = WebsocketSession.require(sid) - context = init_ws_context(session) - - await context.emitter.task_start() + init_ws_context(session) if config.code.on_window_message: try: await config.code.on_window_message(data) except asyncio.CancelledError: pass - finally: - await context.emitter.task_end() @sio.on("audio_start") diff --git a/cypress/e2e/copilot/.chainlit/config.toml b/cypress/e2e/copilot/.chainlit/config.toml index e2a93af08f..9c42755715 100644 --- a/cypress/e2e/copilot/.chainlit/config.toml +++ b/cypress/e2e/copilot/.chainlit/config.toml @@ -13,7 +13,7 @@ session_timeout = 3600 cache = false # Authorized origins -allow_origins = ["*"] +allow_origins = ["http://127.0.0.1:8000"] # Follow symlink for asset mount (see https://github.com/Chainlit/chainlit/issues/317) # follow_symlink = false diff --git a/frontend/src/App.tsx b/frontend/src/App.tsx index cc80e03ac9..9238ca2519 100644 --- a/frontend/src/App.tsx +++ b/frontend/src/App.tsx @@ -42,6 +42,7 @@ declare global { light?: ThemOverride; dark?: ThemOverride; }; + transports?: string[] } } @@ -99,6 +100,7 @@ function App() { return; } else { connect({ + transports: window.transports, userEnv, accessToken }); diff --git a/libs/copilot/src/chat/index.tsx b/libs/copilot/src/chat/index.tsx index 5f0a0779e7..3cc4bd3289 100644 --- a/libs/copilot/src/chat/index.tsx +++ b/libs/copilot/src/chat/index.tsx @@ -12,6 +12,7 @@ export default function ChatWrapper() { useEffect(() => { if (session?.socket?.connected) return; connect({ + transports: window.transports, userEnv: {}, accessToken: `Bearer ${accessToken}` }); diff --git a/libs/react-client/src/useChatSession.ts b/libs/react-client/src/useChatSession.ts index 441e66d665..b1079179f0 100644 --- a/libs/react-client/src/useChatSession.ts +++ b/libs/react-client/src/useChatSession.ts @@ -78,16 +78,18 @@ const useChatSession = () => { // Use currentThreadId as thread id in websocket header useEffect(() => { if (session?.socket) { - session.socket.io.opts.extraHeaders!['X-Chainlit-Thread-Id'] = + session.socket.auth["threadId"] = currentThreadId || ''; } }, [currentThreadId]); const _connect = useCallback( ({ + transports, userEnv, accessToken }: { + transports?: string[] userEnv: Record; accessToken?: string; }) => { @@ -100,16 +102,17 @@ const useChatSession = () => { const socket = io(uri, { path, - extraHeaders: { - Authorization: accessToken || '', - 'X-Chainlit-Client-Type': client.type, - 'X-Chainlit-Session-Id': sessionId, - 'X-Chainlit-Thread-Id': idToResume || '', - 'user-env': JSON.stringify(userEnv), - 'X-Chainlit-Chat-Profile': chatProfile - ? encodeURIComponent(chatProfile) - : '' - } + withCredentials: true, + transports, + auth: { + token: accessToken, + clientType: client.type, + sessionId, + threadId: idToResume || '', + userEnv: JSON.stringify(userEnv), + chatProfile: chatProfile ? encodeURIComponent(chatProfile) : '' + } + }); setSession((old) => { old?.socket?.removeAllListeners();