Skip to content

Commit

Permalink
WIP: draft create statiques
Browse files Browse the repository at this point in the history
  • Loading branch information
jmaupetit committed Apr 29, 2024
1 parent 099654c commit 6604781
Show file tree
Hide file tree
Showing 9 changed files with 313 additions and 24 deletions.
4 changes: 3 additions & 1 deletion bin/kc-init
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ function kcadm() {
declare realm="qualicharge"
declare user="johndoe"
declare password="pass"
declare email="[email protected]"
declare client_id="api"
declare client_secret="super-secret"

Expand Down Expand Up @@ -48,9 +49,10 @@ echo "⚙️ Creating John Doe user…"
kcadm create users \
--target-realm "${realm}" \
--set username="${user}" \
--set email="${email}" \
--set enabled=true \
--output \
--fields id,username
--fields id,username,email

# Set test user password
echo "⚙️ Setting John Doe user password…"
Expand Down
4 changes: 3 additions & 1 deletion src/api/qualicharge/api/v1/routers/static.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,4 +128,6 @@ async def bulk(
statiques: BulkStatiqueList, session: Session = Depends(get_session)
) -> StatiqueItemsCreatedResponse:
"""Create a set of statique items."""
return StatiqueItemsCreatedResponse(items=save_statiques(session, statiques))
return StatiqueItemsCreatedResponse(
items=[statique for statique in save_statiques(session, statiques)]
)
3 changes: 3 additions & 0 deletions src/api/qualicharge/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,3 +23,6 @@ class ModelSerializerException(QualiChargeExceptionMixin, Exception):

class DatabaseQueryException(QualiChargeExceptionMixin, Exception):
"""Raised when a database query does not provide expected results."""

class DuplicateEntriesSubmitted(QualiChargeExceptionMixin, Exception):
"""Raised when submitted batch contains duplicated entries."""
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
"""add location address unique constraint
Revision ID: 8580168c2cef
Revises: da896549e09c
Create Date: 2024-04-29 17:23:43.423327
"""
from typing import Sequence, Union

from alembic import op
import sqlalchemy as sa
from sqlalchemy.dialects import postgresql

# revision identifiers, used by Alembic.
revision: str = '8580168c2cef'
down_revision: Union[str, None] = 'da896549e09c'
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None


def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.create_unique_constraint(None, 'localisation', ['adresse_station'])
# ### end Alembic commands ###


def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.drop_constraint(None, 'localisation', type_='unique')
# ### end Alembic commands ###
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
"""remove location unique together constraint
Revision ID: da896549e09c
Revises: b7d33b01adac
Create Date: 2024-04-29 17:20:56.671209
"""
from typing import Sequence, Union

from alembic import op
import sqlalchemy as sa
from sqlalchemy.dialects import postgresql

# revision identifiers, used by Alembic.
revision: str = 'da896549e09c'
down_revision: Union[str, None] = 'b7d33b01adac'
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None


def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.drop_constraint('localisation_adresse_station_coordonneesXY_key', 'localisation', type_='unique')
# ### end Alembic commands ###


def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.create_unique_constraint('localisation_adresse_station_coordonneesXY_key', 'localisation', ['adresse_station', 'coordonneesXY'])
# ### end Alembic commands ###
36 changes: 31 additions & 5 deletions src/api/qualicharge/schemas/static.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,11 @@ class Amenageur(BaseTimestampedSQLModel, table=True):
# Relationships
stations: List["Station"] = Relationship(back_populates="amenageur")

def __eq__(self, other) -> bool:
"""Assess instances equality given uniqueness criterions."""
fields = ("nom_amenageur", "siren_amenageur", "contact_amenageur")
return all(getattr(self, field) == getattr(other, field) for field in fields)


class Operateur(BaseTimestampedSQLModel, table=True):
"""Operateur table."""
Expand All @@ -68,6 +73,11 @@ class Operateur(BaseTimestampedSQLModel, table=True):
# Relationships
stations: List["Station"] = Relationship(back_populates="operateur")

def __eq__(self, other) -> bool:
"""Assess instances equality given uniqueness criterions."""
fields = ("nom_operateur", "contact_operateur", "telephone_operateur")
return all(getattr(self, field) == getattr(other, field) for field in fields)


class Enseigne(BaseTimestampedSQLModel, table=True):
"""Enseigne table."""
Expand All @@ -80,20 +90,21 @@ class Enseigne(BaseTimestampedSQLModel, table=True):
# Relationships
stations: List["Station"] = Relationship(back_populates="enseigne")

def __eq__(self, other) -> bool:
"""Assess instances equality given uniqueness criterions."""
fields = ("nom_enseigne",)
return all(getattr(self, field) == getattr(other, field) for field in fields)


class Localisation(BaseTimestampedSQLModel, table=True):
"""Localisation table."""

__table_args__ = BaseTimestampedSQLModel.__table_args__ + (
UniqueConstraint("adresse_station", "coordonneesXY"),
)

model_config = SQLModelConfig(
validate_assignment=True, arbitrary_types_allowed=True
)

id: Optional[UUID] = Field(default_factory=lambda: uuid4().hex, primary_key=True)
adresse_station: str
adresse_station: str = Field(unique=True)
code_insee_commune: Optional[str] = Field(regex=r"^([013-9]\d|2[AB1-9])\d{3}$")
coordonneesXY: DataGouvCoordinate = Field(
sa_type=Geometry(
Expand All @@ -107,6 +118,11 @@ class Localisation(BaseTimestampedSQLModel, table=True):
# Relationships
stations: List["Station"] = Relationship(back_populates="localisation")

def __eq__(self, other) -> bool:
"""Assess instances equality given uniqueness criterions."""
fields = ("adresse_station",)
return all(getattr(self, field) == getattr(other, field) for field in fields)

@staticmethod
def _coordinates_to_geometry_point(value: Coordinate):
"""Convert coordinate to Geometry point."""
Expand Down Expand Up @@ -183,6 +199,11 @@ class Station(BaseTimestampedSQLModel, table=True):

points_de_charge: List["PointDeCharge"] = Relationship(back_populates="station")

def __eq__(self, other) -> bool:
"""Assess instances equality given uniqueness criterions."""
fields = ("id_station_itinerance",)
return all(getattr(self, field) == getattr(other, field) for field in fields)


class PointDeCharge(BaseTimestampedSQLModel, table=True):
"""PointDeCharge table."""
Expand Down Expand Up @@ -211,6 +232,11 @@ class PointDeCharge(BaseTimestampedSQLModel, table=True):
observations: Optional[str]
cable_t2_attache: Optional[bool]

def __eq__(self, other) -> bool:
"""Assess instances equality given uniqueness criterions."""
fields = ("id_pdc_itinerance",)
return all(getattr(self, field) == getattr(other, field) for field in fields)

# Relationships
station_id: Optional[UUID] = Field(default=None, foreign_key="station.id")
station: Station = Relationship(back_populates="points_de_charge")
113 changes: 103 additions & 10 deletions src/api/qualicharge/schemas/utils.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
"""QualiCharge schemas utilities."""

from collections import namedtuple
import logging
from typing import Generator, List, Optional, Set, Tuple, Type, Union
from typing import Generator, List, Optional, Set, Tuple, Type

from sqlalchemy.exc import MultipleResultsFound
from sqlmodel import Session, SQLModel, select

from ..exceptions import DatabaseQueryException
from ..exceptions import DatabaseQueryException, DuplicateEntriesSubmitted
from ..models.static import Statique
from .static import (
Amenageur,
Expand All @@ -23,7 +24,7 @@


def get_or_create(
session: Session, entry: SQLModel, fields: Optional[Set] = None
session: Session, entry: SQLModel, fields: Optional[Set] = None, commit: bool = True
) -> Tuple[bool, SQLModel]:
"""Get or create schema instance.
Expand All @@ -32,6 +33,7 @@ def get_or_create(
entry: SQLModel schema instance
fields: entry fields used in database query to select target entry.
Defaults to None (use all fields).
commit: should we commit transation to database?
Returns:
A (bool, entry) tuple. The boolean states on the entry creation.
Expand All @@ -40,8 +42,9 @@ def get_or_create(
DatabaseQueryException: Found multiple entries given input fields.
"""
# Try to get entry from selected fields
statement = select(entry.__class__).filter_by(**entry.model_dump(include=fields))
logger.debug(f"{statement=}")
statement = select(entry.__class__).filter_by(
**entry.model_dump(include=fields, exclude=DB_TO_STATIC_EXCLUDED_FIELDS)
)
try:
db_entry = session.exec(statement).one_or_none()
except MultipleResultsFound as err:
Expand All @@ -55,8 +58,9 @@ def get_or_create(

# Create a new entry
session.add(entry)
session.commit()
session.refresh(entry)
if commit:
session.commit()
session.refresh(entry)
return True, entry


Expand Down Expand Up @@ -140,7 +144,9 @@ def save_statique(
session,
Localisation,
statique,
fields={"adresse_station", "coordonneesXY"},
fields={
"adresse_station",
},
update=update,
)

Expand All @@ -159,8 +165,90 @@ def save_statique(
return pdc_to_statique(pdc)


def save_statiques(session: Session, statiques: List[Statique]) -> List[Statique]:
def save_statiques(
session: Session, statiques: List[Statique]
) -> Generator[Statique, None, None]:
"""Save Statique instances to database in an efficient way."""
points_de_charge: List[PointDeCharge] = []
stations: List[Station] = []
amenageurs: List[Amenageur] = []
operateurs: List[Operateur] = []
enseignes: List[Enseigne] = []
localisations: List[Localisation] = []

StatiqueSchemasEntryIndex = namedtuple(
"StatiqueSchemasEntryIndex",
["pdc", "station", "amenageur", "operateur", "enseigne", "localisation"],
)

# Collect unique entries list per model and add references to those for each
# statique
statiques_db_refs: List[StatiqueSchemasEntryIndex] = []
for statique in statiques:
pdc = PointDeCharge(**statique.get_fields_for_schema(PointDeCharge))
station = Station(**statique.get_fields_for_schema(Station))
amenageur = Amenageur(**statique.get_fields_for_schema(Amenageur))
operateur = Operateur(**statique.get_fields_for_schema(Operateur))
enseigne = Enseigne(**statique.get_fields_for_schema(Enseigne))
localisation = Localisation(**statique.get_fields_for_schema(Localisation))

indexes = ()
for entry, entries in (
(pdc, points_de_charge),
(station, stations),
(amenageur, amenageurs),
(operateur, operateurs),
(enseigne, enseignes),
(localisation, localisations),
):
if entry not in entries:
entries.append(entry)
indexes += (entries.index(entry),)
statiques_db_refs.append(StatiqueSchemasEntryIndex(*indexes))

if len(points_de_charge) != len(statiques):
raise DuplicateEntriesSubmitted("Found duplicated entries in submitted data")

# Create database entries for each schema
for entries, fields in (
(points_de_charge, {"id_pdc_itinerance"}),
(stations, {"id_station_itinerance"}),
(amenageurs, None),
(operateurs, None),
(enseignes, None),
(localisations, {"adresse_station",},),
):
for idx, entry in enumerate(entries):
_, db_entry = get_or_create(session, entry, fields, commit=False)
entries[idx] = db_entry

# Commit transaction so that all expected table rows are created
session.commit()

# Handle relationships
for (
pdc_idx,
station_idx,
amenageur_idx,
operateur_idx,
enseigne_idx,
localisation_idx,
) in statiques_db_refs:
points_de_charge[pdc_idx].station_id = stations[station_idx].id # type: ignore[attr-defined]
stations[station_idx].amenageur_id = amenageurs[amenageur_idx].id # type: ignore[attr-defined]
stations[station_idx].operateur_id = operateurs[operateur_idx].id # type: ignore[attr-defined]
stations[station_idx].enseigne_id = enseignes[enseigne_idx].id # type: ignore[attr-defined]
stations[station_idx].localisation_id = localisations[localisation_idx].id # type: ignore[attr-defined]

session.add(points_de_charge[pdc_idx])
session.add(stations[station_idx])

# Commit transaction for relationships
session.commit()

for pdc in points_de_charge:
session.refresh(pdc)
yield pdc_to_statique(pdc)


def build_statique(session: Session, id_pdc_itinerance: str) -> Statique:
Expand All @@ -179,5 +267,10 @@ def list_statique(
session: Session, offset: int = 0, limit: int = 50
) -> Generator[Statique, None, None]:
"""List Statique entries."""
for pdc in session.exec(select(PointDeCharge).offset(offset).limit(limit)).all():
for pdc in session.exec(
select(PointDeCharge)
.order_by(PointDeCharge.id_pdc_itinerance)
.offset(offset)
.limit(limit)
).all():
yield pdc_to_statique(pdc)
9 changes: 3 additions & 6 deletions src/api/tests/api/v1/routers/test_statique.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ def test_list(client_auth):
assert response.status_code == status.HTTP_200_OK
json_response = response.json()
expected_size = 0
assert len(json_response) == expected_size
assert len(json_response.get("items")) == expected_size


def test_create(client_auth):
Expand All @@ -34,10 +34,8 @@ def test_create(client_auth):

def test_bulk(client_auth):
"""Test the /statique/bulk create endpoint."""
id_pdc_itinerance = "ESZUNE1111ER1"
data = StatiqueFactory.batch(
size=2,
id_pdc_itinerance=id_pdc_itinerance,
)

payload = [json.loads(d.model_dump_json()) for d in data]
Expand All @@ -47,8 +45,8 @@ def test_bulk(client_auth):
json_response = response.json()
assert json_response["message"] == "Statique items created"
assert json_response["size"] == len(payload)
assert json_response["items"][0]["id_pdc_itinerance"] == id_pdc_itinerance
assert json_response["items"][1]["id_pdc_itinerance"] == id_pdc_itinerance
assert json_response["items"][0]["id_pdc_itinerance"] == data[0].id_pdc_itinerance
assert json_response["items"][1]["id_pdc_itinerance"] == data[1].id_pdc_itinerance


def test_bulk_with_outbound_sizes(client_auth):
Expand All @@ -65,7 +63,6 @@ def test_bulk_with_outbound_sizes(client_auth):

data = StatiqueFactory.batch(
size=settings.API_BULK_CREATE_MAX_SIZE + 1,
id_pdc_itinerance=id_pdc_itinerance,
)
payload = [json.loads(d.model_dump_json()) for d in data]
response = client_auth.post("/statique/bulk", json=payload)
Expand Down
Loading

0 comments on commit 6604781

Please sign in to comment.