Skip to content

Commit

Permalink
prevent the app to crash if the data layer is not reachable (#644)
Browse files Browse the repository at this point in the history
* prevent the app to crash if the data layer is not reachable

* make the app still usable if auth is enabled and data layer is down
  • Loading branch information
willydouhard authored Jan 10, 2024
1 parent 38528dc commit a63501f
Show file tree
Hide file tree
Showing 4 changed files with 29 additions and 10 deletions.
2 changes: 1 addition & 1 deletion backend/chainlit/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ async def authenticate_user(token: str = Depends(reuseable_oauth)):
try:
persisted_user = await data_layer.get_user(user.identifier)
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
return user
if persisted_user == None:
raise HTTPException(status_code=401, detail="User does not exist")

Expand Down
14 changes: 9 additions & 5 deletions backend/chainlit/emitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

from chainlit.data import get_data_layer
from chainlit.element import Element, File
from chainlit.logger import logger
from chainlit.message import Message
from chainlit.session import BaseSession, WebsocketSession
from chainlit.step import StepDict
Expand Down Expand Up @@ -172,11 +173,14 @@ async def flush_thread_queues(self, interaction: str):
user_id = self.session.user.id
else:
user_id = None
await data_layer.update_thread(
thread_id=self.session.thread_id,
user_id=user_id,
metadata={"name": interaction},
)
try:
await data_layer.update_thread(
thread_id=self.session.thread_id,
user_id=user_id,
metadata={"name": interaction},
)
except Exception as e:
logger.error(f"Error updating thread: {e}")
await self.session.flush_method_queue()

async def init_thread(self, interaction: str):
Expand Down
17 changes: 14 additions & 3 deletions backend/chainlit/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,7 +266,11 @@ async def login(form_data: OAuth2PasswordRequestForm = Depends()):
)
access_token = create_jwt(user)
if data_layer := get_data_layer():
await data_layer.create_user(user)
try:
await data_layer.create_user(user)
except Exception as e:
logger.error(f"Error creating user: {e}")

return {
"access_token": access_token,
"token_type": "bearer",
Expand Down Expand Up @@ -298,7 +302,11 @@ async def header_auth(request: Request):

access_token = create_jwt(user)
if data_layer := get_data_layer():
await data_layer.create_user(user)
try:
await data_layer.create_user(user)
except Exception as e:
logger.error(f"Error creating user: {e}")

return {
"access_token": access_token,
"token_type": "bearer",
Expand Down Expand Up @@ -406,7 +414,10 @@ async def oauth_callback(
access_token = create_jwt(user)

if data_layer := get_data_layer():
await data_layer.create_user(user)
try:
await data_layer.create_user(user)
except Exception as e:
logger.error(f"Error creating user: {e}")

params = urllib.parse.urlencode(
{
Expand Down
6 changes: 5 additions & 1 deletion backend/chainlit/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from typing import TYPE_CHECKING, Any, Callable, Deque, Dict, List, Optional, Union

import aiofiles
from chainlit.logger import logger

if TYPE_CHECKING:
from chainlit.message import Message
Expand Down Expand Up @@ -242,7 +243,10 @@ async def flush_method_queue(self):
for method_name, queue in self.thread_queues.items():
while queue:
method, self, args, kwargs = queue.popleft()
await method(self, *args, **kwargs)
try:
await method(self, *args, **kwargs)
except Exception as e:
logger.error(f"Error while flushing {method_name}: {e}")

@classmethod
def get(cls, socket_id: str):
Expand Down

0 comments on commit a63501f

Please sign in to comment.