Skip to content

Commit

Permalink
WIP: add static to schema (and reverse) utilities
Browse files Browse the repository at this point in the history
  • Loading branch information
jmaupetit committed Apr 24, 2024
1 parent 270075a commit 91b8a5d
Show file tree
Hide file tree
Showing 4 changed files with 231 additions and 5 deletions.
4 changes: 4 additions & 0 deletions src/api/qualicharge/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,3 +15,7 @@ class OIDCAuthenticationError(QualiChargeExceptionMixin, Exception):

class OIDCProviderException(QualiChargeExceptionMixin, Exception):
"""Raised when the OIDC provider does not behave as expected."""


class DatabaseQueryException(QualiChargeExceptionMixin, Exception):
"""Raised when a database query does not provide expected results."""
4 changes: 3 additions & 1 deletion src/api/qualicharge/models/utils.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
"""QualiCharge models utilities."""

from typing import Type

from sqlmodel import SQLModel


class ModelSchemaMixin:
"""A mixin that adds Pydantic to SQLModel helpers."""

def get_fields_for_schema(self, schema: SQLModel):
def get_fields_for_schema(self, schema: Type[SQLModel]):
"""Get input schema-related fields/values as a dict."""
return self.model_dump(include=set(schema.model_fields.keys())) # type: ignore[attr-defined]
160 changes: 156 additions & 4 deletions src/api/qualicharge/schemas/utils.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,165 @@
"""QualiCharge schemas utilities."""

from sqlmodel import SQLModel
import logging
from typing import Optional, Set, Tuple, Type

from sqlalchemy.exc import MultipleResultsFound
from sqlmodel import Session, SQLModel, select

from ..exceptions import DatabaseQueryException
from ..models.static import Statique
from .static import (
Amenageur,
Enseigne,
Localisation,
Operateur,
PointDeCharge,
Station,
)

logger = logging.getLogger(__name__)

DB_TO_STATIC_EXCLUDED_FIELDS = {"id", "created_at", "updated_at"}


def get_or_create(
session: Session, entry: SQLModel, fields: Optional[Set] = None
) -> Tuple[bool, SQLModel]:
"""Get or create schema instance.
Args:
session: SQLModel session
entry: SQLModel schema instance
fields: entry fields used in database query to select target entry.
Defaults to None (use all fields).
Returns:
A (bool, entry) tuple. The boolean states on the entry creation.
Raises:
DatabaseQueryException: Found multiple entries given input fields.
"""
# Try to get entry from selected fields
statement = select(entry.__class__).filter_by(**entry.model_dump(include=fields))
logger.debug(f"{statement=}")
try:
db_entry = session.exec(statement).one_or_none()
except MultipleResultsFound as err:
raise DatabaseQueryException(
f"Multiple results found for input fields {fields}"
) from err

if db_entry is not None:
logger.debug(f"Found database entry with id: {db_entry.id}") # type: ignore[attr-defined]
return False, db_entry

# Create a new entry
session.add(entry)
session.commit()
session.refresh(entry)
return True, entry


def get_or_create(schema: SQLModel, **kwargs):
"""Get or create schema instance."""
def save_schema_from_statique(
session: Session,
schema_klass: Type[SQLModel],
statique: Statique,
fields: Optional[Set] = None,
update: bool = True,
) -> Tuple[bool, SQLModel]:
"""Save schema to database from Statique instance.
Args:
session: SQLModel session
schema_klass: SQLModel schema class to save from Statique
statique: input static model definition
fields: entry fields used in database query to select target entry.
Defaults to None (use all fields).
update: set to True (default) to update entry
def save_statique(statique: Statique):
Returns:
A (bool, entry) tuple. The boolean states on the entry creation.
Raises:
DatabaseQueryException: Found multiple entries given input fields.
"""
# Is this a new entry?
entry = schema_klass(**statique.get_fields_for_schema(schema_klass))
created, entry_db = get_or_create(
session,
entry,
fields=fields,
)
if created or not update:
return created, entry_db

# Update
entry_db = entry_db.model_copy(
update=entry.model_dump(
exclude=fields if fields is None else fields | DB_TO_STATIC_EXCLUDED_FIELDS
)
)
session.add(entry_db)
session.commit()
return created, entry_db


def save_statique(
session: Session, statique: Statique, update: bool = True
) -> Statique:
"""Save Statique instance to database."""
# Core schemas
_, pdc = save_schema_from_statique(
session, PointDeCharge, statique, fields={"id_pdc_itinerance"}, update=update
)
_, station = save_schema_from_statique(
session, Station, statique, fields={"id_station_itinerance"}, update=update
)
_, amenageur = save_schema_from_statique(
session, Amenageur, statique, update=update
)
_, operateur = save_schema_from_statique(
session, Operateur, statique, update=update
)
_, enseigne = save_schema_from_statique(session, Enseigne, statique, update=update)
_, localisation = save_schema_from_statique(
session, Localisation, statique, update=update
)

# Relationships
pdc.station_id = station.id # type: ignore[attr-defined]
station.amenageur_id = amenageur.id # type: ignore[attr-defined]
station.operateur_id = operateur.id # type: ignore[attr-defined]
station.enseigne_id = enseigne.id # type: ignore[attr-defined]
station.localisation_id = localisation.id # type: ignore[attr-defined]

session.add(pdc)
session.add(station)
session.commit()

return Statique(
**pdc.model_dump(exclude=DB_TO_STATIC_EXCLUDED_FIELDS),
**station.model_dump(exclude=DB_TO_STATIC_EXCLUDED_FIELDS),
**operateur.model_dump(exclude=DB_TO_STATIC_EXCLUDED_FIELDS),
**amenageur.model_dump(exclude=DB_TO_STATIC_EXCLUDED_FIELDS),
**enseigne.model_dump(exclude=DB_TO_STATIC_EXCLUDED_FIELDS),
**localisation.model_dump(exclude=DB_TO_STATIC_EXCLUDED_FIELDS),
)


def build_statique(session: Session, id_pdc_itinerance: str) -> Statique:
"""Build Statique instance from database."""
pdc = session.exec(
select(PointDeCharge).where(
PointDeCharge.id_pdc_itinerance == id_pdc_itinerance
)
).one()

return Statique(
**pdc.model_dump(exclude=DB_TO_STATIC_EXCLUDED_FIELDS),
**pdc.station.model_dump(exclude=DB_TO_STATIC_EXCLUDED_FIELDS),
**pdc.station.operateur.model_dump(exclude=DB_TO_STATIC_EXCLUDED_FIELDS),
**pdc.station.amenageur.model_dump(exclude=DB_TO_STATIC_EXCLUDED_FIELDS),
**pdc.station.enseigne.model_dump(exclude=DB_TO_STATIC_EXCLUDED_FIELDS),
**pdc.station.localisation.model_dump(exclude=DB_TO_STATIC_EXCLUDED_FIELDS),
)
68 changes: 68 additions & 0 deletions src/api/tests/schemas/test_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
"""QualiCharge schemas utilities tests."""

import pytest

from qualicharge.exceptions import DatabaseQueryException
from qualicharge.factories.static import AmenageurFactory
from qualicharge.schemas.utils import get_or_create


def test_get_or_create(db_session):
"""Test get_or_create utility."""
amenageur = AmenageurFactory.build()

# Create case
created, db_entry = get_or_create(db_session, amenageur)
assert created is True
assert db_entry.id is not None
assert db_entry.id == amenageur.id
assert db_entry.nom_amenageur == amenageur.nom_amenageur
assert db_entry.siren_amenageur == amenageur.siren_amenageur
assert db_entry.contact_amenageur == amenageur.contact_amenageur

# Get case
created, db_entry = get_or_create(db_session, amenageur)
assert created is False
assert db_entry.id == amenageur.id
assert db_entry.nom_amenageur == amenageur.nom_amenageur
assert db_entry.siren_amenageur == amenageur.siren_amenageur
assert db_entry.contact_amenageur == amenageur.contact_amenageur


def test_get_or_create_with_explicit_fields_selection(db_session):
"""Test get_or_create utility using explicit fields selection."""
amenageur = AmenageurFactory.build()
fields = {
"contact_amenageur",
}

# Create case
created, db_entry = get_or_create(db_session, amenageur, fields=fields)
assert created is True
assert db_entry.id is not None
assert db_entry.id == amenageur.id
assert db_entry.nom_amenageur == amenageur.nom_amenageur
assert db_entry.siren_amenageur == amenageur.siren_amenageur
assert db_entry.contact_amenageur == amenageur.contact_amenageur

# Get case
created, db_entry = get_or_create(db_session, amenageur, fields=fields)
assert created is False
assert db_entry.id == amenageur.id
assert db_entry.nom_amenageur == amenageur.nom_amenageur
assert db_entry.siren_amenageur == amenageur.siren_amenageur
assert db_entry.contact_amenageur == amenageur.contact_amenageur


def test_get_or_create_with_multiple_existing_entries(db_session):
"""Test get_or_create utility when multiple entries exist."""
AmenageurFactory.__session__ = db_session

nom_amenageur = "ACME Inc."
amenageurs = AmenageurFactory.create_batch_sync(2, nom_amenageur=nom_amenageur)

with pytest.raises(
DatabaseQueryException,
match="Multiple results found for input fields {'nom_amenageur'}",
):
get_or_create(db_session, amenageurs[0], fields={"nom_amenageur"})

0 comments on commit 91b8a5d

Please sign in to comment.