-
Notifications
You must be signed in to change notification settings - Fork 2
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
philippe thomy
committed
Feb 1, 2025
1 parent
f18192d
commit 572702f
Showing
5 changed files
with
283 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,161 @@ | ||
"""QualiCharge prefect indicators: usage. | ||
C1: Number of sessions by operator. | ||
""" | ||
|
||
from string import Template | ||
from typing import List | ||
from uuid import UUID | ||
|
||
import numpy as np | ||
import pandas as pd # type: ignore | ||
from prefect import flow, runtime, task | ||
from prefect.futures import wait | ||
from prefect.task_runners import ThreadPoolTaskRunner | ||
from sqlalchemy.engine import Connection | ||
|
||
from ..conf import settings | ||
from ..models import Indicator, IndicatorTimeSpan, Level | ||
from ..utils import ( | ||
export_indic, | ||
get_database_engine, | ||
get_num_for_level_query_params, | ||
get_targets_for_level, | ||
get_timespan_filter_query_params, | ||
) | ||
|
||
SESSIONS_BY_OPERATOR_QUERY_TEMPLATE = """ | ||
SELECT | ||
count(*) AS value, | ||
nom_operateur AS category, | ||
$level_id AS level_id | ||
FROM | ||
Session | ||
INNER JOIN statique ON point_de_charge_id = pdc_id | ||
LEFT JOIN City ON City.code = code_insee_commune | ||
$join_extras | ||
WHERE | ||
$timespan | ||
AND $level_id IN ($indexes) | ||
GROUP BY | ||
category, | ||
$level_id | ||
""" | ||
|
||
QUERY_NATIONAL_TEMPLATE = """ | ||
SELECT | ||
count(*) AS value, | ||
nom_operateur AS category | ||
FROM | ||
SESSION | ||
INNER JOIN statique ON point_de_charge_id = pdc_id | ||
WHERE | ||
$timespan | ||
GROUP BY | ||
category | ||
""" | ||
|
||
|
||
@task(task_run_name="values-for-target-{level:02d}") | ||
def get_values_for_targets( | ||
connection: Connection, | ||
level: Level, | ||
timespan: IndicatorTimeSpan, | ||
indexes: List[UUID], | ||
) -> pd.DataFrame: | ||
"""Fetch sessions given input level, timestamp and target index.""" | ||
query_template = Template(SESSIONS_BY_OPERATOR_QUERY_TEMPLATE) | ||
query_params = {"indexes": ",".join(f"'{i}'" for i in map(str, indexes))} | ||
query_params |= get_num_for_level_query_params(level) | ||
query_params |= get_timespan_filter_query_params(timespan, session=True) | ||
return pd.read_sql_query(query_template.substitute(query_params), con=connection) | ||
|
||
|
||
@flow( | ||
task_runner=ThreadPoolTaskRunner(max_workers=settings.THREAD_POOL_MAX_WORKERS), | ||
flow_run_name="c1-{timespan.period.value}-{level:02d}-{timespan.start:%y-%m-%d}", | ||
) | ||
def c1_for_level( | ||
level: Level, | ||
timespan: IndicatorTimeSpan, | ||
chunk_size=settings.DEFAULT_CHUNK_SIZE, | ||
) -> pd.DataFrame: | ||
"""Calculate c1 for a level and a timestamp.""" | ||
if level == Level.NATIONAL: | ||
return c1_national(timespan) | ||
engine = get_database_engine() | ||
with engine.connect() as connection: | ||
targets = get_targets_for_level(connection, level) | ||
ids = targets["id"] | ||
chunks = ( | ||
np.array_split(ids, int(len(ids) / chunk_size)) | ||
if len(ids) > chunk_size | ||
else [ids.to_numpy()] | ||
) | ||
futures = [ | ||
get_values_for_targets.submit(connection, level, timespan, chunk) # type: ignore[call-overload] | ||
for chunk in chunks | ||
] | ||
wait(futures) | ||
|
||
# Concatenate results and serialize indicators | ||
results = pd.concat([future.result() for future in futures], ignore_index=True) | ||
merged = targets.merge(results, how="left", left_on="id", right_on="level_id") | ||
|
||
# Build result DataFrame | ||
indicators = { | ||
"target": merged["code"], | ||
"value": merged["value"].fillna(0), | ||
"code": "c1", | ||
"level": level, | ||
"period": timespan.period, | ||
"timestamp": timespan.start.isoformat(), | ||
"category": merged["category"].astype("str"), | ||
"extras": None, | ||
} | ||
return pd.DataFrame(indicators) | ||
|
||
|
||
@flow( | ||
task_runner=ThreadPoolTaskRunner(max_workers=settings.THREAD_POOL_MAX_WORKERS), | ||
flow_run_name="c1-{timespan.period.value}-00-{timespan.start:%y-%m-%d}", | ||
) | ||
def c1_national(timespan: IndicatorTimeSpan) -> pd.DataFrame: | ||
"""Calculate c1 at the national level.""" | ||
engine = get_database_engine() | ||
query_template = Template(QUERY_NATIONAL_TEMPLATE) | ||
query_params = get_timespan_filter_query_params(timespan, session=True) | ||
with engine.connect() as connection: | ||
res = pd.read_sql_query(query_template.substitute(query_params), con=connection) | ||
indicators = { | ||
"target": None, | ||
"value": res["value"].fillna(0), | ||
"code": "c1", | ||
"level": Level.NATIONAL, | ||
"period": timespan.period, | ||
"timestamp": timespan.start.isoformat(), | ||
"category": res["category"].astype("str"), | ||
"extras": None, | ||
} | ||
return pd.DataFrame(indicators) | ||
|
||
|
||
@flow( | ||
task_runner=ThreadPoolTaskRunner(max_workers=settings.THREAD_POOL_MAX_WORKERS), | ||
flow_run_name="meta-c1-{timespan.period.value}", | ||
) | ||
def calculate( | ||
timespan: IndicatorTimeSpan, | ||
levels: List[Level], | ||
create_artifact: bool = False, | ||
chunk_size: int = 1000, | ||
format_pd: bool = False, | ||
) -> List[Indicator]: | ||
"""Run all c1 subflows.""" | ||
subflows_results = [ | ||
c1_for_level(level, timespan, chunk_size=chunk_size) for level in levels | ||
] | ||
indicators = pd.concat(subflows_results, ignore_index=True) | ||
description = f"c1 report at {timespan.start} (period: {timespan.period.value})" | ||
flow_name = runtime.flow_run.name | ||
return export_indic(indicators, create_artifact, flow_name, description, format_pd) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
File renamed without changes.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,106 @@ | ||
"""QualiCharge prefect indicators tests: usage. | ||
C1: Number of sessions by operator. | ||
""" | ||
|
||
from datetime import datetime | ||
|
||
import pytest # type: ignore | ||
from sqlalchemy import text | ||
|
||
from indicators.models import IndicatorPeriod, IndicatorTimeSpan, Level # type: ignore | ||
from indicators.usage import c1 # type: ignore | ||
|
||
from ..param_tests import ( | ||
PARAM_FLOW, | ||
PARAM_VALUE, | ||
PARAMETERS_CHUNK, | ||
) | ||
|
||
# expected result for level [city, epci, dpt, reg] | ||
N_LEVEL = [32, 307, 172, 1055] | ||
N_LEVEL_NATIONAL = 2718 | ||
|
||
TIMESPAN = IndicatorTimeSpan(start=datetime(2024, 12, 24), period=IndicatorPeriod.DAY) | ||
PARAMETERS_FLOW = [prm + (lvl,) for prm, lvl in zip(PARAM_FLOW, N_LEVEL, strict=True)] | ||
PARAMETERS_VALUE = [prm + (lvl,) for prm, lvl in zip(PARAM_VALUE, N_LEVEL, strict=True)] | ||
|
||
|
||
@pytest.mark.parametrize("level,query,expected", PARAMETERS_VALUE) | ||
def test_task_get_values_for_target(db_connection, level, query, expected): | ||
"""Test the `get_values_for_target` task.""" | ||
result = db_connection.execute(text(query)) | ||
indexes = list(result.scalars().all()) | ||
ses_by_hour = c1.get_values_for_targets.fn(db_connection, level, TIMESPAN, indexes) | ||
assert len(set(ses_by_hour["level_id"])) == len(indexes) | ||
assert ses_by_hour["value"].sum() == expected | ||
|
||
|
||
def test_task_get_values_for_target_unexpected_level(db_connection): | ||
"""Test the `get_values_for_target` task (unknown level).""" | ||
with pytest.raises(NotImplementedError, match="Unsupported level"): | ||
c1.get_values_for_targets.fn(db_connection, Level.NATIONAL, TIMESPAN, []) | ||
|
||
|
||
@pytest.mark.parametrize("level,query,targets,expected", PARAMETERS_FLOW) | ||
def test_flow_c1_for_level(db_connection, level, query, targets, expected): | ||
"""Test the `c1_for_level` flow.""" | ||
indicators = c1.c1_for_level(level, TIMESPAN, chunk_size=1000) | ||
# assert len(indicators) == db_connection.execute(text(query)).scalars().one() | ||
assert indicators.loc[indicators["target"].isin(targets), "value"].sum() == expected | ||
|
||
|
||
@pytest.mark.parametrize("chunk_size", PARAMETERS_CHUNK) | ||
def test_flow_c1_for_level_with_various_chunk_sizes(chunk_size): | ||
"""Test the `c1_for_level` flow with various chunk sizes.""" | ||
level, query, targets, expected = PARAMETERS_FLOW[2] | ||
indicators = c1.c1_for_level(level, TIMESPAN, chunk_size=chunk_size) | ||
assert indicators.loc[indicators["target"].isin(targets), "value"].sum() == expected | ||
|
||
|
||
def test_flow_c1_national(db_connection): | ||
"""Test the `c1_national` flow.""" | ||
indicators = c1.c1_national(TIMESPAN) | ||
assert indicators["value"].sum() == N_LEVEL_NATIONAL | ||
|
||
|
||
def test_flow_c1_calculate(db_connection): | ||
"""Test the `calculate` flow.""" | ||
all_levels = [ | ||
Level.NATIONAL, | ||
Level.REGION, | ||
Level.DEPARTMENT, | ||
Level.CITY, | ||
Level.EPCI, | ||
] | ||
indicators = c1.calculate( | ||
TIMESPAN, all_levels, create_artifact=True, format_pd=True | ||
) | ||
assert list(indicators["level"].unique()) == all_levels | ||
|
||
|
||
# query used to get N_LEVEL | ||
N_LEVEL_NAT = """ | ||
SELECT | ||
count(*) AS value, | ||
FROM | ||
SESSION | ||
INNER JOIN statique ON point_de_charge_id = pdc_id | ||
WHERE | ||
START >= timestamp '2024-12-24' | ||
AND START < timestamp '2024-12-25' | ||
""" | ||
N_LEVEL_3 = """ | ||
SELECT | ||
count(*) AS value | ||
FROM | ||
Session | ||
INNER JOIN statique ON point_de_charge_id = pdc_id | ||
LEFT JOIN City ON City.code = code_insee_commune | ||
INNER JOIN Department ON City.department_id = Department.id | ||
INNER JOIN Region ON Department.region_id = Region.id | ||
WHERE | ||
START >= timestamp '2024-12-24' | ||
AND START < timestamp '2024-12-25' | ||
AND region.code IN ('11', '84', '75') | ||
""" |