diff --git a/bin/kc-init b/bin/kc-init index 7d50f49b..82cf9afe 100755 --- a/bin/kc-init +++ b/bin/kc-init @@ -18,6 +18,7 @@ function kcadm() { declare realm="qualicharge" declare user="johndoe" declare password="pass" +declare email="john@doe.com" declare client_id="api" declare client_secret="super-secret" @@ -48,9 +49,10 @@ echo "⚙️ Creating John Doe user…" kcadm create users \ --target-realm "${realm}" \ --set username="${user}" \ + --set email="${email}" \ --set enabled=true \ --output \ - --fields id,username + --fields id,username,email # Set test user password echo "⚙️ Setting John Doe user password…" diff --git a/src/api/qualicharge/api/v1/routers/static.py b/src/api/qualicharge/api/v1/routers/static.py index 91eec803..852ce955 100644 --- a/src/api/qualicharge/api/v1/routers/static.py +++ b/src/api/qualicharge/api/v1/routers/static.py @@ -128,4 +128,6 @@ async def bulk( statiques: BulkStatiqueList, session: Session = Depends(get_session) ) -> StatiqueItemsCreatedResponse: """Create a set of statique items.""" - return StatiqueItemsCreatedResponse(items=save_statiques(session, statiques)) + return StatiqueItemsCreatedResponse( + items=[statique for statique in save_statiques(session, statiques)] + ) diff --git a/src/api/qualicharge/exceptions.py b/src/api/qualicharge/exceptions.py index 433a8234..f87ce509 100644 --- a/src/api/qualicharge/exceptions.py +++ b/src/api/qualicharge/exceptions.py @@ -23,3 +23,6 @@ class ModelSerializerException(QualiChargeExceptionMixin, Exception): class DatabaseQueryException(QualiChargeExceptionMixin, Exception): """Raised when a database query does not provide expected results.""" + +class DuplicateEntriesSubmitted(QualiChargeExceptionMixin, Exception): + """Raised when submitted batch contains duplicated entries.""" diff --git a/src/api/qualicharge/migrations/versions/8580168c2cef_add_location_address_unique_constraint.py b/src/api/qualicharge/migrations/versions/8580168c2cef_add_location_address_unique_constraint.py new file mode 100644 index 00000000..cdd882f6 --- /dev/null +++ b/src/api/qualicharge/migrations/versions/8580168c2cef_add_location_address_unique_constraint.py @@ -0,0 +1,30 @@ +"""add location address unique constraint + +Revision ID: 8580168c2cef +Revises: da896549e09c +Create Date: 2024-04-29 17:23:43.423327 + +""" +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa +from sqlalchemy.dialects import postgresql + +# revision identifiers, used by Alembic. +revision: str = '8580168c2cef' +down_revision: Union[str, None] = 'da896549e09c' +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.create_unique_constraint(None, 'localisation', ['adresse_station']) + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.drop_constraint(None, 'localisation', type_='unique') + # ### end Alembic commands ### diff --git a/src/api/qualicharge/migrations/versions/da896549e09c_remove_location_unique_together_.py b/src/api/qualicharge/migrations/versions/da896549e09c_remove_location_unique_together_.py new file mode 100644 index 00000000..6eca3683 --- /dev/null +++ b/src/api/qualicharge/migrations/versions/da896549e09c_remove_location_unique_together_.py @@ -0,0 +1,30 @@ +"""remove location unique together constraint + +Revision ID: da896549e09c +Revises: b7d33b01adac +Create Date: 2024-04-29 17:20:56.671209 + +""" +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa +from sqlalchemy.dialects import postgresql + +# revision identifiers, used by Alembic. +revision: str = 'da896549e09c' +down_revision: Union[str, None] = 'b7d33b01adac' +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.drop_constraint('localisation_adresse_station_coordonneesXY_key', 'localisation', type_='unique') + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.create_unique_constraint('localisation_adresse_station_coordonneesXY_key', 'localisation', ['adresse_station', 'coordonneesXY']) + # ### end Alembic commands ### diff --git a/src/api/qualicharge/schemas/static.py b/src/api/qualicharge/schemas/static.py index 8d86b341..f2568d23 100644 --- a/src/api/qualicharge/schemas/static.py +++ b/src/api/qualicharge/schemas/static.py @@ -50,6 +50,11 @@ class Amenageur(BaseTimestampedSQLModel, table=True): # Relationships stations: List["Station"] = Relationship(back_populates="amenageur") + def __eq__(self, other) -> bool: + """Assess instances equality given uniqueness criterions.""" + fields = ("nom_amenageur", "siren_amenageur", "contact_amenageur") + return all(getattr(self, field) == getattr(other, field) for field in fields) + class Operateur(BaseTimestampedSQLModel, table=True): """Operateur table.""" @@ -68,6 +73,11 @@ class Operateur(BaseTimestampedSQLModel, table=True): # Relationships stations: List["Station"] = Relationship(back_populates="operateur") + def __eq__(self, other) -> bool: + """Assess instances equality given uniqueness criterions.""" + fields = ("nom_operateur", "contact_operateur", "telephone_operateur") + return all(getattr(self, field) == getattr(other, field) for field in fields) + class Enseigne(BaseTimestampedSQLModel, table=True): """Enseigne table.""" @@ -80,20 +90,21 @@ class Enseigne(BaseTimestampedSQLModel, table=True): # Relationships stations: List["Station"] = Relationship(back_populates="enseigne") + def __eq__(self, other) -> bool: + """Assess instances equality given uniqueness criterions.""" + fields = ("nom_enseigne",) + return all(getattr(self, field) == getattr(other, field) for field in fields) + class Localisation(BaseTimestampedSQLModel, table=True): """Localisation table.""" - __table_args__ = BaseTimestampedSQLModel.__table_args__ + ( - UniqueConstraint("adresse_station", "coordonneesXY"), - ) - model_config = SQLModelConfig( validate_assignment=True, arbitrary_types_allowed=True ) id: Optional[UUID] = Field(default_factory=lambda: uuid4().hex, primary_key=True) - adresse_station: str + adresse_station: str = Field(unique=True) code_insee_commune: Optional[str] = Field(regex=r"^([013-9]\d|2[AB1-9])\d{3}$") coordonneesXY: DataGouvCoordinate = Field( sa_type=Geometry( @@ -107,6 +118,11 @@ class Localisation(BaseTimestampedSQLModel, table=True): # Relationships stations: List["Station"] = Relationship(back_populates="localisation") + def __eq__(self, other) -> bool: + """Assess instances equality given uniqueness criterions.""" + fields = ("adresse_station",) + return all(getattr(self, field) == getattr(other, field) for field in fields) + @staticmethod def _coordinates_to_geometry_point(value: Coordinate): """Convert coordinate to Geometry point.""" @@ -183,6 +199,11 @@ class Station(BaseTimestampedSQLModel, table=True): points_de_charge: List["PointDeCharge"] = Relationship(back_populates="station") + def __eq__(self, other) -> bool: + """Assess instances equality given uniqueness criterions.""" + fields = ("id_station_itinerance",) + return all(getattr(self, field) == getattr(other, field) for field in fields) + class PointDeCharge(BaseTimestampedSQLModel, table=True): """PointDeCharge table.""" @@ -211,6 +232,11 @@ class PointDeCharge(BaseTimestampedSQLModel, table=True): observations: Optional[str] cable_t2_attache: Optional[bool] + def __eq__(self, other) -> bool: + """Assess instances equality given uniqueness criterions.""" + fields = ("id_pdc_itinerance",) + return all(getattr(self, field) == getattr(other, field) for field in fields) + # Relationships station_id: Optional[UUID] = Field(default=None, foreign_key="station.id") station: Station = Relationship(back_populates="points_de_charge") diff --git a/src/api/qualicharge/schemas/utils.py b/src/api/qualicharge/schemas/utils.py index 241ec016..325c8d63 100644 --- a/src/api/qualicharge/schemas/utils.py +++ b/src/api/qualicharge/schemas/utils.py @@ -1,12 +1,13 @@ """QualiCharge schemas utilities.""" +from collections import namedtuple import logging -from typing import Generator, List, Optional, Set, Tuple, Type, Union +from typing import Generator, List, Optional, Set, Tuple, Type from sqlalchemy.exc import MultipleResultsFound from sqlmodel import Session, SQLModel, select -from ..exceptions import DatabaseQueryException +from ..exceptions import DatabaseQueryException, DuplicateEntriesSubmitted from ..models.static import Statique from .static import ( Amenageur, @@ -23,7 +24,7 @@ def get_or_create( - session: Session, entry: SQLModel, fields: Optional[Set] = None + session: Session, entry: SQLModel, fields: Optional[Set] = None, commit: bool = True ) -> Tuple[bool, SQLModel]: """Get or create schema instance. @@ -32,6 +33,7 @@ def get_or_create( entry: SQLModel schema instance fields: entry fields used in database query to select target entry. Defaults to None (use all fields). + commit: should we commit transation to database? Returns: A (bool, entry) tuple. The boolean states on the entry creation. @@ -40,8 +42,9 @@ def get_or_create( DatabaseQueryException: Found multiple entries given input fields. """ # Try to get entry from selected fields - statement = select(entry.__class__).filter_by(**entry.model_dump(include=fields)) - logger.debug(f"{statement=}") + statement = select(entry.__class__).filter_by( + **entry.model_dump(include=fields, exclude=DB_TO_STATIC_EXCLUDED_FIELDS) + ) try: db_entry = session.exec(statement).one_or_none() except MultipleResultsFound as err: @@ -55,8 +58,9 @@ def get_or_create( # Create a new entry session.add(entry) - session.commit() - session.refresh(entry) + if commit: + session.commit() + session.refresh(entry) return True, entry @@ -140,7 +144,9 @@ def save_statique( session, Localisation, statique, - fields={"adresse_station", "coordonneesXY"}, + fields={ + "adresse_station", + }, update=update, ) @@ -159,8 +165,90 @@ def save_statique( return pdc_to_statique(pdc) -def save_statiques(session: Session, statiques: List[Statique]) -> List[Statique]: +def save_statiques( + session: Session, statiques: List[Statique] +) -> Generator[Statique, None, None]: """Save Statique instances to database in an efficient way.""" + points_de_charge: List[PointDeCharge] = [] + stations: List[Station] = [] + amenageurs: List[Amenageur] = [] + operateurs: List[Operateur] = [] + enseignes: List[Enseigne] = [] + localisations: List[Localisation] = [] + + StatiqueSchemasEntryIndex = namedtuple( + "StatiqueSchemasEntryIndex", + ["pdc", "station", "amenageur", "operateur", "enseigne", "localisation"], + ) + + # Collect unique entries list per model and add references to those for each + # statique + statiques_db_refs: List[StatiqueSchemasEntryIndex] = [] + for statique in statiques: + pdc = PointDeCharge(**statique.get_fields_for_schema(PointDeCharge)) + station = Station(**statique.get_fields_for_schema(Station)) + amenageur = Amenageur(**statique.get_fields_for_schema(Amenageur)) + operateur = Operateur(**statique.get_fields_for_schema(Operateur)) + enseigne = Enseigne(**statique.get_fields_for_schema(Enseigne)) + localisation = Localisation(**statique.get_fields_for_schema(Localisation)) + + indexes = () + for entry, entries in ( + (pdc, points_de_charge), + (station, stations), + (amenageur, amenageurs), + (operateur, operateurs), + (enseigne, enseignes), + (localisation, localisations), + ): + if entry not in entries: + entries.append(entry) + indexes += (entries.index(entry),) + statiques_db_refs.append(StatiqueSchemasEntryIndex(*indexes)) + + if len(points_de_charge) != len(statiques): + raise DuplicateEntriesSubmitted("Found duplicated entries in submitted data") + + # Create database entries for each schema + for entries, fields in ( + (points_de_charge, {"id_pdc_itinerance"}), + (stations, {"id_station_itinerance"}), + (amenageurs, None), + (operateurs, None), + (enseignes, None), + (localisations, {"adresse_station",},), + ): + for idx, entry in enumerate(entries): + _, db_entry = get_or_create(session, entry, fields, commit=False) + entries[idx] = db_entry + + # Commit transaction so that all expected table rows are created + session.commit() + + # Handle relationships + for ( + pdc_idx, + station_idx, + amenageur_idx, + operateur_idx, + enseigne_idx, + localisation_idx, + ) in statiques_db_refs: + points_de_charge[pdc_idx].station_id = stations[station_idx].id # type: ignore[attr-defined] + stations[station_idx].amenageur_id = amenageurs[amenageur_idx].id # type: ignore[attr-defined] + stations[station_idx].operateur_id = operateurs[operateur_idx].id # type: ignore[attr-defined] + stations[station_idx].enseigne_id = enseignes[enseigne_idx].id # type: ignore[attr-defined] + stations[station_idx].localisation_id = localisations[localisation_idx].id # type: ignore[attr-defined] + + session.add(points_de_charge[pdc_idx]) + session.add(stations[station_idx]) + + # Commit transaction for relationships + session.commit() + + for pdc in points_de_charge: + session.refresh(pdc) + yield pdc_to_statique(pdc) def build_statique(session: Session, id_pdc_itinerance: str) -> Statique: @@ -179,5 +267,10 @@ def list_statique( session: Session, offset: int = 0, limit: int = 50 ) -> Generator[Statique, None, None]: """List Statique entries.""" - for pdc in session.exec(select(PointDeCharge).offset(offset).limit(limit)).all(): + for pdc in session.exec( + select(PointDeCharge) + .order_by(PointDeCharge.id_pdc_itinerance) + .offset(offset) + .limit(limit) + ).all(): yield pdc_to_statique(pdc) diff --git a/src/api/tests/api/v1/routers/test_statique.py b/src/api/tests/api/v1/routers/test_statique.py index d6823c12..2ee5e852 100644 --- a/src/api/tests/api/v1/routers/test_statique.py +++ b/src/api/tests/api/v1/routers/test_statique.py @@ -14,7 +14,7 @@ def test_list(client_auth): assert response.status_code == status.HTTP_200_OK json_response = response.json() expected_size = 0 - assert len(json_response) == expected_size + assert len(json_response.get("items")) == expected_size def test_create(client_auth): @@ -34,10 +34,8 @@ def test_create(client_auth): def test_bulk(client_auth): """Test the /statique/bulk create endpoint.""" - id_pdc_itinerance = "ESZUNE1111ER1" data = StatiqueFactory.batch( size=2, - id_pdc_itinerance=id_pdc_itinerance, ) payload = [json.loads(d.model_dump_json()) for d in data] @@ -47,8 +45,8 @@ def test_bulk(client_auth): json_response = response.json() assert json_response["message"] == "Statique items created" assert json_response["size"] == len(payload) - assert json_response["items"][0]["id_pdc_itinerance"] == id_pdc_itinerance - assert json_response["items"][1]["id_pdc_itinerance"] == id_pdc_itinerance + assert json_response["items"][0]["id_pdc_itinerance"] == data[0].id_pdc_itinerance + assert json_response["items"][1]["id_pdc_itinerance"] == data[1].id_pdc_itinerance def test_bulk_with_outbound_sizes(client_auth): @@ -65,7 +63,6 @@ def test_bulk_with_outbound_sizes(client_auth): data = StatiqueFactory.batch( size=settings.API_BULK_CREATE_MAX_SIZE + 1, - id_pdc_itinerance=id_pdc_itinerance, ) payload = [json.loads(d.model_dump_json()) for d in data] response = client_auth.post("/statique/bulk", json=payload) diff --git a/src/api/tests/schemas/test_utils.py b/src/api/tests/schemas/test_utils.py index c9193787..3476aa53 100644 --- a/src/api/tests/schemas/test_utils.py +++ b/src/api/tests/schemas/test_utils.py @@ -1,9 +1,10 @@ """QualiCharge schemas utilities tests.""" import pytest +from sqlalchemy import func from sqlmodel import select -from qualicharge.exceptions import DatabaseQueryException +from qualicharge.exceptions import DatabaseQueryException, DuplicateEntriesSubmitted from qualicharge.factories.static import AmenageurFactory, StatiqueFactory from qualicharge.schemas.static import ( Amenageur, @@ -16,8 +17,10 @@ from qualicharge.schemas.utils import ( build_statique, get_or_create, + list_statique, save_schema_from_statique, save_statique, + save_statiques, ) @@ -176,6 +179,88 @@ def test_save_statique(db_session): assert db_statique.cable_t2_attache == pdc.cable_t2_attache +def test_save_statiques(db_session): + """Test save_statiques utility.""" + statiques = StatiqueFactory.batch(2) + + db_statiques = list(save_statiques(db_session, statiques)) + assert db_statiques[0] == statiques[0] + assert db_statiques[1] == statiques[1] + + +def test_save_statiques_with_same_amenageur(db_session): + """Test save_statiques utility with the same amenageur.""" + statiques = StatiqueFactory.batch( + 2, + nom_amenageur="ACME Inc.", + siren_amenageur="123456789", + contact_amenageur="john.doe@acme.com", + ) + + db_statiques = list(save_statiques(db_session, statiques)) + assert db_statiques[0] == statiques[0] + assert db_statiques[1] == statiques[1] + + # We should only have created one Amenageur and two PointDeCharge + assert db_session.exec(select(func.count(Amenageur.siren_amenageur))).one() == 1 + assert db_session.exec( + select(func.count(PointDeCharge.id_pdc_itinerance)) + ).one() == len(statiques) + + +def test_save_statiques_with_same_localisation(db_session): + """Test save_statiques utility with the same localisation.""" + statiques = StatiqueFactory.batch( + 2, + adresse_station="221B Baker street, London", + code_insee_commune="21231", + coordonneesXY="[-3.129447,45.700327]", + ) + + db_statiques = list(save_statiques(db_session, statiques)) + assert db_statiques[0] == statiques[0] + assert db_statiques[1] == statiques[1] + + # We should only have created one Amenageur and two PointDeCharge + assert db_session.exec(select(func.count(Localisation.adresse_station))).one() == 1 + assert db_session.exec( + select(func.count(PointDeCharge.id_pdc_itinerance)) + ).one() == len(statiques) + + +def test_save_statiques_with_same_amenageur_twice(db_session): + """Test save_statiques utility with the same amenageur, twice.""" + statiques = StatiqueFactory.batch( + 2, + nom_amenageur="ACME Inc.", + siren_amenageur="123456789", + contact_amenageur="john.doe@acme.com", + ) + + db_statiques = list(save_statiques(db_session, statiques)) + db_statiques = list(save_statiques(db_session, statiques)) + assert db_statiques[0] == statiques[0] + assert db_statiques[1] == statiques[1] + + # We should only have created one Amenageur and two PointDeCharge + assert db_session.exec(select(func.count(Amenageur.siren_amenageur))).one() == 1 + assert db_session.exec( + select(func.count(PointDeCharge.id_pdc_itinerance)) + ).one() == len(statiques) + + +def test_save_statiques_with_same_entries(db_session): + """Test save_statiques utility with the same id_pdc_itinerance.""" + statiques = StatiqueFactory.batch(2) + statiques.append(statiques[0]) + + with pytest.raises( + DuplicateEntriesSubmitted, + match="Found duplicated entries in submitted data", + ): + list(save_statiques(db_session, statiques)) + + def test_build_statique(db_session): """Test build_statique utility.""" # Create a Statique instance and save it to database @@ -193,3 +278,24 @@ def test_build_statique(db_session): assert db_statique == statique db_another_statique = build_statique(db_session, another_statique.id_pdc_itinerance) assert db_another_statique == another_statique + + +def test_list_statique(db_session): + """Test list_statique utility.""" + # Create statiques in database + ids_pdc_itinerance = [] + n_statiques = 22 + for _ in range(n_statiques): + statique = StatiqueFactory.build() + ids_pdc_itinerance.append(statique.id_pdc_itinerance) + save_statique(db_session, statique, update=False) + ids_pdc_itinerance.sort() + + limit = 10 + for offset in range(0, 30, 10): + statiques = list(list_statique(db_session, offset=offset, limit=limit)) + size = limit if offset + limit < n_statiques else n_statiques - offset + assert len(statiques) == size + assert {statique.id_pdc_itinerance for statique in statiques} == set( + ids_pdc_itinerance[offset : offset + size] + )