Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 24 additions & 9 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -492,7 +492,8 @@ def multi_pathogen_combined_prep(location, target, scenario, other_scen, round_t
other_pathogen_data = {}
for patho in other_pathogen:
rnd_n = pathogen_set[patho]["round"]
other_pathogen_data.update({patho: {"round_number": re.findall("\d+", rnd_n)[0]}})
other_pathogen_data.update({patho: {"round_number": re.findall(r"\d+", rnd_n)[0]}})

# Data
df = query_proj_data(location, target, "sample", round_number)
df = df[df["age_group"] == "0-130"]
Expand All @@ -506,6 +507,7 @@ def multi_pathogen_combined_prep(location, target, scenario, other_scen, round_t
pathogen_data = {pathogen: {"data": df, "scenario_int": scenario_int, "time_series": flu_ts}}
time_series_date = list()
time_series_date.append(set(pathogen_data[pathogen]["time_series"]))

for patho in other_pathogen:
if other_scen[patho] is not None and len(other_scen[patho]) > 0:
df_other = query_proj_data(location, target, "sample", other_pathogen_data[patho]["round_number"],
Expand All @@ -517,14 +519,14 @@ def multi_pathogen_combined_prep(location, target, scenario, other_scen, round_t
ts = None
other_scen_int = list()
for i in other_scen[patho]:
other_scen_int.append(int(
pathogen_set[patho]["scenario"]["dict"][i]))
other_scen_int.append(int(pathogen_set[patho]["scenario"]["dict"][i]))
else:
df_other = pd.DataFrame()
ts = None
other_scen_int = list()
other_pathogen_data[patho].update({"data": df_other, "scenario_int": other_scen_int,
"time_series": ts})

time_series_date = set.intersection(*time_series_date)
time_series_date = list(filter(None, list(time_series_date)))
pathogen_info = pathogen_data | other_pathogen_data
Expand All @@ -536,12 +538,14 @@ def multi_pathogen_combined_prep(location, target, scenario, other_scen, round_t
df_patho = prep_pathogen_data(df_patho, pathogen_info[patho]["scenario_int"], patho.lower(), k=k)
pathogen_information.update({patho: {"dataframe": df_patho}})
df_all = prep_multipat_plot_comb(pathogen_information, calc_mean=True)

# Observed data
obs_other_patho = list()
for patho in other_pathogen:
if other_scen[patho] is not None:
obs_other_patho.append(pathogen_set[patho]["display_name"])
obs_data = multi_pathogen_obs_prep(round_tab, target, location, obs_other_patho, time_series_date)

# Title & Subtitle
title_pathogen = list()
title_other_pathogen = list()
Expand Down Expand Up @@ -631,7 +635,12 @@ def draw_scenario_plot(scenario, location, target, ui, age_group, ens_check, rou


@cache.memoize(timeout=TIMEOUT)
def draw_spaghetti_plot(scenario, location, target, age_group, n_sample, med_plot, round_tab):
def draw_spaghetti_plot(scenario, location, target, age_group, n_sample, med_plot, round_tab, band_depth_limit):
'''
:param band_depth_limit: if this parameter is set to a value between 0 and 1, then the plot will create
a shaded "envelope" around the trajectories with band depths greater than or equal to band_depth_limit
NOTE: this parameter should be set by the top-level function spaghetti_plot()
'''
prep_plot = spaghetti_plot_prep(scenario, location, target, age_group, n_sample, med_plot, round_tab)
if (prep_plot["df"] is None) or (len(prep_plot["df"]) == 0):
fig = fig_error_message("No projection to display for the target: " + prep_plot["y_title"] + ", location: " +
Expand All @@ -644,13 +653,13 @@ def draw_spaghetti_plot(scenario, location, target, age_group, n_sample, med_plo
color_dict=prep_plot["color_dict"], opacity=prep_plot["opacity"],
add_median=True, title=prep_plot["title"], x_title="Epiweek",
subplot_titles=prep_plot["subplot_titles"], y_title=prep_plot["y_title"],
legend_dict=constant_dict["model_name"])
legend_dict=constant_dict["model_name"], band_depth_limit=band_depth_limit)
else:
fig = make_spaghetti_plot(prep_plot["df"], subplot=True, subplot_col="scenario_id",
color_dict=prep_plot["color_dict"], opacity=prep_plot["opacity"],
title=prep_plot["title"], subplot_titles=prep_plot["subplot_titles"],
y_title=prep_plot["y_title"], x_title="Epiweek",
legend_dict=constant_dict["model_name"])
legend_dict=constant_dict["model_name"], band_depth_limit=band_depth_limit)
fig.add_annotation(
x=0, y=1, xref="paper", yref="paper", text="ⓘ", font=dict(size=32), arrowcolor="white",
arrowhead=False,
Expand Down Expand Up @@ -759,7 +768,9 @@ def draw_multi_pathogen_comb_plot(location, target, scenario, other_scen, round_
for patho in other_scen.keys():
if other_scen[patho] is not None and len(other_scen[patho]) >= 1:
other_pathogen.append(viz_setting[round_tab]["multi-pathogen_plot"]["pathogen"][patho]["display_name"])

prep_plot = multi_pathogen_combined_prep(location, target, scenario, other_scen, round_tab)

df_all = prep_plot["df_all"]
if prep_plot["df_gs_data"] is not None:
truth_data = prep_plot["df_gs_data"]
Expand Down Expand Up @@ -928,7 +939,7 @@ def render_plot_tab_content(plot_tab, round_tab):
patho_round = pathogen_information["round_display"]
else:
patho_round = pathogen_information["round"]
patho_round = int(re.sub("\D", "", patho_round))
patho_round = int(re.sub(r"\D", "", patho_round))
patho_website = pathogen_information["website"]
patho_dic = {"scenario": patho_scen_dict, "default_sel": def_scen, "round_int": patho_round,
"name": pathogen_information["display_name"], "website": patho_website}
Expand Down Expand Up @@ -1021,9 +1032,13 @@ def scenario_plot(location, target, scenario, ui, age_group, round_tab, ens_chec
Input("sample-slider", "value"),
Input("median-checkbox", "value"),
Input("tabs-round", "value"))
def spaghetti_plot(location, target, scenario, age_group, n_sample, med_plot, round_tab):
def spaghetti_plot(location, target, scenario, age_group, n_sample, med_plot, round_tab, band_depth_limit=None):
Copy link
Collaborator

@jacobrklein20 jacobrklein20 Dec 13, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Proposed documentation to add:
"""
:param location: str used to construct path to sample file to load in visualization/data-visualization
:param target: str, 'inc hosp' or 'cum hosp'. Used to construct path to sample file to load in visualization/data-visualization
:param scenario: List[int], used to filter sample file
:param age_group: str, used to filter sample file
:param n_sample: int, number of samples to display
:param med_plot: bool, if True, median trajectory displayed
:param round_tab: str, name of Round (as seen in hub-config/viz_settings.json) to be used
:param band_depth_limit: float| None. If not None, will show an envelope in resulting spaghetti plot around trajectories with a band depth above this value. Band depth is a measure of the representativeness of one trajectory among an ensemble. For more details, see https://ieeexplore.ieee.org/document/6875964 - Curve Boxplot: Generalization of Boxplot for Ensembles of Curves by Mirzargar et al.
:return: go.Figure object result
"""

'''
:param band_depth_limit: if this parameter is set to a value between 0 and 1, then the plot will create
a shaded "envelope" around the trajectories with band depths greater than or equal to band_depth_limit
'''
tic = time.perf_counter()
fig = draw_spaghetti_plot(scenario, location, target, age_group, n_sample, med_plot, round_tab)
fig = draw_spaghetti_plot(scenario, location, target, age_group, n_sample, med_plot, round_tab, band_depth_limit)
toc = time.perf_counter()
print(f"Draw Spaghetti plot in {toc - tic:0.4f} seconds")
return fig
Expand Down