Skip to content

Commit

Permalink
feat: enhance socket.io configuration and CORS handling (#1575)
Browse files Browse the repository at this point in the history
This improves Socket.IO configuration and CORS handling for cross-origin use cases,
particularly focusing on scenarios where Chainlit is embedded in other websites (copilot, closes #1279) or deployed behind load balancers:

- Make Socket.IO client transports configurable via `project.transports` setting
- Move connection parameters from headers to Socket.IO auth object for better
  websocket compatibility
- Update CORS handling to properly support authenticated cross-origin requests. Closes #1359.
- Remove unnecessary task start/end events from window_message handler

BREAKING CHANGE: For cross-origin deployments, wildcard "*" is no longer supported
in allow_origins when using authenticated connections. Specific origins must be
listed in the project config.

Possibly resolves #719, #1339, #1369, #1407, #1492, #1507.
  • Loading branch information
willydouhard authored Dec 12, 2024
1 parent 8b2d4ba commit d531633
Show file tree
Hide file tree
Showing 7 changed files with 36 additions and 46 deletions.
2 changes: 2 additions & 0 deletions backend/chainlit/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -311,6 +311,8 @@ class CodeSettings:
@dataclass()
class ProjectSettings(DataClassJsonMixin):
allow_origins: List[str] = Field(default_factory=lambda: ["*"])
# Socket.io client transports option
transports: Optional[List[str]] = None
enable_telemetry: bool = True
# List of environment variables to be provided by each user to use the app. If empty, no environment variables will be asked to the user.
user_env: Optional[List[str]] = None
Expand Down
5 changes: 4 additions & 1 deletion backend/chainlit/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -301,7 +301,10 @@ def get_html_template():
<meta property="og:url" content="{url}">
<meta property="og:root_path" content="{ROOT_PATH}">"""

js = f"""<script>{f"window.theme = {json.dumps(config.ui.theme.to_dict())}; " if config.ui.theme else ""}</script>"""
js = f"""<script>
{f"window.theme = {json.dumps(config.ui.theme.to_dict())}; " if config.ui.theme else ""}
{f"window.transports = {json.dumps(config.project.transports)}; " if config.project.transports else "undefined"}
</script>"""

css = None
if config.ui.custom_css:
Expand Down
45 changes: 12 additions & 33 deletions backend/chainlit/socket.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import asyncio
import json
import time
import uuid
from typing import Any, Dict, Literal
from urllib.parse import unquote

Expand Down Expand Up @@ -77,24 +76,8 @@ def load_user_env(user_env):
return user_env


def build_anon_user_identifier(environ):
scope = environ.get("asgi.scope", {})
client_ip, _ = scope.get("client")
ip = environ.get("HTTP_X_FORWARDED_FOR", client_ip)

try:
headers = scope.get("headers", {})
user_agent = next(
(v.decode("utf-8") for k, v in headers if k.decode("utf-8") == "user-agent")
)
return str(uuid.uuid5(uuid.NAMESPACE_DNS, user_agent + ip))

except StopIteration:
return str(uuid.uuid5(uuid.NAMESPACE_DNS, ip))


@sio.on("connect")
async def connect(sid, environ):
async def connect(sid, environ, auth):
if (
not config.code.on_chat_start
and not config.code.on_message
Expand All @@ -110,8 +93,8 @@ async def connect(sid, environ):
try:
# Check if the authentication is required
if login_required:
authorization_header = environ.get("HTTP_AUTHORIZATION")
token = authorization_header.split(" ")[1] if authorization_header else None
token = auth.get("token")
token = token.split(" ")[1] if token else None
user = await get_current_user(token=token)
except Exception:
logger.info("Authentication failed")
Expand All @@ -125,16 +108,16 @@ def emit_fn(event, data):
def emit_call_fn(event: Literal["ask", "call_fn"], data, timeout):
return sio.call(event, data, timeout=timeout, to=sid)

session_id = environ.get("HTTP_X_CHAINLIT_SESSION_ID")
session_id = auth.get("sessionId")
if restore_existing_session(sid, session_id, emit_fn, emit_call_fn):
return True

user_env_string = environ.get("HTTP_USER_ENV")
user_env_string = auth.get("userEnv")
user_env = load_user_env(user_env_string)

client_type = environ.get("HTTP_X_CHAINLIT_CLIENT_TYPE")
client_type = auth.get("clientType")
http_referer = environ.get("HTTP_REFERER")
url_encoded_chat_profile = environ.get("HTTP_X_CHAINLIT_CHAT_PROFILE")
url_encoded_chat_profile = auth.get("chatProfile")
chat_profile = (
unquote(url_encoded_chat_profile) if url_encoded_chat_profile else None
)
Expand All @@ -149,7 +132,7 @@ def emit_call_fn(event: Literal["ask", "call_fn"], data, timeout):
user=user,
token=token,
chat_profile=chat_profile,
thread_id=environ.get("HTTP_X_CHAINLIT_THREAD_ID"),
thread_id=auth.get("threadId"),
languages=environ.get("HTTP_ACCEPT_LANGUAGE"),
http_referer=http_referer,
)
Expand All @@ -162,13 +145,13 @@ def emit_call_fn(event: Literal["ask", "call_fn"], data, timeout):
async def connection_successful(sid):
context = init_ws_context(sid)

if context.session.restored:
return

await context.emitter.task_end()
await context.emitter.clear("clear_ask")
await context.emitter.clear("clear_call_fn")

if context.session.restored:
return

if context.session.thread_id_to_resume and config.code.on_chat_resume:
thread = await resume_thread(context.session)
if thread:
Expand Down Expand Up @@ -312,17 +295,13 @@ async def message(sid, payload: MessagePayload):
async def window_message(sid, data):
"""Handle a message send by the host window."""
session = WebsocketSession.require(sid)
context = init_ws_context(session)

await context.emitter.task_start()
init_ws_context(session)

if config.code.on_window_message:
try:
await config.code.on_window_message(data)
except asyncio.CancelledError:
pass
finally:
await context.emitter.task_end()


@sio.on("audio_start")
Expand Down
2 changes: 1 addition & 1 deletion cypress/e2e/copilot/.chainlit/config.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ session_timeout = 3600
cache = false

# Authorized origins
allow_origins = ["*"]
allow_origins = ["http://127.0.0.1:8000"]

# Follow symlink for asset mount (see https://github.com/Chainlit/chainlit/issues/317)
# follow_symlink = false
Expand Down
2 changes: 2 additions & 0 deletions frontend/src/App.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ declare global {
light?: ThemOverride;
dark?: ThemOverride;
};
transports?: string[]
}
}

Expand Down Expand Up @@ -99,6 +100,7 @@ function App() {
return;
} else {
connect({
transports: window.transports,
userEnv,
accessToken
});
Expand Down
1 change: 1 addition & 0 deletions libs/copilot/src/chat/index.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ export default function ChatWrapper() {
useEffect(() => {
if (session?.socket?.connected) return;
connect({
transports: window.transports,
userEnv: {},
accessToken: `Bearer ${accessToken}`
});
Expand Down
25 changes: 14 additions & 11 deletions libs/react-client/src/useChatSession.ts
Original file line number Diff line number Diff line change
Expand Up @@ -78,16 +78,18 @@ const useChatSession = () => {
// Use currentThreadId as thread id in websocket header
useEffect(() => {
if (session?.socket) {
session.socket.io.opts.extraHeaders!['X-Chainlit-Thread-Id'] =
session.socket.auth["threadId"] =
currentThreadId || '';
}
}, [currentThreadId]);

const _connect = useCallback(
({
transports,
userEnv,
accessToken
}: {
transports?: string[]
userEnv: Record<string, string>;
accessToken?: string;
}) => {
Expand All @@ -100,16 +102,17 @@ const useChatSession = () => {

const socket = io(uri, {
path,
extraHeaders: {
Authorization: accessToken || '',
'X-Chainlit-Client-Type': client.type,
'X-Chainlit-Session-Id': sessionId,
'X-Chainlit-Thread-Id': idToResume || '',
'user-env': JSON.stringify(userEnv),
'X-Chainlit-Chat-Profile': chatProfile
? encodeURIComponent(chatProfile)
: ''
}
withCredentials: true,
transports,
auth: {
token: accessToken,
clientType: client.type,
sessionId,
threadId: idToResume || '',
userEnv: JSON.stringify(userEnv),
chatProfile: chatProfile ? encodeURIComponent(chatProfile) : ''
}

});
setSession((old) => {
old?.socket?.removeAllListeners();
Expand Down

0 comments on commit d531633

Please sign in to comment.