diff --git a/src/api/qualicharge/api/v1/routers/static.py b/src/api/qualicharge/api/v1/routers/static.py index f5c7832a..ffa0d7b1 100644 --- a/src/api/qualicharge/api/v1/routers/static.py +++ b/src/api/qualicharge/api/v1/routers/static.py @@ -141,14 +141,15 @@ async def update( try: update = update_statique(session, id_pdc_itinerance, statique) except IntegrityError as err: + session.rollback() raise HTTPException( status_code=status.HTTP_406_NOT_ACCEPTABLE, detail="id_pdc_itinerance does not match request body", ) from err except ObjectDoesNotExist as err: + session.rollback() raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, - detail="Statique to update does not exist", + status_code=status.HTTP_404_NOT_FOUND, detail=str(err) ) from err return update @@ -158,7 +159,15 @@ async def create( statique: Statique, session: Session = Depends(get_session) ) -> StatiqueItemsCreatedResponse: """Create a statique item.""" - return StatiqueItemsCreatedResponse(items=[save_statique(session, statique)]) + try: + db_statique = save_statique(session, statique) + except ObjectDoesNotExist as err: + session.rollback() + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, detail=str(err) + ) from err + + return StatiqueItemsCreatedResponse(items=[db_statique]) @router.post("/bulk", status_code=status.HTTP_201_CREATED) @@ -166,7 +175,14 @@ async def bulk( statiques: BulkStatiqueList, session: Session = Depends(get_session) ) -> StatiqueItemsCreatedResponse: """Create a set of statique items.""" - statiques = [statique for statique in save_statiques(session, statiques)] + try: + statiques = [statique for statique in save_statiques(session, statiques)] + except ObjectDoesNotExist as err: + session.rollback() + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, detail=str(err) + ) from err + if not len(statiques): raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, diff --git a/src/api/qualicharge/db.py b/src/api/qualicharge/db.py index 4a520136..c0193c21 100644 --- a/src/api/qualicharge/db.py +++ b/src/api/qualicharge/db.py @@ -1,7 +1,7 @@ """QualiCharge database connection.""" import logging -from typing import Optional +from typing import Generator, Optional from pydantic import PostgresDsn from sqlalchemy import Engine as SAEngine @@ -86,16 +86,19 @@ def get_engine() -> SAEngine: return Engine().get_engine(url=settings.DATABASE_URL, echo=settings.DEBUG) -def get_session() -> SMSession: +def get_session() -> Generator[SMSession, None, None]: """Get database session.""" session = Session().get_session(get_engine()) logger.debug("Getting session %s", session) - return session + try: + yield session + finally: + session.close() def is_alive() -> bool: """Check if database connection is alive.""" - session = get_session() + session = next(get_session()) try: session.execute(text("SELECT 1 as is_alive")) return True