From e70c0be798cd5176e36c30e2e5d66bff21945d9f Mon Sep 17 00:00:00 2001 From: Julien Maupetit Date: Mon, 20 May 2024 10:59:50 +0200 Subject: [PATCH] =?UTF-8?q?=E2=9A=A1=EF=B8=8F(api)=20improve=20database=20?= =?UTF-8?q?session=20usage?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit We should rollback and properly close the session on server restart or database-related failures. --- src/api/qualicharge/api/v1/routers/static.py | 24 ++++++++++++++++---- src/api/qualicharge/db.py | 11 +++++---- 2 files changed, 27 insertions(+), 8 deletions(-) diff --git a/src/api/qualicharge/api/v1/routers/static.py b/src/api/qualicharge/api/v1/routers/static.py index f5c7832a..ffa0d7b1 100644 --- a/src/api/qualicharge/api/v1/routers/static.py +++ b/src/api/qualicharge/api/v1/routers/static.py @@ -141,14 +141,15 @@ async def update( try: update = update_statique(session, id_pdc_itinerance, statique) except IntegrityError as err: + session.rollback() raise HTTPException( status_code=status.HTTP_406_NOT_ACCEPTABLE, detail="id_pdc_itinerance does not match request body", ) from err except ObjectDoesNotExist as err: + session.rollback() raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, - detail="Statique to update does not exist", + status_code=status.HTTP_404_NOT_FOUND, detail=str(err) ) from err return update @@ -158,7 +159,15 @@ async def create( statique: Statique, session: Session = Depends(get_session) ) -> StatiqueItemsCreatedResponse: """Create a statique item.""" - return StatiqueItemsCreatedResponse(items=[save_statique(session, statique)]) + try: + db_statique = save_statique(session, statique) + except ObjectDoesNotExist as err: + session.rollback() + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, detail=str(err) + ) from err + + return StatiqueItemsCreatedResponse(items=[db_statique]) @router.post("/bulk", status_code=status.HTTP_201_CREATED) @@ -166,7 +175,14 @@ async def bulk( statiques: BulkStatiqueList, session: Session = Depends(get_session) ) -> StatiqueItemsCreatedResponse: """Create a set of statique items.""" - statiques = [statique for statique in save_statiques(session, statiques)] + try: + statiques = [statique for statique in save_statiques(session, statiques)] + except ObjectDoesNotExist as err: + session.rollback() + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, detail=str(err) + ) from err + if not len(statiques): raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, diff --git a/src/api/qualicharge/db.py b/src/api/qualicharge/db.py index 4a520136..c0193c21 100644 --- a/src/api/qualicharge/db.py +++ b/src/api/qualicharge/db.py @@ -1,7 +1,7 @@ """QualiCharge database connection.""" import logging -from typing import Optional +from typing import Generator, Optional from pydantic import PostgresDsn from sqlalchemy import Engine as SAEngine @@ -86,16 +86,19 @@ def get_engine() -> SAEngine: return Engine().get_engine(url=settings.DATABASE_URL, echo=settings.DEBUG) -def get_session() -> SMSession: +def get_session() -> Generator[SMSession, None, None]: """Get database session.""" session = Session().get_session(get_engine()) logger.debug("Getting session %s", session) - return session + try: + yield session + finally: + session.close() def is_alive() -> bool: """Check if database connection is alive.""" - session = get_session() + session = next(get_session()) try: session.execute(text("SELECT 1 as is_alive")) return True