Skip to content

Commit

Permalink
✨(api) add dynamic enpoint filters
Browse files Browse the repository at this point in the history
We are now able to filter:

- status list by:
    * station
    * pdc
    * lower limit horodatage date
- PDC status history by:
    * lower limit horodatage date

WIP: add from filter

WIP: add status read history from filter
  • Loading branch information
jmaupetit committed May 15, 2024
1 parent 1658e44 commit 6a7f773
Show file tree
Hide file tree
Showing 2 changed files with 278 additions and 14 deletions.
110 changes: 97 additions & 13 deletions src/api/qualicharge/api/v1/routers/dynamic.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,12 @@
from typing import Annotated, List, cast

from annotated_types import Len
from fastapi import APIRouter, Depends, HTTPException, Path
from fastapi import APIRouter, Depends, HTTPException, Path, Query
from fastapi import status as fa_status
from pydantic import PastDatetime, StringConstraints
from sqlalchemy import func
from sqlalchemy.schema import Column as SAColumn
from sqlmodel import Session, select
from sqlmodel import Session, join, select

from qualicharge.conf import settings
from qualicharge.db import get_session
Expand All @@ -17,7 +18,7 @@
StatusCreate,
StatusRead,
)
from qualicharge.schemas import PointDeCharge, Status
from qualicharge.schemas import PointDeCharge, Station, Status
from qualicharge.schemas import Session as QCSession

logger = logging.getLogger(__name__)
Expand All @@ -33,13 +34,74 @@
BulkSessionCreateList = Annotated[
List[SessionCreate], Len(2, settings.API_SESSION_BULK_CREATE_MAX_SIZE)
]
IdItinerance = Annotated[
str,
StringConstraints(pattern="(?:(?:^|,)(^[A-Z]{2}[A-Z0-9]{4,33}$|Non concerné))+$"),
]


@router.get("/status/", tags=["Status"])
async def list_statuses(
from_: Annotated[
PastDatetime | None,
Query(
alias="from",
title="Date/time from",
description="The datetime from when we want statuses to be collected",
),
] = None,
pdc: Annotated[
List[IdItinerance] | None,
Query(
title="Point de charge",
description=(
"Filter status by `id_pdc_itinerance` "
"(can be provided multiple times)"
),
),
] = None,
station: Annotated[
List[IdItinerance] | None,
Query(
title="Station",
description=(
"Filter status by `id_station_itinerance` "
"(can be provided multiple times)"
),
),
] = None,
session: Session = Depends(get_session),
) -> List[StatusRead]:
"""List last known point of charge statuses."""
"""List last known points of charge status."""
pdc_ids_filter = set()

# Filter by station
if station:
pdc_ids_filter = set(
session.exec(
select(PointDeCharge.id)
.select_from(
join(
PointDeCharge,
Station,
cast(SAColumn, PointDeCharge.station_id)
== cast(SAColumn, Station.id),
)
)
.filter(cast(SAColumn, Station.id_station_itinerance).in_(station))
).all()
)

# Filter by point of charge
if pdc:
pdc_ids_filter = pdc_ids_filter | set(
session.exec(
select(PointDeCharge.id).filter(
cast(SAColumn, PointDeCharge.id_pdc_itinerance).in_(pdc)
)
).all()
)

# Get latest status per point of charge
latest_db_statuses_stmt = (
select(
Expand All @@ -49,13 +111,23 @@ async def list_statuses(
.group_by(cast(SAColumn, Status.point_de_charge_id))
.subquery()
)
db_statuses = session.exec(
select(Status).join_from(
Status,
latest_db_statuses_stmt,
Status.id == latest_db_statuses_stmt.c.status_id, # type: ignore[arg-type]
db_statuses_stmt = select(Status)

if from_:
db_statuses_stmt = db_statuses_stmt.where(Status.horodatage >= from_)

if len(pdc_ids_filter):
db_statuses_stmt = db_statuses_stmt.filter(
cast(SAColumn, Status.point_de_charge_id).in_(pdc_ids_filter)
)
).all()

db_statuses_stmt = db_statuses_stmt.join_from(
Status,
latest_db_statuses_stmt,
Status.id == latest_db_statuses_stmt.c.status_id, # type: ignore[arg-type]
)
db_statuses = session.exec(db_statuses_stmt).all()

return [
StatusRead(
**s.model_dump(
Expand Down Expand Up @@ -138,6 +210,14 @@ async def read_status_history(
),
),
],
from_: Annotated[
PastDatetime | None,
Query(
alias="from",
title="Date/time from",
description="The datetime from when we want statuses to be collected",
),
] = None,
session: Session = Depends(get_session),
) -> List[StatusRead]:
"""Read point of charge status history."""
Expand All @@ -153,11 +233,15 @@ async def read_status_history(
)

# Get latest statuses
db_statuses_stmt = select(Status).where(Status.point_de_charge_id == pdc_id)

if from_:
db_statuses_stmt = db_statuses_stmt.where(Status.horodatage >= from_)

db_statuses = session.exec(
select(Status)
.where(Status.point_de_charge_id == pdc_id)
.order_by(cast(SAColumn, Status.horodatage))
db_statuses_stmt.order_by(cast(SAColumn, Status.horodatage))
).all()

if not len(db_statuses):
raise HTTPException(
status_code=fa_status.HTTP_404_NOT_FOUND,
Expand Down
182 changes: 181 additions & 1 deletion src/api/tests/api/v1/routers/test_dynamic.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import json
from typing import cast
from urllib.parse import quote_plus

from fastapi import status
from sqlalchemy import func
Expand All @@ -14,7 +15,11 @@
StatusCreateFactory,
StatusFactory,
)
from qualicharge.factories.static import StatiqueFactory
from qualicharge.factories.static import (
PointDeChargeFactory,
StationFactory,
StatiqueFactory,
)
from qualicharge.models.dynamic import StatusRead
from qualicharge.schemas import PointDeCharge, Session, Status
from qualicharge.schemas.utils import save_statique, save_statiques
Expand Down Expand Up @@ -74,6 +79,131 @@ def test_list_statuses(db_session, client_auth):
assert db_status.etat_prise_type_ef == response_status.etat_prise_type_ef


def test_list_statuses_filters(db_session, client_auth): # noqa: PLR0915
"""Test the /status/ get endpoint filters."""
StationFactory.__session__ = db_session
PointDeChargeFactory.__session__ = db_session
StatusFactory.__session__ = db_session

# Create stations, points of charge and statuses
n_station = 2
n_pdc_by_station = 2
n_status_by_pdc = 2
stations = StationFactory.create_batch_sync(n_station)
for station in stations:
PointDeChargeFactory.create_batch_sync(n_pdc_by_station, station_id=station.id)
pdcs = db_session.exec(select(PointDeCharge)).all()
assert len(pdcs) == n_station * n_pdc_by_station
for pdc in pdcs:
StatusFactory.create_batch_sync(n_status_by_pdc, point_de_charge_id=pdc.id)
assert db_session.exec(select(func.count(Status.id))).one() == (
n_station * n_pdc_by_station * n_status_by_pdc
)

# List all latest statuses by pdc
response = client_auth.get("/dynamique/status/")
assert response.status_code == status.HTTP_200_OK
statuses = [StatusRead(**s) for s in response.json()]
assert len(statuses) == n_station * n_pdc_by_station

# Filter with invalid PDC
response = client_auth.get("/dynamique/status/?pdc=foo")
assert response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY

# Filter with one pdc
response = client_auth.get(f"/dynamique/status/?pdc={pdcs[0].id_pdc_itinerance}")
assert response.status_code == status.HTTP_200_OK
statuses = [StatusRead(**s) for s in response.json()]
assert len(statuses) == 1

# Filter with two pdcs
selected_pdc_indexes = (0, 1)
query = "&".join(
f"pdc={pdcs[idx].id_pdc_itinerance}" for idx in selected_pdc_indexes
)
response = client_auth.get(f"/dynamique/status/?{query}")
assert response.status_code == status.HTTP_200_OK
statuses = [StatusRead(**s) for s in response.json()]
assert len(statuses) == len(selected_pdc_indexes)

# Filter with invalid station
response = client_auth.get("/dynamique/status/?station=foo")
assert response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY

# Filter with one station
response = client_auth.get(
f"/dynamique/status/?station={stations[0].id_station_itinerance}"
)
assert response.status_code == status.HTTP_200_OK
statuses = [StatusRead(**s) for s in response.json()]
assert len(statuses) == n_pdc_by_station
assert {s.id_pdc_itinerance for s in statuses} == {
p.id_pdc_itinerance for p in stations[0].points_de_charge
}

# Filter with two stations
query = "&".join(f"station={station.id_station_itinerance}" for station in stations)
response = client_auth.get(f"/dynamique/status/?{query}")
assert response.status_code == status.HTTP_200_OK
statuses = [StatusRead(**s) for s in response.json()]
assert len(statuses) == n_station * n_pdc_by_station

# Filter with one station and one pdc from another station
query = (
f"station={stations[0].id_station_itinerance}&"
f"pdc={stations[1].points_de_charge[0].id_pdc_itinerance}"
)
response = client_auth.get(f"/dynamique/status/?{query}")
assert response.status_code == status.HTTP_200_OK
statuses = [StatusRead(**s) for s in response.json()]
# 2 for the station and 1 for the pdc (from a different station)
expected_statuses = 3
assert len(statuses) == expected_statuses

# Filter with one station and one pdc from the same station
query = (
f"station={stations[0].id_station_itinerance}&"
f"pdc={stations[0].points_de_charge[0].id_pdc_itinerance}"
)
response = client_auth.get(f"/dynamique/status/?{query}")
assert response.status_code == status.HTTP_200_OK
statuses = [StatusRead(**s) for s in response.json()]
# 2 for the station (as the extra pdc is from the same station)
expected_statuses = 2
assert len(statuses) == expected_statuses
assert {s.id_pdc_itinerance for s in statuses} == {
p.id_pdc_itinerance for p in stations[0].points_de_charge
}

# Filter with invalid from date time
response = client_auth.get("/dynamique/status/?from=foo")
assert response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY

# Filter with only latest status datetime
statuses = db_session.exec(select(Status).order_by(Status.horodatage)).all()
from_ = quote_plus(statuses[-1].horodatage.isoformat())
response = client_auth.get(f"/dynamique/status/?from={from_}")
assert response.status_code == status.HTTP_200_OK
statuses = [StatusRead(**s) for s in response.json()]
assert len(statuses) == 1

# Filter with only latest status datetime and pdc
response = client_auth.get(
f"/dynamique/status/?from={from_}&pdc={statuses[-1].id_pdc_itinerance}"
)
assert response.status_code == status.HTTP_200_OK
statuses = [StatusRead(**s) for s in response.json()]
assert len(statuses) == 1

# Filter with the oldest status datetime
statuses = db_session.exec(select(Status).order_by(Status.horodatage)).all()
from_ = quote_plus(statuses[0].horodatage.isoformat())
response = client_auth.get(f"/dynamique/status/?from={from_}")
assert response.status_code == status.HTTP_200_OK
statuses = [StatusRead(**s) for s in response.json()]
assert len(statuses) == n_station * n_pdc_by_station


def test_read_status_for_non_existing_point_of_charge(client_auth):
"""Test the /status/{id_pdc_itinerance} endpoint for unknown point of charge."""
response = client_auth.get("/dynamique/status/ESZUNE1111ER1")
Expand Down Expand Up @@ -220,6 +350,56 @@ def test_read_status_history(db_session, client_auth):
assert expected_status.etat_prise_type_ef == response_status.etat_prise_type_ef


def test_read_status_history_filters(db_session, client_auth):
"""Test the /status/{id_pdc_itinerance}/history endpoint filters."""
StatusFactory.__session__ = db_session

# Create the PointDeCharge
id_pdc_itinerance = "ESZUNE1111ER1"
save_statique(
db_session, StatiqueFactory.build(id_pdc_itinerance=id_pdc_itinerance)
)
pdc = db_session.exec(
select(PointDeCharge).where(
PointDeCharge.id_pdc_itinerance == id_pdc_itinerance
)
).one()

# Create 20 attached statuses
n_statuses = 20
StatusFactory.create_batch_sync(n_statuses, point_de_charge_id=pdc.id)
assert (
db_session.exec(
select(func.count(Status.id)).where(Status.point_de_charge_id == pdc.id)
).one()
== n_statuses
)
# All statuses
db_statuses = db_session.exec(
select(Status)
.where(Status.point_de_charge_id == pdc.id)
.order_by(Status.horodatage)
).all()

# Get latest status
from_ = quote_plus(db_statuses[-1].horodatage.isoformat())
response = client_auth.get(
f"/dynamique/status/{id_pdc_itinerance}/history?from={from_}"
)
assert response.status_code == status.HTTP_200_OK
response_statuses = [StatusRead(**s) for s in response.json()]
assert len(response_statuses) == 1

# Filter with the oldest status datetime
from_ = quote_plus(db_statuses[0].horodatage.isoformat())
response = client_auth.get(
f"/dynamique/status/{id_pdc_itinerance}/history?from={from_}"
)
assert response.status_code == status.HTTP_200_OK
statuses = [StatusRead(**s) for s in response.json()]
assert len(statuses) == n_statuses


def test_create_status_for_non_existing_point_of_charge(client_auth):
"""Test the /status/ create endpoint for non existing point of charge."""
id_pdc_itinerance = "ESZUNE1111ER1"
Expand Down

0 comments on commit 6a7f773

Please sign in to comment.