Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

AbstractDynodeRunner code reorg and naming revamp #316

Merged
merged 16 commits into from
Jan 23, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions examples/config/config_inferer_covid.json
Original file line number Diff line number Diff line change
Expand Up @@ -123,8 +123,8 @@
"SEASONALITY_SECOND_WAVE": 0.5,
"SEASONALITY_SHIFT": 0,
"INFERENCE_PRNGKEY": 8675309,
"INFERENCE_NUM_WARMUP": 30,
"INFERENCE_NUM_SAMPLES": 30,
"INFERENCE_NUM_WARMUP": 40,
"INFERENCE_NUM_SAMPLES": 20,
"INFERENCE_NUM_CHAINS": 4,
"INFERENCE_PROGRESS_BAR": true,
"MODEL_RAND_SEED": 8675309
Expand Down
26 changes: 19 additions & 7 deletions examples/example_end_to_end_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,8 +115,11 @@ def process_state(self, state: str, **kwargs):
# those distributions in the Config are now posteriors
inferer.infer(jnp.array(synthetic_observed_hospitalizations))
print("saving a suite of inference visualizations ")
self.save_inference_timelines(
inferer, "local_inference_timeseries.csv"
# save some particle 0 and 5 from chains 0 and 1 for example
self.generate_inference_timeseries(
inferer,
particles=[(0, 0), (1, 0), (0, 5), (1, 5)],
timeseries_filename="local_inference_timeseries.csv",
)
self.save_inference_posteriors(
inferer, "local_example_inferer_posteriors.json"
Expand All @@ -128,17 +131,26 @@ def process_state(self, state: str, **kwargs):
"to increase the INFERENCE_NUM_SAMPLES and INFERENCE_NUM_WARMUP "
"parameters in the config_inferer_covid.json to see this improve. \n"
)
print(
"static values used to generate synthetic hosp data: \n"
"INFECTIOUS_PERIOD : %s \n"
"INTRODUCTION_TIMES_BA2/5 : %s \n"
kokbent marked this conversation as resolved.
Show resolved Hide resolved
"IHRS : %s"
% (
static_params.config.INFECTIOUS_PERIOD,
static_params.config.INTRODUCTION_TIMES[0],
str(ihr),
)
)
else:
# step 5: interpret the solution object in a variety of ways
save_path = "output/example_end_to_end_run.png"
self.save_static_run_timelines(
df = self.save_static_run_timeseries(
static_params, solution, "local_run_timeseries.csv"
)
df = self._generate_model_component_timelines(
static_params, solution
)
df["chain_particle"] = "na_na"
# attach a `state` column so plot cols have titles
df["state"] = "USA"
# for normalization of metrics per 100k
usa_pop = {"USA": initializer.config.POP_SIZE}
fig = vis_utils.plot_model_overview_subplot_matplotlib(df, usa_pop)
print("Please see %s for your plot!" % save_path)
Expand Down
2 changes: 1 addition & 1 deletion src/dynode/abstract_parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -433,7 +433,7 @@ def seasonality(
-------
<ArrayLike | Float>
Seasonality coefficient signaling an increase (>1) or decrease (<1)
in transmission due to the impact of seasonality.\
in transmission due to the impact of seasonality.
"""
# cosine curves are defined by a cycle of 365 days begining at jan 1st
# start by shifting the curve some number of days such that we line up with our INIT_DATE
Expand Down
Loading
Loading