-
Notifications
You must be signed in to change notification settings - Fork 2
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
WIP: add static to schema (and reverse) utilities
- Loading branch information
Showing
4 changed files
with
231 additions
and
5 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,11 +1,13 @@ | ||
"""QualiCharge models utilities.""" | ||
|
||
from typing import Type | ||
|
||
from sqlmodel import SQLModel | ||
|
||
|
||
class ModelSchemaMixin: | ||
"""A mixin that adds Pydantic to SQLModel helpers.""" | ||
|
||
def get_fields_for_schema(self, schema: SQLModel): | ||
def get_fields_for_schema(self, schema: Type[SQLModel]): | ||
"""Get input schema-related fields/values as a dict.""" | ||
return self.model_dump(include=set(schema.model_fields.keys())) # type: ignore[attr-defined] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,13 +1,165 @@ | ||
"""QualiCharge schemas utilities.""" | ||
|
||
from sqlmodel import SQLModel | ||
import logging | ||
from typing import Optional, Set, Tuple, Type | ||
|
||
from sqlalchemy.exc import MultipleResultsFound | ||
from sqlmodel import Session, SQLModel, select | ||
|
||
from ..exceptions import DatabaseQueryException | ||
from ..models.static import Statique | ||
from .static import ( | ||
Amenageur, | ||
Enseigne, | ||
Localisation, | ||
Operateur, | ||
PointDeCharge, | ||
Station, | ||
) | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
DB_TO_STATIC_EXCLUDED_FIELDS = {"id", "created_at", "updated_at"} | ||
|
||
|
||
def get_or_create( | ||
session: Session, entry: SQLModel, fields: Optional[Set] = None | ||
) -> Tuple[bool, SQLModel]: | ||
"""Get or create schema instance. | ||
Args: | ||
session: SQLModel session | ||
entry: SQLModel schema instance | ||
fields: entry fields used in database query to select target entry. | ||
Defaults to None (use all fields). | ||
Returns: | ||
A (bool, entry) tuple. The boolean states on the entry creation. | ||
Raises: | ||
DatabaseQueryException: Found multiple entries given input fields. | ||
""" | ||
# Try to get entry from selected fields | ||
statement = select(entry.__class__).filter_by(**entry.model_dump(include=fields)) | ||
logger.debug(f"{statement=}") | ||
try: | ||
db_entry = session.exec(statement).one_or_none() | ||
except MultipleResultsFound as err: | ||
raise DatabaseQueryException( | ||
f"Multiple results found for input fields {fields}" | ||
) from err | ||
|
||
if db_entry is not None: | ||
logger.debug(f"Found database entry with id: {db_entry.id}") # type: ignore[attr-defined] | ||
return False, db_entry | ||
|
||
# Create a new entry | ||
session.add(entry) | ||
session.commit() | ||
session.refresh(entry) | ||
return True, entry | ||
|
||
|
||
def get_or_create(schema: SQLModel, **kwargs): | ||
"""Get or create schema instance.""" | ||
def save_schema_from_statique( | ||
session: Session, | ||
schema_klass: Type[SQLModel], | ||
statique: Statique, | ||
fields: Optional[Set] = None, | ||
update: bool = True, | ||
) -> Tuple[bool, SQLModel]: | ||
"""Save schema to database from Statique instance. | ||
Args: | ||
session: SQLModel session | ||
schema_klass: SQLModel schema class to save from Statique | ||
statique: input static model definition | ||
fields: entry fields used in database query to select target entry. | ||
Defaults to None (use all fields). | ||
update: set to True (default) to update entry | ||
def save_statique(statique: Statique): | ||
Returns: | ||
A (bool, entry) tuple. The boolean states on the entry creation. | ||
Raises: | ||
DatabaseQueryException: Found multiple entries given input fields. | ||
""" | ||
# Is this a new entry? | ||
entry = schema_klass(**statique.get_fields_for_schema(schema_klass)) | ||
created, entry_db = get_or_create( | ||
session, | ||
entry, | ||
fields=fields, | ||
) | ||
if created or not update: | ||
return created, entry_db | ||
|
||
# Update | ||
entry_db = entry_db.model_copy( | ||
update=entry.model_dump( | ||
exclude=fields if fields is None else fields | DB_TO_STATIC_EXCLUDED_FIELDS | ||
) | ||
) | ||
session.add(entry_db) | ||
session.commit() | ||
return created, entry_db | ||
|
||
|
||
def save_statique( | ||
session: Session, statique: Statique, update: bool = True | ||
) -> Statique: | ||
"""Save Statique instance to database.""" | ||
# Core schemas | ||
_, pdc = save_schema_from_statique( | ||
session, PointDeCharge, statique, fields={"id_pdc_itinerance"}, update=update | ||
) | ||
_, station = save_schema_from_statique( | ||
session, Station, statique, fields={"id_station_itinerance"}, update=update | ||
) | ||
_, amenageur = save_schema_from_statique( | ||
session, Amenageur, statique, update=update | ||
) | ||
_, operateur = save_schema_from_statique( | ||
session, Operateur, statique, update=update | ||
) | ||
_, enseigne = save_schema_from_statique(session, Enseigne, statique, update=update) | ||
_, localisation = save_schema_from_statique( | ||
session, Localisation, statique, update=update | ||
) | ||
|
||
# Relationships | ||
pdc.station_id = station.id # type: ignore[attr-defined] | ||
station.amenageur_id = amenageur.id # type: ignore[attr-defined] | ||
station.operateur_id = operateur.id # type: ignore[attr-defined] | ||
station.enseigne_id = enseigne.id # type: ignore[attr-defined] | ||
station.localisation_id = localisation.id # type: ignore[attr-defined] | ||
|
||
session.add(pdc) | ||
session.add(station) | ||
session.commit() | ||
|
||
return Statique( | ||
**pdc.model_dump(exclude=DB_TO_STATIC_EXCLUDED_FIELDS), | ||
**station.model_dump(exclude=DB_TO_STATIC_EXCLUDED_FIELDS), | ||
**operateur.model_dump(exclude=DB_TO_STATIC_EXCLUDED_FIELDS), | ||
**amenageur.model_dump(exclude=DB_TO_STATIC_EXCLUDED_FIELDS), | ||
**enseigne.model_dump(exclude=DB_TO_STATIC_EXCLUDED_FIELDS), | ||
**localisation.model_dump(exclude=DB_TO_STATIC_EXCLUDED_FIELDS), | ||
) | ||
|
||
|
||
def build_statique(session: Session, id_pdc_itinerance: str) -> Statique: | ||
"""Build Statique instance from database.""" | ||
pdc = session.exec( | ||
select(PointDeCharge).where( | ||
PointDeCharge.id_pdc_itinerance == id_pdc_itinerance | ||
) | ||
).one() | ||
|
||
return Statique( | ||
**pdc.model_dump(exclude=DB_TO_STATIC_EXCLUDED_FIELDS), | ||
**pdc.station.model_dump(exclude=DB_TO_STATIC_EXCLUDED_FIELDS), | ||
**pdc.station.operateur.model_dump(exclude=DB_TO_STATIC_EXCLUDED_FIELDS), | ||
**pdc.station.amenageur.model_dump(exclude=DB_TO_STATIC_EXCLUDED_FIELDS), | ||
**pdc.station.enseigne.model_dump(exclude=DB_TO_STATIC_EXCLUDED_FIELDS), | ||
**pdc.station.localisation.model_dump(exclude=DB_TO_STATIC_EXCLUDED_FIELDS), | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,68 @@ | ||
"""QualiCharge schemas utilities tests.""" | ||
|
||
import pytest | ||
|
||
from qualicharge.exceptions import DatabaseQueryException | ||
from qualicharge.factories.static import AmenageurFactory | ||
from qualicharge.schemas.utils import get_or_create | ||
|
||
|
||
def test_get_or_create(db_session): | ||
"""Test get_or_create utility.""" | ||
amenageur = AmenageurFactory.build() | ||
|
||
# Create case | ||
created, db_entry = get_or_create(db_session, amenageur) | ||
assert created is True | ||
assert db_entry.id is not None | ||
assert db_entry.id == amenageur.id | ||
assert db_entry.nom_amenageur == amenageur.nom_amenageur | ||
assert db_entry.siren_amenageur == amenageur.siren_amenageur | ||
assert db_entry.contact_amenageur == amenageur.contact_amenageur | ||
|
||
# Get case | ||
created, db_entry = get_or_create(db_session, amenageur) | ||
assert created is False | ||
assert db_entry.id == amenageur.id | ||
assert db_entry.nom_amenageur == amenageur.nom_amenageur | ||
assert db_entry.siren_amenageur == amenageur.siren_amenageur | ||
assert db_entry.contact_amenageur == amenageur.contact_amenageur | ||
|
||
|
||
def test_get_or_create_with_explicit_fields_selection(db_session): | ||
"""Test get_or_create utility using explicit fields selection.""" | ||
amenageur = AmenageurFactory.build() | ||
fields = { | ||
"contact_amenageur", | ||
} | ||
|
||
# Create case | ||
created, db_entry = get_or_create(db_session, amenageur, fields=fields) | ||
assert created is True | ||
assert db_entry.id is not None | ||
assert db_entry.id == amenageur.id | ||
assert db_entry.nom_amenageur == amenageur.nom_amenageur | ||
assert db_entry.siren_amenageur == amenageur.siren_amenageur | ||
assert db_entry.contact_amenageur == amenageur.contact_amenageur | ||
|
||
# Get case | ||
created, db_entry = get_or_create(db_session, amenageur, fields=fields) | ||
assert created is False | ||
assert db_entry.id == amenageur.id | ||
assert db_entry.nom_amenageur == amenageur.nom_amenageur | ||
assert db_entry.siren_amenageur == amenageur.siren_amenageur | ||
assert db_entry.contact_amenageur == amenageur.contact_amenageur | ||
|
||
|
||
def test_get_or_create_with_multiple_existing_entries(db_session): | ||
"""Test get_or_create utility when multiple entries exist.""" | ||
AmenageurFactory.__session__ = db_session | ||
|
||
nom_amenageur = "ACME Inc." | ||
amenageurs = AmenageurFactory.create_batch_sync(2, nom_amenageur=nom_amenageur) | ||
|
||
with pytest.raises( | ||
DatabaseQueryException, | ||
match="Multiple results found for input fields {'nom_amenageur'}", | ||
): | ||
get_or_create(db_session, amenageurs[0], fields={"nom_amenageur"}) |