Skip to content

Commit

Permalink
AbstractDynodeRunner code reorg and naming revamp (#316)
Browse files Browse the repository at this point in the history
* checkpoint, moving timeline->timeseries naming

* checkpoint bugfix

* checkpoint pickling issues

* moving some AbstractDynodeRunner code out to utils finishing timeline-> timeseries rename

* merging and fixing comments and formatting

* moving single use functions back from utils to DynodeRunner

* hotfix to flatten_list_parameters not working with jax array, adding tests

* fixing mypy

* rewording module description

* comment feedback

* adding period to the end of parameters sentences

* save_inference_timeseries -> generate_inference_timeseries, generate_model_compnent_timeseries now private
  • Loading branch information
arik-shurygin authored Jan 23, 2025
1 parent 6bf6a79 commit 149fbb1
Show file tree
Hide file tree
Showing 10 changed files with 665 additions and 592 deletions.
4 changes: 2 additions & 2 deletions examples/config/config_inferer_covid.json
Original file line number Diff line number Diff line change
Expand Up @@ -123,8 +123,8 @@
"SEASONALITY_SECOND_WAVE": 0.5,
"SEASONALITY_SHIFT": 0,
"INFERENCE_PRNGKEY": 8675309,
"INFERENCE_NUM_WARMUP": 30,
"INFERENCE_NUM_SAMPLES": 30,
"INFERENCE_NUM_WARMUP": 40,
"INFERENCE_NUM_SAMPLES": 20,
"INFERENCE_NUM_CHAINS": 4,
"INFERENCE_PROGRESS_BAR": true,
"MODEL_RAND_SEED": 8675309
Expand Down
26 changes: 19 additions & 7 deletions examples/example_end_to_end_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,8 +115,11 @@ def process_state(self, state: str, **kwargs):
# those distributions in the Config are now posteriors
inferer.infer(jnp.array(synthetic_observed_hospitalizations))
print("saving a suite of inference visualizations ")
self.save_inference_timelines(
inferer, "local_inference_timeseries.csv"
# save some particle 0 and 5 from chains 0 and 1 for example
self.generate_inference_timeseries(
inferer,
particles=[(0, 0), (1, 0), (0, 5), (1, 5)],
timeseries_filename="local_inference_timeseries.csv",
)
self.save_inference_posteriors(
inferer, "local_example_inferer_posteriors.json"
Expand All @@ -128,17 +131,26 @@ def process_state(self, state: str, **kwargs):
"to increase the INFERENCE_NUM_SAMPLES and INFERENCE_NUM_WARMUP "
"parameters in the config_inferer_covid.json to see this improve. \n"
)
print(
"static values used to generate synthetic hosp data: \n"
"INFECTIOUS_PERIOD : %s \n"
"INTRODUCTION_TIMES_BA2/5 : %s \n"
"IHRS : %s"
% (
static_params.config.INFECTIOUS_PERIOD,
static_params.config.INTRODUCTION_TIMES[0],
str(ihr),
)
)
else:
# step 5: interpret the solution object in a variety of ways
save_path = "output/example_end_to_end_run.png"
self.save_static_run_timelines(
df = self.save_static_run_timeseries(
static_params, solution, "local_run_timeseries.csv"
)
df = self._generate_model_component_timelines(
static_params, solution
)
df["chain_particle"] = "na_na"
# attach a `state` column so plot cols have titles
df["state"] = "USA"
# for normalization of metrics per 100k
usa_pop = {"USA": initializer.config.POP_SIZE}
fig = vis_utils.plot_model_overview_subplot_matplotlib(df, usa_pop)
print("Please see %s for your plot!" % save_path)
Expand Down
2 changes: 1 addition & 1 deletion src/dynode/abstract_parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -433,7 +433,7 @@ def seasonality(
-------
<ArrayLike | Float>
Seasonality coefficient signaling an increase (>1) or decrease (<1)
in transmission due to the impact of seasonality.\
in transmission due to the impact of seasonality.
"""
# cosine curves are defined by a cycle of 365 days begining at jan 1st
# start by shifting the curve some number of days such that we line up with our INIT_DATE
Expand Down
Loading

0 comments on commit 149fbb1

Please sign in to comment.