Skip to content

Commit

Permalink
track user sessions (#620)
Browse files Browse the repository at this point in the history
* track user sessions

* interactive -> is_interactive

* update sdk
  • Loading branch information
willydouhard authored Jan 3, 2024
1 parent f2b396a commit e1948a7
Show file tree
Hide file tree
Showing 3 changed files with 99 additions and 18 deletions.
44 changes: 44 additions & 0 deletions backend/chainlit/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,23 @@ async def update_thread(
):
pass

async def create_user_session(
self,
id: str,
started_at: str,
anon_user_id: str,
user_id: Optional[str],
) -> Dict:
return {}

async def update_user_session(
self, id: str, is_interactive: bool, ended_at: Optional[str]
) -> Dict:
return {}

async def delete_user_session(self, id: str) -> bool:
return True


class ChainlitDataLayer:
def __init__(
Expand Down Expand Up @@ -377,6 +394,33 @@ async def update_thread(
tags=tags,
)

async def create_user_session(
self,
id: str,
started_at: str,
anon_user_id: str,
user_id: Optional[str],
) -> Dict:
session = await self.client.api.create_user_session(
id=id,
started_at=started_at,
participant_identifier=user_id,
anon_participant_identifier=anon_user_id,
)
return session

async def update_user_session(
self, id: str, is_interactive: bool, ended_at: Optional[str]
) -> Dict:
session = await self.client.api.update_user_session(
id=id, is_interactive=is_interactive, ended_at=ended_at
)
return session

async def delete_user_session(self, id: str) -> bool:
await self.client.api.delete_user_session(id=id)
return True


if api_key := os.environ.get("CHAINLIT_API_KEY"):
chainlit_server = os.environ.get("CHAINLIT_SERVER")
Expand Down
69 changes: 53 additions & 16 deletions backend/chainlit/socket.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import asyncio
import json
import uuid
from datetime import datetime
from typing import Any, Dict

from chainlit.action import Action
Expand Down Expand Up @@ -73,6 +75,22 @@ def load_user_env(user_env):
return user_env


def build_anon_user_identifier(environ):
scope = environ.get("asgi.scope", {})
client_ip, _ = scope.get("client")
ip = environ.get("HTTP_X_FORWARDED_FOR", client_ip)

try:
headers = scope.get("headers", {})
user_agent = next(
(v.decode("utf-8") for k, v in headers if k.decode("utf-8") == "user-agent")
)
return str(uuid.uuid5(uuid.NAMESPACE_DNS, user_agent + ip))

except StopIteration:
return str(uuid.uuid5(uuid.NAMESPACE_DNS, ip))


@socket.on("connect")
async def connect(sid, environ, auth):
if not config.code.on_chat_start and not config.code.on_message:
Expand All @@ -81,10 +99,12 @@ async def connect(sid, environ, auth):
)

user = None
anon_user_identifier = build_anon_user_identifier(environ)
token = None
login_required = require_login()
try:
# Check if the authentication is required
if require_login():
if login_required:
authorization_header = environ.get("HTTP_AUTHORIZATION")
token = authorization_header.split(" ")[1] if authorization_header else None
user = await get_current_user(token=token)
Expand Down Expand Up @@ -114,7 +134,7 @@ def ask_user_fn(data, timeout):

user_env_string = environ.get("HTTP_USER_ENV")
user_env = load_user_env(user_env_string)
WebsocketSession(
ws_session = WebsocketSession(
id=session_id,
socket_id=sid,
emit=emit_fn,
Expand All @@ -125,6 +145,17 @@ def ask_user_fn(data, timeout):
chat_profile=environ.get("HTTP_X_CHAINLIT_CHAT_PROFILE"),
thread_id=environ.get("HTTP_X_CHAINLIT_THREAD_ID"),
)

if data_layer := get_data_layer():
asyncio.create_task(
data_layer.create_user_session(
id=session_id,
started_at=datetime.utcnow().isoformat(),
anon_user_id=anon_user_identifier,
user_id=user.identifier if user else None,
)
)

trace_event("connection_successful")
return True

Expand Down Expand Up @@ -154,22 +185,22 @@ async def connection_successful(sid):

@socket.on("clear_session")
async def clean_session(sid):
if session := WebsocketSession.get(sid):
if config.code.on_chat_end:
init_ws_context(session)
await config.code.on_chat_end()
# Clean up the user session
if session.id in user_sessions:
user_sessions.pop(session.id)

# Clean up the session
session.delete()
await disconnect(sid, force_clear=True)


@socket.on("disconnect")
async def disconnect(sid):
async def disconnect(sid, force_clear=False):
session = WebsocketSession.get(sid)
if session:
if data_layer := get_data_layer():
asyncio.create_task(
data_layer.update_user_session(
id=session.id,
is_interactive=session.has_first_interaction,
ended_at=datetime.utcnow().isoformat(),
)
)

init_ws_context(session)

if config.code.on_chat_end and session:
Expand All @@ -178,16 +209,22 @@ async def disconnect(sid):
if session and session.thread_id and session.has_first_interaction:
await persist_user_session(session.thread_id, session.to_persistable())

async def disconnect_on_timeout(sid):
await asyncio.sleep(config.project.session_timeout)
def clear():
if session := WebsocketSession.get(sid):
# Clean up the user session
if session.id in user_sessions:
user_sessions.pop(session.id)
# Clean up the session
session.delete()

asyncio.ensure_future(disconnect_on_timeout(sid))
async def clear_on_timeout(sid):
await asyncio.sleep(config.project.session_timeout)
clear()

if force_clear:
clear()
else:
asyncio.ensure_future(clear_on_timeout(sid))


@socket.on("stop")
Expand Down
4 changes: 2 additions & 2 deletions backend/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,12 @@ chainlit = 'chainlit.cli:cli'

[tool.poetry.dependencies]
python = ">=3.8.1,<4.0.0"
chainlit_client = "0.1.0rc7"
httpx = ">=0.23.0,<0.25.0"
chainlit_client = "0.1.0rc10"
dataclasses_json = "^0.5.7"
uvicorn = "^0.23.2"
fastapi = "^0.100"
fastapi-socketio = "^0.0.10"
httpx = ">=0.23.0,<0.25.0"
aiofiles = "^23.1.0"
syncer = "^2.0.3"
asyncer = "^0.0.2"
Expand Down

0 comments on commit e1948a7

Please sign in to comment.