-
Notifications
You must be signed in to change notification settings - Fork 2
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
philippe thomy
committed
Jan 31, 2025
1 parent
820b91c
commit f18192d
Showing
5 changed files
with
369 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,191 @@ | ||
"""QualiCharge prefect indicators: usage. | ||
U6: session duration by power category. | ||
""" | ||
|
||
from string import Template | ||
from typing import List | ||
from uuid import UUID | ||
|
||
import numpy as np | ||
import pandas as pd # type: ignore | ||
from prefect import flow, runtime, task | ||
from prefect.futures import wait | ||
from prefect.task_runners import ThreadPoolTaskRunner | ||
from sqlalchemy.engine import Connection | ||
|
||
from ..conf import settings | ||
from ..models import Indicator, IndicatorTimeSpan, Level | ||
from ..utils import ( | ||
POWER_RANGE_CTE, | ||
export_indic, | ||
get_database_engine, | ||
get_num_for_level_query_params, | ||
get_targets_for_level, | ||
get_timespan_filter_query_params, | ||
) | ||
|
||
DURATION_FOR_LEVEL_QUERY_TEMPLATE = """ | ||
WITH | ||
$power_range, | ||
sessionf AS ( | ||
SELECT | ||
point_de_charge_id, | ||
sum(session.end -session.start) as duree_pdc | ||
FROM | ||
session | ||
WHERE | ||
$timespan | ||
GROUP BY | ||
point_de_charge_id | ||
) | ||
SELECT | ||
extract ('epoch' from sum(duree_pdc)) / 3600.0 AS value, | ||
category, | ||
$level_id AS level_id | ||
FROM | ||
sessionf | ||
INNER JOIN PointDeCharge ON sessionf.point_de_charge_id = PointDeCharge.id | ||
LEFT JOIN Station ON station_id = Station.id | ||
LEFT JOIN Localisation ON localisation_id = Localisation.id | ||
LEFT JOIN City ON City.code = code_insee_commune | ||
LEFT JOIN puissance ON puissance_nominale::numeric <@ category | ||
$join_extras | ||
WHERE | ||
$level_id IN ($indexes) | ||
GROUP BY | ||
$level_id, | ||
category | ||
""" | ||
|
||
QUERY_NATIONAL_TEMPLATE = """ | ||
WITH | ||
$power_range, | ||
sessionf AS ( | ||
SELECT | ||
point_de_charge_id, | ||
sum(session.end -session.start) as duree_pdc | ||
FROM | ||
session | ||
WHERE | ||
$timespan | ||
GROUP BY | ||
point_de_charge_id | ||
) | ||
SELECT | ||
extract ('epoch' from sum(duree_pdc)) / 3600.0 AS value, | ||
category | ||
FROM | ||
sessionf | ||
INNER JOIN PointDeCharge ON sessionf.point_de_charge_id = PointDeCharge.id | ||
LEFT JOIN puissance ON puissance_nominale::numeric <@ category | ||
GROUP BY | ||
category | ||
""" | ||
|
||
|
||
@task(task_run_name="values-for-target-{level:02d}") | ||
def get_values_for_targets( | ||
connection: Connection, | ||
level: Level, | ||
timespan: IndicatorTimeSpan, | ||
indexes: List[UUID], | ||
) -> pd.DataFrame: | ||
"""Fetch sessions given input level, timestamp and target index.""" | ||
query_template = Template(DURATION_FOR_LEVEL_QUERY_TEMPLATE) | ||
query_params = {"indexes": ",".join(f"'{i}'" for i in map(str, indexes))} | ||
query_params |= POWER_RANGE_CTE | ||
query_params |= get_num_for_level_query_params(level) | ||
query_params |= get_timespan_filter_query_params(timespan, session=True) | ||
return pd.read_sql_query(query_template.substitute(query_params), con=connection) | ||
|
||
|
||
@flow( | ||
task_runner=ThreadPoolTaskRunner(max_workers=settings.THREAD_POOL_MAX_WORKERS), | ||
flow_run_name="u6-{timespan.period.value}-{level:02d}-{timespan.start:%y-%m-%d}", | ||
) | ||
def u6_for_level( | ||
level: Level, | ||
timespan: IndicatorTimeSpan, | ||
chunk_size=settings.DEFAULT_CHUNK_SIZE, | ||
) -> pd.DataFrame: | ||
"""Calculate u6 for a level and a timestamp.""" | ||
if level == Level.NATIONAL: | ||
return u6_national(timespan) | ||
engine = get_database_engine() | ||
with engine.connect() as connection: | ||
targets = get_targets_for_level(connection, level) | ||
ids = targets["id"] | ||
chunks = ( | ||
np.array_split(ids, int(len(ids) / chunk_size)) | ||
if len(ids) > chunk_size | ||
else [ids.to_numpy()] | ||
) | ||
futures = [ | ||
get_values_for_targets.submit(connection, level, timespan, chunk) # type: ignore[call-overload] | ||
for chunk in chunks | ||
] | ||
wait(futures) | ||
|
||
# Concatenate results and serialize indicators | ||
results = pd.concat([future.result() for future in futures], ignore_index=True) | ||
merged = targets.merge(results, how="left", left_on="id", right_on="level_id") | ||
|
||
# Build result DataFrame | ||
indicators = { | ||
"target": merged["code"], | ||
"value": merged["value"].fillna(0), | ||
"code": "u6", | ||
"level": level, | ||
"period": timespan.period, | ||
"timestamp": timespan.start.isoformat(), | ||
"category": merged["category"].astype("str"), | ||
"extras": None, | ||
} | ||
return pd.DataFrame(indicators) | ||
|
||
|
||
@flow( | ||
task_runner=ThreadPoolTaskRunner(max_workers=settings.THREAD_POOL_MAX_WORKERS), | ||
flow_run_name="u6-{timespan.period.value}-00-{timespan.start:%y-%m-%d}", | ||
) | ||
def u6_national(timespan: IndicatorTimeSpan) -> pd.DataFrame: | ||
"""Calculate u6 at the national level.""" | ||
engine = get_database_engine() | ||
query_template = Template(QUERY_NATIONAL_TEMPLATE) | ||
query_params = get_timespan_filter_query_params(timespan, session=True) | ||
query_params |= POWER_RANGE_CTE | ||
with engine.connect() as connection: | ||
res = pd.read_sql_query(query_template.substitute(query_params), con=connection) | ||
indicators = { | ||
"target": None, | ||
"value": res["value"].fillna(0), | ||
"code": "u6", | ||
"level": Level.NATIONAL, | ||
"period": timespan.period, | ||
"timestamp": timespan.start.isoformat(), | ||
"category": res["category"].astype("str"), | ||
"extras": None, | ||
} | ||
return pd.DataFrame(indicators) | ||
|
||
|
||
@flow( | ||
task_runner=ThreadPoolTaskRunner(max_workers=settings.THREAD_POOL_MAX_WORKERS), | ||
flow_run_name="meta-u6-{timespan.period.value}", | ||
) | ||
def calculate( | ||
timespan: IndicatorTimeSpan, | ||
levels: List[Level], | ||
create_artifact: bool = False, | ||
chunk_size: int = 1000, | ||
format_pd: bool = False, | ||
) -> List[Indicator]: | ||
"""Run all u6 subflows.""" | ||
subflows_results = [ | ||
u6_for_level(level, timespan, chunk_size=chunk_size) for level in levels | ||
] | ||
indicators = pd.concat(subflows_results, ignore_index=True) | ||
description = f"u6 report at {timespan.start} (period: {timespan.period.value})" | ||
flow_name = runtime.flow_run.name | ||
return export_indic(indicators, create_artifact, flow_name, description, format_pd) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
File renamed without changes.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,149 @@ | ||
"""QualiCharge prefect indicators tests: usage. | ||
U6: session duration by power category. | ||
""" | ||
|
||
from datetime import datetime | ||
|
||
import pytest # type: ignore | ||
from sqlalchemy import text | ||
|
||
from indicators.models import IndicatorPeriod, IndicatorTimeSpan, Level # type: ignore | ||
from indicators.usage import u6 # type: ignore | ||
|
||
from ..param_tests import ( | ||
PARAM_FLOW, | ||
PARAM_VALUE, | ||
PARAMETERS_CHUNK, | ||
) | ||
|
||
# expected result for level [city, epci, dpt, reg] | ||
N_LEVEL = [18, 167, 85, 584] | ||
N_LEVEL_NATIONAL = 1509 | ||
|
||
TIMESPAN = IndicatorTimeSpan(start=datetime(2024, 12, 24), period=IndicatorPeriod.DAY) | ||
PARAMETERS_FLOW = [prm + (lvl,) for prm, lvl in zip(PARAM_FLOW, N_LEVEL, strict=True)] | ||
PARAMETERS_VALUE = [prm + (lvl,) for prm, lvl in zip(PARAM_VALUE, N_LEVEL, strict=True)] | ||
|
||
|
||
@pytest.mark.parametrize("level,query,expected", PARAMETERS_VALUE) | ||
def test_task_get_values_for_target(db_connection, level, query, expected): | ||
"""Test the `get_values_for_target` task.""" | ||
result = db_connection.execute(text(query)) | ||
indexes = list(result.scalars().all()) | ||
values = u6.get_values_for_targets.fn(db_connection, level, TIMESPAN, indexes) | ||
assert len(set(values["level_id"])) == len(indexes) | ||
assert int(values["value"].sum()) == expected | ||
|
||
|
||
def test_task_get_values_for_target_unexpected_level(db_connection): | ||
"""Test the `get_values_for_target` task (unknown level).""" | ||
with pytest.raises(NotImplementedError, match="Unsupported level"): | ||
u6.get_values_for_targets.fn(db_connection, Level.NATIONAL, TIMESPAN, []) | ||
|
||
|
||
@pytest.mark.parametrize("level,query,targets,expected", PARAMETERS_FLOW) | ||
def test_flow_u6_for_level(db_connection, level, query, targets, expected): | ||
"""Test the `u6_for_level` flow.""" | ||
indicators = u6.u6_for_level(level, TIMESPAN, chunk_size=1000) | ||
# assert len(indicators) == db_connection.execute(text(query)).scalars().one() | ||
assert ( | ||
int(indicators.loc[indicators["target"].isin(targets), "value"].sum()) | ||
== expected | ||
) | ||
|
||
|
||
@pytest.mark.parametrize("chunk_size", PARAMETERS_CHUNK) | ||
def test_flow_u6_for_level_with_various_chunk_sizes(chunk_size): | ||
"""Test the `u6_for_level` flow with various chunk sizes.""" | ||
level, query, targets, expected = PARAMETERS_FLOW[2] | ||
indicators = u6.u6_for_level(level, TIMESPAN, chunk_size=chunk_size) | ||
# assert len(indicators) == N_DPTS | ||
assert ( | ||
int(indicators.loc[indicators["target"].isin(targets), "value"].sum()) | ||
== expected | ||
) | ||
|
||
|
||
def test_flow_u6_national(db_connection): | ||
"""Test the `u6_national` flow.""" | ||
indicators = u6.u6_national(TIMESPAN) | ||
assert int(indicators["value"].sum()) == N_LEVEL_NATIONAL | ||
|
||
|
||
def test_flow_u6_calculate(db_connection): | ||
"""Test the `calculate` flow.""" | ||
all_levels = [ | ||
Level.NATIONAL, | ||
Level.REGION, | ||
Level.DEPARTMENT, | ||
Level.CITY, | ||
Level.EPCI, | ||
] | ||
indicators = u6.calculate( | ||
TIMESPAN, all_levels, create_artifact=True, format_pd=True | ||
) | ||
assert list(indicators["level"].unique()) == all_levels | ||
|
||
|
||
# query used to get N_LEVEL | ||
N_LEVEL_NAT = """ | ||
WITH | ||
sessionf AS ( | ||
SELECT | ||
point_de_charge_id, | ||
sum(SESSION.end - SESSION.start) AS duree_pdc | ||
FROM | ||
SESSION | ||
WHERE | ||
START >= date '2024-12-24' | ||
AND START < date '2024-12-25' | ||
GROUP BY | ||
point_de_charge_id | ||
) | ||
SELECT | ||
extract( | ||
'epoch' | ||
FROM | ||
sum(duree_pdc) | ||
) / 3600.0 AS duree | ||
FROM | ||
sessionf | ||
INNER JOIN PointDeCharge ON sessionf.point_de_charge_id = PointDeCharge.id | ||
LEFT JOIN station ON station_id = station.id | ||
LEFT JOIN localisation ON localisation_id = localisation.id | ||
LEFT JOIN city ON city.code = code_insee_commune | ||
LEFT JOIN department ON city.department_id = department.id | ||
LEFT JOIN region ON department.region_id = region.id | ||
""" | ||
N_LEVEL_3 = """ | ||
WITH | ||
sessionf AS ( | ||
SELECT | ||
point_de_charge_id, | ||
sum(SESSION.end - SESSION.start) AS duree_pdc | ||
FROM | ||
SESSION | ||
WHERE | ||
START >= date '2024-12-24' | ||
AND START < date '2024-12-25' | ||
GROUP BY | ||
point_de_charge_id | ||
) | ||
SELECT | ||
extract( | ||
'epoch' | ||
FROM | ||
sum(duree_pdc) | ||
) / 3600.0 AS duree | ||
FROM | ||
sessionf | ||
INNER JOIN PointDeCharge ON sessionf.point_de_charge_id = PointDeCharge.id | ||
LEFT JOIN station ON station_id = station.id | ||
LEFT JOIN localisation ON localisation_id = localisation.id | ||
LEFT JOIN city ON city.code = code_insee_commune | ||
LEFT JOIN department ON city.department_id = department.id | ||
LEFT JOIN region ON department.region_id = region.id | ||
WHERE | ||
region.code IN ('11', '84', '75') | ||
""" |