From a4e5bbb08deb54993a4a1144b7898fd30b5265ad Mon Sep 17 00:00:00 2001 From: "Dylan H. Morris" Date: Mon, 6 Jan 2025 17:04:28 -0500 Subject: [PATCH 01/32] Split up main call into sample helper functions --- pyrenew_hew/pyrenew_hew_model.py | 74 ++++++++++++++++++++++++++------ 1 file changed, 60 insertions(+), 14 deletions(-) diff --git a/pyrenew_hew/pyrenew_hew_model.py b/pyrenew_hew/pyrenew_hew_model.py index 88658739..8d372ffa 100644 --- a/pyrenew_hew/pyrenew_hew_model.py +++ b/pyrenew_hew/pyrenew_hew_model.py @@ -5,6 +5,7 @@ import numpyro import numpyro.distributions as dist import pyrenew.transformation as transformation +from jax.typing import ArrayLike from pyrenew.arrayutils import repeat_until_n, tile_until_n from pyrenew.convolve import compute_delay_ascertained_incidence from pyrenew.deterministic import DeterministicVariable @@ -39,7 +40,9 @@ def __init__( ed_wday_effect_rv, inf_to_ed_rv, phi_rv, - right_truncation_pmf_rv, # when unnamed deterministic variables are allowed, we could default this to 1. + right_truncation_pmf_rv, + # when unnamed deterministic variables are allowed, + # we could default this to 1. n_initialization_points, ): # numpydoc ignore=GL08 self.infection_initialization_process = InfectionInitializationProcess( @@ -95,7 +98,8 @@ def sample( and data_observed_disease_ed_visits is None ): raise ValueError( - "Either n_observed_disease_ed_visits_datapoints or data_observed_disease_ed_visits " + "Either n_observed_disease_ed_visits_datapoints or " + "data_observed_disease_ed_visits " "must be passed." ) elif ( @@ -103,7 +107,8 @@ def sample( and data_observed_disease_ed_visits is not None ): raise ValueError( - "Cannot pass both n_observed_disease_ed_visits_datapoints and data_observed_disease_ed_visits." + "Cannot pass both n_observed_disease_ed_visits_datapoints " + "and data_observed_disease_ed_visits." ) elif n_observed_disease_ed_visits_datapoints is None: n_observed_disease_ed_visits_datapoints = len( @@ -163,6 +168,55 @@ def sample( numpyro.deterministic("rt", inf_with_feedback_proc_sample.rt) numpyro.deterministic("latent_infections", latent_infections) + sampled_ed_visits = self.sample_ed_visits( + latent_infections=latent_infections, + data_observed_disease_ed_visits=data_observed_disease_ed_visits, + n_observed_disease_ed_visits_datapoints=( + n_observed_disease_ed_visits_datapoints + ), + n_weeks_post_init=n_weeks_post_init, + right_truncation_offset=right_truncation_offset, + ) + self.observe_hospital_admissions() + self.observe_wastewater() + + return sampled_ed_visits + + def sample_hospital_admissions( + latent_infections: ArrayLike, + n_observed_hospital_admissions_datapoints: int, + ) -> ArrayLike: + """ + Observe and/or predict incident hospital admissions. + """ + if n_observed_hospital_admissions_datapoints is not None: + hospital_admissions_obs_rv = PoissonObservation( + "observed_hospital_admissions" + ) + data_observed_disease_hospital_admissions = ( + hospital_admissions_obs_rv( + mu=jnp.ones(n_observed_hospital_admissions_datapoints) + 50 + ) + ) + return data_observed_disease_hospital_admissions + + def sample_wastewater(self): + """ + Placeholder for when W component implemented. + """ + pass + + def sample_ed_visits( + self, + latent_infections, + data_observed_disease_ed_visits: ArrayLike, + n_observed_disease_ed_visits_datapoints: int, + n_weeks_post_init: int, + right_truncation_offset: ArrayLike = None, + ) -> ArrayLike: + """ + Observe and/or predict ED visit values + """ p_ed_mean = self.p_ed_mean_rv() p_ed_w_sd = self.p_ed_w_sd_rv() autoreg_p_ed = self.autoreg_p_ed_rv() @@ -188,7 +242,9 @@ def sample( transformation.SigmoidTransform()(p_ed_ar + p_ed_mean), repeats=7, )[:n_observed_disease_ed_visits_datapoints] - # this is only applied after the ed visits are generated, not to all the latent infections. This is why we cannot apply the iedr in compute_delay_ascertained_incidence + # this is only applied after the ed visits are generated, not to all + # the latent infections. This is why we cannot apply the iedr in + # compute_delay_ascertained_incidence # see https://github.com/CDCgov/ww-inference-model/issues/43 numpyro.deterministic("iedr", iedr) @@ -239,16 +295,6 @@ def sample( obs=data_observed_disease_ed_visits, ) - if n_observed_hospital_admissions_datapoints is not None: - hospital_admissions_obs_rv = PoissonObservation( - "observed_hospital_admissions" - ) - data_observed_disease_hospital_admissions = ( - hospital_admissions_obs_rv( - mu=jnp.ones(n_observed_hospital_admissions_datapoints) + 50 - ) - ) - return observed_ed_visits From ff2d629fd131095162b13af9d03b0adbfd632632 Mon Sep 17 00:00:00 2001 From: "Dylan H. Morris" Date: Mon, 6 Jan 2025 17:05:55 -0500 Subject: [PATCH 02/32] Correct names in function calls --- pyrenew_hew/pyrenew_hew_model.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pyrenew_hew/pyrenew_hew_model.py b/pyrenew_hew/pyrenew_hew_model.py index 8d372ffa..d1675b15 100644 --- a/pyrenew_hew/pyrenew_hew_model.py +++ b/pyrenew_hew/pyrenew_hew_model.py @@ -177,8 +177,8 @@ def sample( n_weeks_post_init=n_weeks_post_init, right_truncation_offset=right_truncation_offset, ) - self.observe_hospital_admissions() - self.observe_wastewater() + self.sample_hospital_admissions() + self.sample_wastewater() return sampled_ed_visits From ac306ff294158c4b10c129564d920b1209ddba65 Mon Sep 17 00:00:00 2001 From: "Dylan H. Morris" Date: Mon, 6 Jan 2025 17:07:27 -0500 Subject: [PATCH 03/32] Fix missing arg --- pyrenew_hew/pyrenew_hew_model.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/pyrenew_hew/pyrenew_hew_model.py b/pyrenew_hew/pyrenew_hew_model.py index d1675b15..32954cc0 100644 --- a/pyrenew_hew/pyrenew_hew_model.py +++ b/pyrenew_hew/pyrenew_hew_model.py @@ -177,7 +177,12 @@ def sample( n_weeks_post_init=n_weeks_post_init, right_truncation_offset=right_truncation_offset, ) - self.sample_hospital_admissions() + self.sample_hospital_admissions( + latent_infections=latent_infections, + n_observed_hospital_admissions_datapoints=( + n_observed_hospital_admissions_datapoints + ), + ) self.sample_wastewater() return sampled_ed_visits From 37b3357ed4998877045af7d49e7c471b6023cfad Mon Sep 17 00:00:00 2001 From: "Dylan H. Morris" Date: Mon, 6 Jan 2025 17:08:36 -0500 Subject: [PATCH 04/32] fix missing self in fn defn --- pyrenew_hew/pyrenew_hew_model.py | 1 + 1 file changed, 1 insertion(+) diff --git a/pyrenew_hew/pyrenew_hew_model.py b/pyrenew_hew/pyrenew_hew_model.py index 32954cc0..ded2f0cb 100644 --- a/pyrenew_hew/pyrenew_hew_model.py +++ b/pyrenew_hew/pyrenew_hew_model.py @@ -188,6 +188,7 @@ def sample( return sampled_ed_visits def sample_hospital_admissions( + self, latent_infections: ArrayLike, n_observed_hospital_admissions_datapoints: int, ) -> ArrayLike: From a52719f7b0e004bee1d37ec2ecc36e8f315a91f5 Mon Sep 17 00:00:00 2001 From: "Dylan H. Morris" Date: Mon, 6 Jan 2025 17:18:54 -0500 Subject: [PATCH 05/32] Clean up admissions placeholder --- pyrenew_hew/pyrenew_hew_model.py | 32 +++++++++++++++++--------------- 1 file changed, 17 insertions(+), 15 deletions(-) diff --git a/pyrenew_hew/pyrenew_hew_model.py b/pyrenew_hew/pyrenew_hew_model.py index ded2f0cb..9b429dde 100644 --- a/pyrenew_hew/pyrenew_hew_model.py +++ b/pyrenew_hew/pyrenew_hew_model.py @@ -92,7 +92,7 @@ def sample( data_observed_disease_ed_visits=None, data_observed_disease_hospital_admissions=None, right_truncation_offset=None, - ): # numpydoc ignore=GL08 + ) -> ArrayLike: # numpydoc ignore=GL08 if ( n_observed_disease_ed_visits_datapoints is None and data_observed_disease_ed_visits is None @@ -170,7 +170,7 @@ def sample( sampled_ed_visits = self.sample_ed_visits( latent_infections=latent_infections, - data_observed_disease_ed_visits=data_observed_disease_ed_visits, + data_observed_disease_ed_visits=(data_observed_disease_ed_visits), n_observed_disease_ed_visits_datapoints=( n_observed_disease_ed_visits_datapoints ), @@ -179,8 +179,9 @@ def sample( ) self.sample_hospital_admissions( latent_infections=latent_infections, - n_observed_hospital_admissions_datapoints=( - n_observed_hospital_admissions_datapoints + n_admissions_sampled=(n_observed_hospital_admissions_datapoints), + data_observed_disease_hospital_admissions=( + data_observed_disease_hospital_admissions ), ) self.sample_wastewater() @@ -190,21 +191,22 @@ def sample( def sample_hospital_admissions( self, latent_infections: ArrayLike, - n_observed_hospital_admissions_datapoints: int, + n_admissions_sampled: int, + data_observed_disease_hospital_admissions: ArrayLike | None = None, ) -> ArrayLike: """ Observe and/or predict incident hospital admissions. """ - if n_observed_hospital_admissions_datapoints is not None: - hospital_admissions_obs_rv = PoissonObservation( - "observed_hospital_admissions" - ) - data_observed_disease_hospital_admissions = ( - hospital_admissions_obs_rv( - mu=jnp.ones(n_observed_hospital_admissions_datapoints) + 50 - ) - ) - return data_observed_disease_hospital_admissions + hospital_admissions_obs_rv = PoissonObservation( + "observed_hospital_admissions" + ) + # placeholder mean + predicted_admissions = jnp.ones(n_admissions_sampled) + sampled_admissions = hospital_admissions_obs_rv( + mu=predicted_admissions, + obs=data_observed_disease_hospital_admissions, + ) + return sampled_admissions def sample_wastewater(self): """ From ca882fc7480a60fea0f8d7393d641615bb7ef4a7 Mon Sep 17 00:00:00 2001 From: "Dylan H. Morris" Date: Mon, 6 Jan 2025 17:25:59 -0500 Subject: [PATCH 06/32] Add some conditionals to the observation process --- pyrenew_hew/pyrenew_hew_model.py | 60 +++++++++++++++++++++++--------- 1 file changed, 43 insertions(+), 17 deletions(-) diff --git a/pyrenew_hew/pyrenew_hew_model.py b/pyrenew_hew/pyrenew_hew_model.py index 9b429dde..fc6f00c7 100644 --- a/pyrenew_hew/pyrenew_hew_model.py +++ b/pyrenew_hew/pyrenew_hew_model.py @@ -122,6 +122,8 @@ def sample( n_observed_hospital_admissions_datapoints = len( data_observed_disease_hospital_admissions ) + elif n_observed_hospital_admissions_datapoints is None: + n_observed_hospital_admissions_datapoints = 0 n_weeks_post_init = n_observed_disease_ed_visits_datapoints // 7 + 1 i0 = self.infection_initialization_process() @@ -168,25 +170,49 @@ def sample( numpyro.deterministic("rt", inf_with_feedback_proc_sample.rt) numpyro.deterministic("latent_infections", latent_infections) - sampled_ed_visits = self.sample_ed_visits( - latent_infections=latent_infections, - data_observed_disease_ed_visits=(data_observed_disease_ed_visits), - n_observed_disease_ed_visits_datapoints=( - n_observed_disease_ed_visits_datapoints - ), - n_weeks_post_init=n_weeks_post_init, - right_truncation_offset=right_truncation_offset, - ) - self.sample_hospital_admissions( - latent_infections=latent_infections, - n_admissions_sampled=(n_observed_hospital_admissions_datapoints), - data_observed_disease_hospital_admissions=( - data_observed_disease_hospital_admissions - ), + # observation proceses + + sample_ed_visits = True + sample_admissions = n_observed_hospital_admissions_datapoints > 0 + sample_wastewater = False + + sampled_ed_visits, sampled_admissions, sampled_wastewater = ( + None, + None, + None, ) - self.sample_wastewater() - return sampled_ed_visits + if sample_ed_visits: + sampled_ed_visits = self.sample_ed_visits( + latent_infections=latent_infections, + data_observed_disease_ed_visits=( + data_observed_disease_ed_visits + ), + n_observed_disease_ed_visits_datapoints=( + n_observed_disease_ed_visits_datapoints + ), + n_weeks_post_init=n_weeks_post_init, + right_truncation_offset=right_truncation_offset, + ) + + if sample_admissions: + self.sample_hospital_admissions( + latent_infections=latent_infections, + n_admissions_sampled=( + n_observed_hospital_admissions_datapoints + ), + data_observed_disease_hospital_admissions=( + data_observed_disease_hospital_admissions + ), + ) + if sample_wastewater: + self.sample_wastewater() + + return { + "ed_visits": sampled_ed_visits, + "admissions": sampled_admissions, + "wasewater": sampled_wastewater, + } def sample_hospital_admissions( self, From a266f2f7c0a1833db64a5b5978311fc1bb33112c Mon Sep 17 00:00:00 2001 From: "Dylan H. Morris" Date: Mon, 6 Jan 2025 17:30:24 -0500 Subject: [PATCH 07/32] Add assignments of other sampled quantities --- pyrenew_hew/pyrenew_hew_model.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pyrenew_hew/pyrenew_hew_model.py b/pyrenew_hew/pyrenew_hew_model.py index fc6f00c7..5affccfa 100644 --- a/pyrenew_hew/pyrenew_hew_model.py +++ b/pyrenew_hew/pyrenew_hew_model.py @@ -196,7 +196,7 @@ def sample( ) if sample_admissions: - self.sample_hospital_admissions( + sampled_admissions = self.sample_hospital_admissions( latent_infections=latent_infections, n_admissions_sampled=( n_observed_hospital_admissions_datapoints @@ -206,7 +206,7 @@ def sample( ), ) if sample_wastewater: - self.sample_wastewater() + sampled_wastewater = self.sample_wastewater() return { "ed_visits": sampled_ed_visits, From 559d29f7772659a68a5d8ebce22fc2038bc145f8 Mon Sep 17 00:00:00 2001 From: "Dylan H. Morris" Date: Tue, 7 Jan 2025 11:30:13 -0500 Subject: [PATCH 08/32] Modularize model in prep for real admissions data --- pipelines/build_pyrenew_model.py | 46 ++- pipelines/priors/prod_priors.py | 9 +- pipelines/tests/test_output/priors.py | 9 +- pyrenew_hew/pyrenew_hew_model.py | 426 ++++++++++++++++---------- 4 files changed, 322 insertions(+), 168 deletions(-) diff --git a/pipelines/build_pyrenew_model.py b/pipelines/build_pyrenew_model.py index 115c031c..f319d2e9 100644 --- a/pipelines/build_pyrenew_model.py +++ b/pipelines/build_pyrenew_model.py @@ -4,7 +4,13 @@ import jax.numpy as jnp from pyrenew.deterministic import DeterministicVariable -from pyrenew_hew.pyrenew_hew_model import pyrenew_hew_model +from pyrenew_hew.pyrenew_hew_model import ( + EDVisitObservationProcess, + HospAdmitObservationProcess, + LatentInfectionProcess, + PyrenewHEWModel, + WastewaterObservationProcess, +) def build_model_from_dir(model_dir): @@ -21,6 +27,11 @@ def build_model_from_dir(model_dir): "inf_to_ed", jnp.array(model_data["inf_to_ed_pmf"]) ) # check if off by 1 or reversed + # use same as inf to ed, per NNH guidelines + inf_to_hosp_admit_rv = DeterministicVariable( + "inf_to_hosp_admit", jnp.array(model_data["inf_to_ed_pmf"]) + ) # check if off by 1 or reversed + generation_interval_pmf_rv = DeterministicVariable( "generation_interval_pmf", jnp.array(model_data["generation_interval_pmf"]), @@ -37,9 +48,9 @@ def build_model_from_dir(model_dir): data_observed_disease_hospital_admissions = jnp.array( model_data["data_observed_disease_hospital_admissions"] ) - state_pop = jnp.array(model_data["state_pop"]) + population_size = jnp.array(model_data["state_pop"]) - right_truncation_pmf_rv = DeterministicVariable( + ed_right_truncation_pmf_rv = DeterministicVariable( "right_truncation_pmf", jnp.array(model_data["right_truncation_pmf"]) ) @@ -55,8 +66,7 @@ def build_model_from_dir(model_dir): right_truncation_offset = model_data["right_truncation_offset"] - my_model = pyrenew_hew_model( - state_pop=state_pop, + my_latent_infection_model = LatentInfectionProcess( i0_first_obs_n_rv=priors["i0_first_obs_n_rv"], initialization_rate_rv=priors["initialization_rate_rv"], log_r_mu_intercept_rv=priors["log_r_mu_intercept_rv"], @@ -65,14 +75,34 @@ def build_model_from_dir(model_dir): generation_interval_pmf_rv=generation_interval_pmf_rv, infection_feedback_strength_rv=priors["inf_feedback_strength_rv"], infection_feedback_pmf_rv=infection_feedback_pmf_rv, + n_initialization_points=uot, + ) + + my_ed_visit_obs_model = EDVisitObservationProcess( p_ed_mean_rv=priors["p_ed_visit_mean_rv"], p_ed_w_sd_rv=priors["p_ed_visit_w_sd_rv"], autoreg_p_ed_rv=priors["autoreg_p_ed_visit_rv"], ed_wday_effect_rv=priors["ed_visit_wday_effect_rv"], inf_to_ed_rv=inf_to_ed_rv, - phi_rv=priors["phi_rv"], - right_truncation_pmf_rv=right_truncation_pmf_rv, - n_initialization_points=uot, + ed_neg_bin_concentration_rv=(priors["ed_neg_bin_concentration_rv"]), + ed_right_truncation_pmf_rv=ed_right_truncation_pmf_rv, + ) + + my_hosp_admit_obs_model = HospAdmitObservationProcess( + inf_to_hosp_admit_rv=inf_to_hosp_admit_rv, + hosp_admit_neg_bin_concentration_rv=( + priors["hosp_admit_neg_bin_concentration_rv"] + ), + ) + + my_wastewater_obs_model = WastewaterObservationProcess() + + my_model = PyrenewHEWModel( + population_size=population_size, + latent_infection_process_rv=my_latent_infection_model, + ed_visit_obs_process_rv=my_ed_visit_obs_model, + hosp_admit_obs_process_rv=my_hosp_admit_obs_model, + wastewater_obs_process_rv=my_wastewater_obs_model, ) return ( diff --git a/pipelines/priors/prod_priors.py b/pipelines/priors/prod_priors.py index 05c4a9f4..8625c20a 100644 --- a/pipelines/priors/prod_priors.py +++ b/pipelines/priors/prod_priors.py @@ -65,4 +65,11 @@ ) # Based on looking at some historical posteriors. -phi_rv = DistributionalVariable("phi", dist.LogNormal(4, 1)) +ed_neg_bin_concentration_rv = DistributionalVariable( + "ed_visit_neg_bin_concentration", dist.LogNormal(4, 1) +) + +# more diffuse than ED visit, same mean +hosp_admit_neg_bin_concentration_rv = DistributionalVariable( + "hosp_admit_neg_bin_concentration", dist.LogNormal(4, 2) +) diff --git a/pipelines/tests/test_output/priors.py b/pipelines/tests/test_output/priors.py index 4f9d61ab..44b0e286 100644 --- a/pipelines/tests/test_output/priors.py +++ b/pipelines/tests/test_output/priors.py @@ -65,4 +65,11 @@ ) # Based on looking at some historical posteriors. -phi_rv = DistributionalVariable("phi", dist.LogNormal(6, 1)) +ed_neg_bin_concentration_rv = DistributionalVariable( + "ed_visit_neg_bin_concentration", dist.LogNormal(4, 1) +) + +# more diffuse than ED visit, same mean +hosp_admit_neg_bin_concentration_rv = DistributionalVariable( + "hosp_admit_neg_bin_concentration", dist.LogNormal(4, 2) +) diff --git a/pyrenew_hew/pyrenew_hew_model.py b/pyrenew_hew/pyrenew_hew_model.py index 5affccfa..b77f607e 100644 --- a/pyrenew_hew/pyrenew_hew_model.py +++ b/pyrenew_hew/pyrenew_hew_model.py @@ -14,7 +14,7 @@ InfectionsWithFeedback, InitializeInfectionsExponentialGrowth, ) -from pyrenew.metaclass import Model +from pyrenew.metaclass import Model, RandomVariable from pyrenew.observation import NegativeBinomialObservation, PoissonObservation from pyrenew.process import ARProcess, DifferencedProcess from pyrenew.randomvariable import DistributionalVariable, TransformedVariable @@ -22,29 +22,19 @@ from pyrenew_hew.utils import convert_to_logmean_log_sd -class pyrenew_hew_model(Model): # numpydoc ignore=GL08 +class LatentInfectionProcess(RandomVariable): def __init__( self, - state_pop, - i0_first_obs_n_rv, - initialization_rate_rv, - log_r_mu_intercept_rv, - autoreg_rt_rv, # ar process - eta_sd_rv, # sd of random walk for ar process - generation_interval_pmf_rv, - infection_feedback_strength_rv, - infection_feedback_pmf_rv, - p_ed_mean_rv, - p_ed_w_sd_rv, - autoreg_p_ed_rv, - ed_wday_effect_rv, - inf_to_ed_rv, - phi_rv, - right_truncation_pmf_rv, - # when unnamed deterministic variables are allowed, - # we could default this to 1. - n_initialization_points, - ): # numpydoc ignore=GL08 + i0_first_obs_n_rv: RandomVariable, + initialization_rate_rv: RandomVariable, + log_r_mu_intercept_rv: RandomVariable, + autoreg_rt_rv: RandomVariable, # ar coefficient of AR(1) process on R'(t) + eta_sd_rv: RandomVariable, # sd of random walk for ar process on R'(t) + generation_interval_pmf_rv: RandomVariable, + infection_feedback_strength_rv: RandomVariable, + infection_feedback_pmf_rv: RandomVariable, + n_initialization_points: int, + ) -> None: self.infection_initialization_process = InfectionInitializationProcess( "I0_initialization", i0_first_obs_n_rv, @@ -58,74 +48,24 @@ def __init__( infection_feedback_pmf=infection_feedback_pmf_rv, ) - self.p_ed_ar_proc = ARProcess() self.ar_diff = DifferencedProcess( fundamental_process=ARProcess(), differencing_order=1, ) - self.right_truncation_cdf_rv = TransformedVariable( - "right_truncation_cdf", right_truncation_pmf_rv, jnp.cumsum - ) self.autoreg_rt_rv = autoreg_rt_rv self.eta_sd_rv = eta_sd_rv self.log_r_mu_intercept_rv = log_r_mu_intercept_rv self.generation_interval_pmf_rv = generation_interval_pmf_rv self.infection_feedback_pmf_rv = infection_feedback_pmf_rv - self.p_ed_mean_rv = p_ed_mean_rv - self.p_ed_w_sd_rv = p_ed_w_sd_rv - self.autoreg_p_ed_rv = autoreg_p_ed_rv - self.ed_wday_effect_rv = ed_wday_effect_rv - self.inf_to_ed_rv = inf_to_ed_rv - self.phi_rv = phi_rv - self.state_pop = state_pop - self.n_initialization_points = n_initialization_points - return None - def validate(self): # numpydoc ignore=GL08 - return None - - def sample( - self, - n_observed_disease_ed_visits_datapoints=None, - n_observed_hospital_admissions_datapoints=None, - data_observed_disease_ed_visits=None, - data_observed_disease_hospital_admissions=None, - right_truncation_offset=None, - ) -> ArrayLike: # numpydoc ignore=GL08 - if ( - n_observed_disease_ed_visits_datapoints is None - and data_observed_disease_ed_visits is None - ): - raise ValueError( - "Either n_observed_disease_ed_visits_datapoints or " - "data_observed_disease_ed_visits " - "must be passed." - ) - elif ( - n_observed_disease_ed_visits_datapoints is not None - and data_observed_disease_ed_visits is not None - ): - raise ValueError( - "Cannot pass both n_observed_disease_ed_visits_datapoints " - "and data_observed_disease_ed_visits." - ) - elif n_observed_disease_ed_visits_datapoints is None: - n_observed_disease_ed_visits_datapoints = len( - data_observed_disease_ed_visits - ) - - if ( - n_observed_hospital_admissions_datapoints is None - and data_observed_disease_hospital_admissions is not None - ): - n_observed_hospital_admissions_datapoints = len( - data_observed_disease_hospital_admissions - ) - elif n_observed_hospital_admissions_datapoints is None: - n_observed_hospital_admissions_datapoints = 0 + def validate(self): + pass - n_weeks_post_init = n_observed_disease_ed_visits_datapoints // 7 + 1 + def sample(self, n_days_post_init: int, n_weeks_post_init: int): + """ + Sample latent infections. + """ i0 = self.infection_initialization_process() eta_sd = self.eta_sd_rv() @@ -147,7 +87,7 @@ def sample( rtu = repeat_until_n( data=jnp.exp(log_rtu_weekly), - n_timepoints=n_observed_disease_ed_visits_datapoints, + n_timepoints=n_days_post_init, offset=0, period_size=7, ) @@ -170,79 +110,38 @@ def sample( numpyro.deterministic("rt", inf_with_feedback_proc_sample.rt) numpyro.deterministic("latent_infections", latent_infections) - # observation proceses - - sample_ed_visits = True - sample_admissions = n_observed_hospital_admissions_datapoints > 0 - sample_wastewater = False - - sampled_ed_visits, sampled_admissions, sampled_wastewater = ( - None, - None, - None, - ) - - if sample_ed_visits: - sampled_ed_visits = self.sample_ed_visits( - latent_infections=latent_infections, - data_observed_disease_ed_visits=( - data_observed_disease_ed_visits - ), - n_observed_disease_ed_visits_datapoints=( - n_observed_disease_ed_visits_datapoints - ), - n_weeks_post_init=n_weeks_post_init, - right_truncation_offset=right_truncation_offset, - ) + return latent_infections - if sample_admissions: - sampled_admissions = self.sample_hospital_admissions( - latent_infections=latent_infections, - n_admissions_sampled=( - n_observed_hospital_admissions_datapoints - ), - data_observed_disease_hospital_admissions=( - data_observed_disease_hospital_admissions - ), - ) - if sample_wastewater: - sampled_wastewater = self.sample_wastewater() - - return { - "ed_visits": sampled_ed_visits, - "admissions": sampled_admissions, - "wasewater": sampled_wastewater, - } - def sample_hospital_admissions( +class EDVisitObservationProcess(RandomVariable): + def __init__( self, - latent_infections: ArrayLike, - n_admissions_sampled: int, - data_observed_disease_hospital_admissions: ArrayLike | None = None, - ) -> ArrayLike: - """ - Observe and/or predict incident hospital admissions. - """ - hospital_admissions_obs_rv = PoissonObservation( - "observed_hospital_admissions" - ) - # placeholder mean - predicted_admissions = jnp.ones(n_admissions_sampled) - sampled_admissions = hospital_admissions_obs_rv( - mu=predicted_admissions, - obs=data_observed_disease_hospital_admissions, + p_ed_mean_rv: RandomVariable, + p_ed_w_sd_rv: RandomVariable, + autoreg_p_ed_rv: RandomVariable, + ed_wday_effect_rv: RandomVariable, + inf_to_ed_rv: RandomVariable, + ed_neg_bin_concentration_rv: RandomVariable, + ed_right_truncation_pmf_rv: RandomVariable, + ) -> None: + self.p_ed_ar_proc = ARProcess() + self.p_ed_mean_rv = p_ed_mean_rv + self.p_ed_w_sd_rv = p_ed_w_sd_rv + self.autoreg_p_ed_rv = autoreg_p_ed_rv + self.ed_wday_effect_rv = ed_wday_effect_rv + self.inf_to_ed_rv = inf_to_ed_rv + self.ed_right_truncation_cdf_rv = TransformedVariable( + "ed_right_truncation_cdf", ed_right_truncation_pmf_rv, jnp.cumsum ) - return sampled_admissions + self.ed_neg_bin_concentration_rv = ed_neg_bin_concentration_rv - def sample_wastewater(self): - """ - Placeholder for when W component implemented. - """ + def validate(self): pass - def sample_ed_visits( + def sample( self, - latent_infections, + latent_infections: ArrayLike, + population_size: int, data_observed_disease_ed_visits: ArrayLike, n_observed_disease_ed_visits_datapoints: int, n_weeks_post_init: int, @@ -297,12 +196,15 @@ def sample_ed_visits( )[-n_observed_disease_ed_visits_datapoints:] latent_ed_visits_final = ( - potential_latent_ed_visits * iedr * ed_wday_effect * self.state_pop + potential_latent_ed_visits + * iedr + * ed_wday_effect + * population_size ) if right_truncation_offset is not None: prop_already_reported_tail = jnp.flip( - self.right_truncation_cdf_rv()[right_truncation_offset:] + self.ed_right_truncation_cdf_rv()[right_truncation_offset:] ) n_points_to_prepend = ( n_observed_disease_ed_visits_datapoints @@ -321,7 +223,8 @@ def sample_ed_visits( latent_ed_visits_now = latent_ed_visits_final ed_visit_obs_rv = NegativeBinomialObservation( - "observed_ed_visits", concentration_rv=self.phi_rv + "observed_ed_visits", + concentration_rv=self.ed_neg_bin_concentration_rv, ) observed_ed_visits = ed_visit_obs_rv( @@ -332,6 +235,176 @@ def sample_ed_visits( return observed_ed_visits +class HospAdmitObservationProcess(RandomVariable): + def __init__( + self, + inf_to_hosp_admit_rv: RandomVariable, + hosp_admit_neg_bin_concentration_rv: RandomVariable, + ): + self.inf_to_hosp_admit_rv = inf_to_hosp_admit_rv + self.hosp_admit_neg_bin_concentration_rv = ( + hosp_admit_neg_bin_concentration_rv + ) + + def validate(self): + pass + + def sample( + self, + latent_infections: ArrayLike, + population_size: int, + n_observed_hospital_admissions_datapoints: int, + data_observed_disease_hospital_admissions: ArrayLike | None = None, + ) -> ArrayLike: + """ + Observe and/or predict incident hospital admissions. + """ + inf_to_hosp_admit = self.inf_to_hosp_admit_rv() + + latent_hospital_admissions = compute_delay_ascertained_incidence( + p_observed_given_incident=1, + latent_incidence=latent_infections, + delay_incidence_to_observation_pmf=inf_to_hosp_admit, + )[-n_observed_hospital_admissions_datapoints:] + + hospital_admissions_obs_rv = NegativeBinomialObservation( + "observed_hospital_admissions", + concentration_rv=self.hosp_admit_neg_bin_concentration_rv, + ) + predicted_admissions = latent_hospital_admissions + sampled_admissions = hospital_admissions_obs_rv( + mu=predicted_admissions, + obs=data_observed_disease_hospital_admissions, + ) + return sampled_admissions + + +class WastewaterObservationProcess(RandomVariable): + """ + Placeholder for wasteater obs process + """ + + def __init__(self) -> None: + pass + + def sample(self): + pass + + def validate(self): + pass + + +class PyrenewHEWModel(Model): # numpydoc ignore=GL08 + def __init__( + self, + population_size: int, + latent_infection_process_rv: LatentInfectionProcess, + ed_visit_obs_process_rv: EDVisitObservationProcess, + hosp_admit_obs_process_rv: HospAdmitObservationProcess, + wastewater_obs_process_rv: WastewaterObservationProcess, + ) -> None: # numpydoc ignore=GL08 + self.population_size = population_size + self.latent_infection_process_rv = latent_infection_process_rv + self.ed_visit_obs_process_rv = ed_visit_obs_process_rv + self.hosp_admit_obs_process_rv = hosp_admit_obs_process_rv + return None + + def validate(self): # numpydoc ignore=GL08 + return None + + def sample( + self, + n_observed_disease_ed_visits_datapoints=None, + n_observed_hospital_admissions_datapoints=None, + data_observed_disease_ed_visits=None, + data_observed_disease_hospital_admissions=None, + right_truncation_offset=None, + ) -> ArrayLike: # numpydoc ignore=GL08 + if ( + n_observed_disease_ed_visits_datapoints is None + and data_observed_disease_ed_visits is None + ): + raise ValueError( + "Either n_observed_disease_ed_visits_datapoints or " + "data_observed_disease_ed_visits " + "must be passed." + ) + elif ( + n_observed_disease_ed_visits_datapoints is not None + and data_observed_disease_ed_visits is not None + ): + raise ValueError( + "Cannot pass both n_observed_disease_ed_visits_datapoints " + "and data_observed_disease_ed_visits." + ) + elif n_observed_disease_ed_visits_datapoints is None: + n_observed_disease_ed_visits_datapoints = len( + data_observed_disease_ed_visits + ) + + if ( + n_observed_hospital_admissions_datapoints is None + and data_observed_disease_hospital_admissions is not None + ): + n_observed_hospital_admissions_datapoints = len( + data_observed_disease_hospital_admissions + ) + elif n_observed_hospital_admissions_datapoints is None: + n_observed_hospital_admissions_datapoints = 0 + + n_weeks_post_init = n_observed_disease_ed_visits_datapoints // 7 + 1 + n_days_post_init = n_observed_disease_ed_visits_datapoints + + latent_infections = self.latent_infection_process_rv( + n_days_post_init=n_days_post_init, + n_weeks_post_init=n_weeks_post_init, + ) + + sample_ed_visits = True + sample_admissions = n_observed_hospital_admissions_datapoints > 0 + sample_wastewater = False + + sampled_ed_visits, sampled_admissions, sampled_wastewater = ( + None, + None, + None, + ) + + if sample_ed_visits: + sampled_ed_visits = self.ed_visit_obs_process_rv( + latent_infections=latent_infections, + population_size=self.population_size, + data_observed_disease_ed_visits=( + data_observed_disease_ed_visits + ), + n_observed_disease_ed_visits_datapoints=( + n_observed_disease_ed_visits_datapoints + ), + n_weeks_post_init=n_weeks_post_init, + right_truncation_offset=right_truncation_offset, + ) + + if sample_admissions: + sampled_admissions = self.hosp_admit_obs_process_rv( + latent_infections=latent_infections, + population_size=self.population_size, + n_observed_hospital_admissions_datapoints=( + n_observed_hospital_admissions_datapoints + ), + data_observed_disease_hospital_admissions=( + data_observed_disease_hospital_admissions + ), + ) + if sample_wastewater: + sampled_wastewater = self.wastewater_obs_process_rv() + + return { + "ed_visits": sampled_ed_visits, + "admissions": sampled_admissions, + "wasewater": sampled_wastewater, + } + + def create_pyrenew_hew_model_from_stan_data(stan_data_file): with open( stan_data_file, @@ -444,10 +517,27 @@ def create_pyrenew_hew_model_from_stan_data(stan_data_file): inv_sqrt_phi_prior_mean = stan_data["inv_sqrt_phi_prior_mean"] inv_sqrt_phi_prior_sd = stan_data["inv_sqrt_phi_prior_sd"] - phi_rv = TransformedVariable( - "phi", + ed_neg_bin_concentration_rv = TransformedVariable( + "ed_visit_neg_bin_concentration", + DistributionalVariable( + "inv_sqrt_ed_visit_neg_bin_conc", + dist.TruncatedNormal( + loc=inv_sqrt_phi_prior_mean, + scale=inv_sqrt_phi_prior_sd, + low=1 / jnp.sqrt(5000), + ), + ), + transforms=transformation.PowerTransform(-2), + ) + + inf_to_hosp_admit_rv = DeterministicVariable( + "inf_to_hosp_admit", jnp.array(stan_data["inf_to_hosp"]) + ) + + hosp_admit_neg_bin_concentration_rv = TransformedVariable( + "hosp_admit_neg_bin_concentration", DistributionalVariable( - "inv_sqrt_phi", + "inv_sqrt_hosp_admit_neg_bin_conc", dist.TruncatedNormal( loc=inv_sqrt_phi_prior_mean, scale=inv_sqrt_phi_prior_sd, @@ -459,14 +549,14 @@ def create_pyrenew_hew_model_from_stan_data(stan_data_file): uot = stan_data["uot"] uot = len(jnp.array(stan_data["inf_to_hosp"])) - state_pop = stan_data["state_pop"] + population_size = stan_data["state_pop"] data_observed_disease_ed_visits = jnp.array(stan_data["hosp"]) - right_truncation_pmf_rv = DeterministicVariable( - "right_truncation_pmf", jnp.array(1) + ed_right_truncation_pmf_rv = DeterministicVariable( + "ed_visit_right_truncation_pmf", jnp.array(1) ) - my_model = pyrenew_hew_model( - state_pop=state_pop, + + my_latent_infection_model = LatentInfectionProcess( i0_first_obs_n_rv=i0_first_obs_n_rv, initialization_rate_rv=initialization_rate_rv, log_r_mu_intercept_rv=log_r_mu_intercept_rv, @@ -475,14 +565,34 @@ def create_pyrenew_hew_model_from_stan_data(stan_data_file): generation_interval_pmf_rv=generation_interval_pmf_rv, infection_feedback_pmf_rv=infection_feedback_pmf_rv, infection_feedback_strength_rv=inf_feedback_strength_rv, + n_initialization_points=uot, + ) + + my_ed_visit_obs_model = EDVisitObservationProcess( p_ed_mean_rv=p_ed_mean_rv, p_ed_w_sd_rv=p_ed_w_sd_rv, autoreg_p_ed_rv=autoreg_p_ed_rv, ed_wday_effect_rv=ed_wday_effect_rv, inf_to_ed_rv=inf_to_ed_rv, - phi_rv=phi_rv, - right_truncation_pmf_rv=right_truncation_pmf_rv, - n_initialization_points=uot, + ed_neg_bin_concentration_rv=ed_neg_bin_concentration_rv, + ed_right_truncation_pmf_rv=ed_right_truncation_pmf_rv, + ) + + my_hosp_admit_obs_model = HospAdmitObservationProcess( + inf_to_hosp_admit_rv=inf_to_hosp_admit_rv, + hosp_admit_neg_bin_concentration_rv=( + hosp_admit_neg_bin_concentration_rv + ), + ) + + my_wastewater_obs_model = WastewaterObservationProcess() + + my_model = PyrenewHEWModel( + population_size=population_size, + latent_infection_process_rv=my_latent_infection_model, + ed_visit_obs_process_rv=my_ed_visit_obs_model, + hosp_admit_obs_process_rv=my_hosp_admit_obs_model, + wastewater_obs_process_rv=my_wastewater_obs_model, ) return my_model, data_observed_disease_ed_visits From 32ba7d08cca0024ae779de76f34d916072d2f554 Mon Sep 17 00:00:00 2001 From: "Dylan H. Morris" Date: Tue, 7 Jan 2025 12:04:29 -0500 Subject: [PATCH 09/32] Checkpoint commit --- pyrenew_hew/pyrenew_hew_model.py | 22 ++++++++++++++++++---- 1 file changed, 18 insertions(+), 4 deletions(-) diff --git a/pyrenew_hew/pyrenew_hew_model.py b/pyrenew_hew/pyrenew_hew_model.py index b77f607e..ea30355b 100644 --- a/pyrenew_hew/pyrenew_hew_model.py +++ b/pyrenew_hew/pyrenew_hew_model.py @@ -7,7 +7,10 @@ import pyrenew.transformation as transformation from jax.typing import ArrayLike from pyrenew.arrayutils import repeat_until_n, tile_until_n -from pyrenew.convolve import compute_delay_ascertained_incidence +from pyrenew.convolve import ( + compute_delay_ascertained_incidence, + daily_to_weekly, +) from pyrenew.deterministic import DeterministicVariable from pyrenew.latent import ( InfectionInitializationProcess, @@ -15,9 +18,10 @@ InitializeInfectionsExponentialGrowth, ) from pyrenew.metaclass import Model, RandomVariable -from pyrenew.observation import NegativeBinomialObservation, PoissonObservation +from pyrenew.observation import NegativeBinomialObservation from pyrenew.process import ARProcess, DifferencedProcess from pyrenew.randomvariable import DistributionalVariable, TransformedVariable +from pyrenew.utils import daily_t from pyrenew_hew.utils import convert_to_logmean_log_sd @@ -252,6 +256,7 @@ def validate(self): def sample( self, latent_infections: ArrayLike, + first_latent_infection_dow: int, population_size: int, n_observed_hospital_admissions_datapoints: int, data_observed_disease_hospital_admissions: ArrayLike | None = None, @@ -265,15 +270,24 @@ def sample( p_observed_given_incident=1, latent_incidence=latent_infections, delay_incidence_to_observation_pmf=inf_to_hosp_admit, + ) + + assert latent_hospital_admissions.shape == latent_infections.shape + + first_latent_admission_dow = first_latent_infection_dow + + predicted_weekly_admissions = daily_to_weekly( + latent_hospital_admissions, + input_data_first_dow=first_latent_admission_dow, + week_start_dow=6, # MMWR epiweek, starts Sunday )[-n_observed_hospital_admissions_datapoints:] hospital_admissions_obs_rv = NegativeBinomialObservation( "observed_hospital_admissions", concentration_rv=self.hosp_admit_neg_bin_concentration_rv, ) - predicted_admissions = latent_hospital_admissions sampled_admissions = hospital_admissions_obs_rv( - mu=predicted_admissions, + mu=predicted_weekly_admissions, obs=data_observed_disease_hospital_admissions, ) return sampled_admissions From c9e8cacae591bb2d1bd87be7656097bcaa0f0232 Mon Sep 17 00:00:00 2001 From: "Dylan H. Morris" Date: Tue, 7 Jan 2025 13:46:43 -0500 Subject: [PATCH 10/32] Make models day-of-week aware (needed for epiweekly obs) --- pipelines/build_pyrenew_model.py | 6 ++++++ pipelines/fit_pyrenew_model.py | 2 ++ pipelines/generate_predictive.py | 2 ++ pyrenew_hew/pyrenew_hew_model.py | 32 +++++++++++++++++++++----------- 4 files changed, 31 insertions(+), 11 deletions(-) diff --git a/pipelines/build_pyrenew_model.py b/pipelines/build_pyrenew_model.py index f319d2e9..76f87ebf 100644 --- a/pipelines/build_pyrenew_model.py +++ b/pipelines/build_pyrenew_model.py @@ -1,3 +1,4 @@ +import datetime import json import runpy @@ -62,6 +63,10 @@ def build_model_from_dir(model_dir): - 1 ) + first_observation_date = datetime.datetime.strptime( + model_data["nssp_training_dates"][0], "%Y-%m-%d" + ) + priors = runpy.run_path(str(prior_path)) right_truncation_offset = model_data["right_truncation_offset"] @@ -110,4 +115,5 @@ def build_model_from_dir(model_dir): data_observed_disease_ed_visits, data_observed_disease_hospital_admissions, right_truncation_offset, + first_observation_date, ) diff --git a/pipelines/fit_pyrenew_model.py b/pipelines/fit_pyrenew_model.py index 43ce5442..ecb1fbf4 100644 --- a/pipelines/fit_pyrenew_model.py +++ b/pipelines/fit_pyrenew_model.py @@ -31,6 +31,7 @@ def fit_and_save_model( data_observed_disease_ed_visits, data_observed_disease_hospital_admissions, right_truncation_offset, + first_observation_date, ) = build_model_from_dir(model_run_dir) my_model.run( num_warmup=n_warmup, @@ -38,6 +39,7 @@ def fit_and_save_model( rng_key=rng_key, data_observed_disease_ed_visits=(data_observed_disease_ed_visits), right_truncation_offset=right_truncation_offset, + first_observation_date=first_observation_date, mcmc_args=dict(num_chains=n_chains, progress_bar=True), nuts_args=dict(find_heuristic_step_size=True), ) diff --git a/pipelines/generate_predictive.py b/pipelines/generate_predictive.py index e483af8c..615c2891 100644 --- a/pipelines/generate_predictive.py +++ b/pipelines/generate_predictive.py @@ -20,6 +20,7 @@ def generate_and_save_predictions( data_observed_disease_ed_visits, data_observed_disease_hospital_admissions, right_truncation_offset, + first_observation_date, ) = build_model_from_dir(model_run_dir) my_model._init_model(1, 1) @@ -42,6 +43,7 @@ def generate_and_save_predictions( data_observed_disease_hospital_admissions ) + n_forecast_points // 7, + first_observation_date=first_observation_date, ) idata = az.from_numpyro( diff --git a/pyrenew_hew/pyrenew_hew_model.py b/pyrenew_hew/pyrenew_hew_model.py index ea30355b..3fc14ec0 100644 --- a/pyrenew_hew/pyrenew_hew_model.py +++ b/pyrenew_hew/pyrenew_hew_model.py @@ -1,4 +1,5 @@ # numpydoc ignore=GL08 +import datetime import json import jax.numpy as jnp @@ -21,7 +22,6 @@ from pyrenew.observation import NegativeBinomialObservation from pyrenew.process import ARProcess, DifferencedProcess from pyrenew.randomvariable import DistributionalVariable, TransformedVariable -from pyrenew.utils import daily_t from pyrenew_hew.utils import convert_to_logmean_log_sd @@ -32,8 +32,8 @@ def __init__( i0_first_obs_n_rv: RandomVariable, initialization_rate_rv: RandomVariable, log_r_mu_intercept_rv: RandomVariable, - autoreg_rt_rv: RandomVariable, # ar coefficient of AR(1) process on R'(t) - eta_sd_rv: RandomVariable, # sd of random walk for ar process on R'(t) + autoreg_rt_rv: RandomVariable, # ar coeff for AR(1) on R'(t) + eta_sd_rv: RandomVariable, # sd of random walk for AR(1) on R'(t) generation_interval_pmf_rv: RandomVariable, infection_feedback_strength_rv: RandomVariable, infection_feedback_pmf_rv: RandomVariable, @@ -149,7 +149,7 @@ def sample( data_observed_disease_ed_visits: ArrayLike, n_observed_disease_ed_visits_datapoints: int, n_weeks_post_init: int, - right_truncation_offset: ArrayLike = None, + right_truncation_offset: int = None, ) -> ArrayLike: """ Observe and/or predict ED visit values @@ -272,9 +272,13 @@ def sample( delay_incidence_to_observation_pmf=inf_to_hosp_admit, ) - assert latent_hospital_admissions.shape == latent_infections.shape + longest_possible_delay = inf_to_hosp_admit.shape[0] - first_latent_admission_dow = first_latent_infection_dow + # we should add functionality to automate this, + # along with tests + first_latent_admission_dow = ( + first_latent_infection_dow + longest_possible_delay + ) % 7 predicted_weekly_admissions = daily_to_weekly( latent_hospital_admissions, @@ -328,11 +332,12 @@ def validate(self): # numpydoc ignore=GL08 def sample( self, - n_observed_disease_ed_visits_datapoints=None, - n_observed_hospital_admissions_datapoints=None, - data_observed_disease_ed_visits=None, - data_observed_disease_hospital_admissions=None, - right_truncation_offset=None, + n_observed_disease_ed_visits_datapoints: int = None, + n_observed_hospital_admissions_datapoints: int = None, + data_observed_disease_ed_visits: ArrayLike = None, + data_observed_disease_hospital_admissions: ArrayLike = None, + right_truncation_offset: int = None, + first_observation_date: datetime.datetime.date = None, ) -> ArrayLike: # numpydoc ignore=GL08 if ( n_observed_disease_ed_visits_datapoints is None @@ -373,6 +378,10 @@ def sample( n_days_post_init=n_days_post_init, n_weeks_post_init=n_weeks_post_init, ) + n_init_days = self.latent_infection_process_rv.infection_initialization_process.infection_init_method.n_timepoints + first_latent_infection_dow = ( + first_observation_date - datetime.timedelta(days=n_init_days) + ).weekday() sample_ed_visits = True sample_admissions = n_observed_hospital_admissions_datapoints > 0 @@ -401,6 +410,7 @@ def sample( if sample_admissions: sampled_admissions = self.hosp_admit_obs_process_rv( latent_infections=latent_infections, + first_latent_infection_dow=first_latent_infection_dow, population_size=self.population_size, n_observed_hospital_admissions_datapoints=( n_observed_hospital_admissions_datapoints From 8f175f3930b8bc8ec3a44428936e1e6daa173781 Mon Sep 17 00:00:00 2001 From: "Dylan H. Morris" Date: Wed, 8 Jan 2025 08:15:13 -0500 Subject: [PATCH 11/32] Add helper data class --- pipelines/build_pyrenew_model.py | 33 ++- pipelines/fit_pyrenew_model.py | 15 +- pipelines/generate_predictive.py | 25 +-- pipelines/priors/prod_priors.py | 16 ++ pyrenew_hew/pyrenew_hew_data.py | 173 +++++++++++++++ pyrenew_hew/pyrenew_hew_model.py | 351 ++++++------------------------- 6 files changed, 287 insertions(+), 326 deletions(-) create mode 100644 pyrenew_hew/pyrenew_hew_data.py diff --git a/pipelines/build_pyrenew_model.py b/pipelines/build_pyrenew_model.py index 76f87ebf..7e05f32c 100644 --- a/pipelines/build_pyrenew_model.py +++ b/pipelines/build_pyrenew_model.py @@ -1,10 +1,12 @@ import datetime import json import runpy +from pathlib import Path import jax.numpy as jnp from pyrenew.deterministic import DeterministicVariable +from pyrenew_hew.pyrenew_hew_data import PyrenewHEWData from pyrenew_hew.pyrenew_hew_model import ( EDVisitObservationProcess, HospAdmitObservationProcess, @@ -14,9 +16,11 @@ ) -def build_model_from_dir(model_dir): - data_path = model_dir / "data" / "data_for_model_fit.json" - prior_path = model_dir / "priors.py" +def build_model_from_dir( + model_dir: Path, +) -> tuple[PyrenewHEWModel, PyrenewHEWData]: + data_path = Path(model_dir) / "data" / "data_for_model_fit.json" + prior_path = Path(model_dir) / "priors.py" with open( data_path, @@ -63,9 +67,12 @@ def build_model_from_dir(model_dir): - 1 ) - first_observation_date = datetime.datetime.strptime( + first_ed_visits_date = datetime.datetime.strptime( model_data["nssp_training_dates"][0], "%Y-%m-%d" ) + first_hospital_admissions_date = datetime.datetime.strptime( + model_data["nhsn_training_dates"][0], "%Y-%m-%d" + ) priors = runpy.run_path(str(prior_path)) @@ -98,6 +105,8 @@ def build_model_from_dir(model_dir): hosp_admit_neg_bin_concentration_rv=( priors["hosp_admit_neg_bin_concentration_rv"] ), + ihr_rel_iedr_rv=priors["ihr_rel_iedr_rv"], + ihr_rv=priors["ihr_rv"], ) my_wastewater_obs_model = WastewaterObservationProcess() @@ -110,10 +119,14 @@ def build_model_from_dir(model_dir): wastewater_obs_process_rv=my_wastewater_obs_model, ) - return ( - my_model, - data_observed_disease_ed_visits, - data_observed_disease_hospital_admissions, - right_truncation_offset, - first_observation_date, + my_data = PyrenewHEWData( + data_observed_disease_ed_visits=data_observed_disease_ed_visits, + data_observed_disease_hospital_admissions=( + data_observed_disease_hospital_admissions + ), + right_truncation_offset=right_truncation_offset, + first_ed_visits_date=first_ed_visits_date, + first_hospital_admissions_date=first_hospital_admissions_date, ) + + return (my_model, my_data) diff --git a/pipelines/fit_pyrenew_model.py b/pipelines/fit_pyrenew_model.py index ecb1fbf4..432b2fd9 100644 --- a/pipelines/fit_pyrenew_model.py +++ b/pipelines/fit_pyrenew_model.py @@ -26,20 +26,15 @@ def fit_and_save_model( "rng_key must be an integer with which " "to seed :func:`jax.random.key`" ) - ( - my_model, - data_observed_disease_ed_visits, - data_observed_disease_hospital_admissions, - right_truncation_offset, - first_observation_date, - ) = build_model_from_dir(model_run_dir) + (my_model, my_data) = build_model_from_dir(model_run_dir) my_model.run( + data=my_data, + sample_ed_visits=True, + sample_hospital_admissions=False, + sample_wastewater=False, num_warmup=n_warmup, num_samples=n_samples, rng_key=rng_key, - data_observed_disease_ed_visits=(data_observed_disease_ed_visits), - right_truncation_offset=right_truncation_offset, - first_observation_date=first_observation_date, mcmc_args=dict(num_chains=n_chains, progress_bar=True), nuts_args=dict(find_heuristic_step_size=True), ) diff --git a/pipelines/generate_predictive.py b/pipelines/generate_predictive.py index 615c2891..e205c49b 100644 --- a/pipelines/generate_predictive.py +++ b/pipelines/generate_predictive.py @@ -15,13 +15,7 @@ def generate_and_save_predictions( model_dir = Path(model_run_dir, model_name) if not model_dir.exists(): raise FileNotFoundError(f"The directory {model_dir} does not exist.") - ( - my_model, - data_observed_disease_ed_visits, - data_observed_disease_hospital_admissions, - right_truncation_offset, - first_observation_date, - ) = build_model_from_dir(model_run_dir) + (my_model, my_data) = build_model_from_dir(model_run_dir) my_model._init_model(1, 1) fresh_sampler = my_model.mcmc.sampler @@ -33,22 +27,17 @@ def generate_and_save_predictions( my_model.mcmc = pickle.load(file) my_model.mcmc.sampler = fresh_sampler + forecast_data = my_data.to_forecast_data(n_forecast_points) posterior_predictive = my_model.posterior_predictive( - n_observed_disease_ed_visits_datapoints=len( - data_observed_disease_ed_visits - ) - + n_forecast_points, - n_observed_hospital_admissions_datapoints=len( - data_observed_disease_hospital_admissions - ) - + n_forecast_points // 7, - first_observation_date=first_observation_date, + data=forecast_data, + sample_ed_visits=True, + sample_hospital_admissions=True, + sample_wastewater=True, ) idata = az.from_numpyro( - my_model.mcmc, - posterior_predictive=posterior_predictive, + my_model.mcmc, posterior_predictive=posterior_predictive ) idata.to_dataframe().to_csv(model_dir / "inference_data.csv", index=False) diff --git a/pipelines/priors/prod_priors.py b/pipelines/priors/prod_priors.py index 8625c20a..a50c3c6e 100644 --- a/pipelines/priors/prod_priors.py +++ b/pipelines/priors/prod_priors.py @@ -45,6 +45,18 @@ ), ) # logit scale +ihr_rv = TransformedVariable( + "ihr", + DistributionalVariable( + "logit_ihr", + dist.Normal( + transformation.SigmoidTransform().inv(0.005), + 0.3, + ), + ), + transforms=transformation.SigmoidTransform(), +) + p_ed_visit_w_sd_rv = DistributionalVariable( "p_ed_visit_w_sd_sd", dist.TruncatedNormal(0, 0.01, low=0) @@ -64,6 +76,10 @@ transformation.AffineTransform(loc=0, scale=7), ) +ihr_rel_iedr_rv = DistributionalVariable( + "ihr_rel_iedr", dist.LogNormal(0, jnp.log(jnp.sqrt(2))) +) + # Based on looking at some historical posteriors. ed_neg_bin_concentration_rv = DistributionalVariable( "ed_visit_neg_bin_concentration", dist.LogNormal(4, 1) diff --git a/pyrenew_hew/pyrenew_hew_data.py b/pyrenew_hew/pyrenew_hew_data.py new file mode 100644 index 00000000..b074cb15 --- /dev/null +++ b/pyrenew_hew/pyrenew_hew_data.py @@ -0,0 +1,173 @@ +import datetime +from typing import Self + +from jax.typing import ArrayLike + + +class PyrenewHEWData: + """ + Class for holding input data + to a PyrenewHEW model. + """ + + def __init__( + self, + n_ed_visits_datapoints: int = None, + n_hospital_admissions_datapoints: int = None, + n_wastewater_datapoints: int = None, + data_observed_disease_ed_visits: ArrayLike = None, + data_observed_disease_hospital_admissions: ArrayLike = None, + data_observed_disease_wastewater: ArrayLike = None, + right_truncation_offset: int = None, + first_ed_visits_date: datetime.datetime.date = None, + first_hospital_admissions_date: datetime.datetime.date = None, + first_wastewater_date: datetime.datetime.date = None, + ) -> None: + self.n_ed_visits_datapoints_ = n_ed_visits_datapoints + self.n_hospital_admissions_datapoints_ = ( + n_hospital_admissions_datapoints + ) + self.n_wastewater_datapoints_ = n_wastewater_datapoints + + self.data_observed_disease_ed_visits = data_observed_disease_ed_visits + self.data_observed_disease_hospital_admissions = ( + data_observed_disease_hospital_admissions + ) + self.data_observed_disease_wastewater = ( + data_observed_disease_wastewater + ) + + self.right_truncation_offset = right_truncation_offset + + self.first_ed_visits_date = first_ed_visits_date + self.first_hospital_admissions_date = first_hospital_admissions_date + self.first_wastewater_date = first_wastewater_date + + @property + def n_ed_visits_datapoints(self): + return self.get_n_datapoints( + n_datapoints=self.n_ed_visits_datapoints_, + data_array=self.data_observed_disease_ed_visits, + ) + + @property + def n_hospital_admissions_datapoints(self): + return self.get_n_datapoints( + n_datapoints=self.n_hospital_admissions_datapoints_, + data_array=self.data_observed_disease_hospital_admissions, + ) + + @property + def n_wastewater_datapoints(self): + return self.get_n_datapoints( + n_datapoints=self.n_wastewater_datapoints_, + data_array=self.data_observed_disease_wastewater, + ) + + @property + def last_ed_visits_date(self): + return self.get_end_date( + self.first_ed_visits_date, + self.n_ed_visits_datapoints, + timestep_days=1, + ) + + @property + def last_hospital_admissions_date(self): + return self.get_end_date( + self.first_hospital_admissions_date, + self.n_hospital_admissions_datapoints, + timestep_days=7, + ) + + @property + def last_wastewater_date(self): + return self.get_end_date( + self.first_wastewater_date, + self.n_wastewater_datapoints, + timestep_days=1, + ) + + @property + def first_data_dates(self): + return [ + self.first_ed_visits_date, + self.first_hospital_admissions_date, + self.first_wastewater_date, + ] + + @property + def last_data_dates(self): + return [ + self.last_ed_visits_date, + self.last_hospital_admissions_date, + self.last_wastewater_date, + ] + + @property + def first_data_date_overall(self): + return min([x for x in self.first_data_dates if x is not None]) + + @property + def last_data_date_overall(self): + return max([x for x in self.last_data_dates if x is not None]) + + @property + def n_days_post_init(self): + return ( + self.last_data_date_overall - self.first_data_date_overall + ).days + + def get_end_date( + self, + first_date: datetime.datetime.date, + n_datapoints: int, + timestep_days: int = 1, + ) -> datetime.datetime.date: + """ + Get end date from a first date and a number of datapoints, + with handling of None values and non-daily timeseries + """ + if first_date is None: + if n_datapoints != 0: + raise ValueError( + "Must provide an initial date if " + "n_datapoints is non-zero. " + f"Got n_datapoints = {n_datapoints} " + "but first_date was `None`" + ) + result = None + else: + result = first_date + datetime.timedelta( + days=n_datapoints * timestep_days + ) + return result + + def get_n_datapoints( + self, n_datapoints: int = None, data_array: ArrayLike = None + ) -> int: + if n_datapoints is None and data_array is None: + return 0 + elif data_array is not None and n_datapoints is not None: + raise ValueError( + "Must provide at most one out of a " + "number of datapoints to simulate and " + "an array of observed data." + ) + elif data_array is not None: + return data_array.shape[0] + else: + return n_datapoints + + def to_forecast_data(self, n_forecast_points: int) -> Self: + n_days = self.n_days_post_init + n_forecast_points + n_weeks = n_days // 7 + return PyrenewHEWData( + n_ed_visits_datapoints=n_days, + n_hospital_admissions_datapoints=n_weeks, + n_wastewater_datapoints=n_days, + first_ed_visits_date=self.first_data_date_overall, + first_hospital_admissions_date=(self.first_data_date_overall), + first_wastewater_date=self.first_data_date_overall, + right_truncation_offset=0, + ) diff --git a/pyrenew_hew/pyrenew_hew_model.py b/pyrenew_hew/pyrenew_hew_model.py index 3fc14ec0..99d1703f 100644 --- a/pyrenew_hew/pyrenew_hew_model.py +++ b/pyrenew_hew/pyrenew_hew_model.py @@ -1,6 +1,5 @@ # numpydoc ignore=GL08 import datetime -import json import jax.numpy as jnp import numpyro @@ -12,7 +11,6 @@ compute_delay_ascertained_incidence, daily_to_weekly, ) -from pyrenew.deterministic import DeterministicVariable from pyrenew.latent import ( InfectionInitializationProcess, InfectionsWithFeedback, @@ -23,7 +21,7 @@ from pyrenew.process import ARProcess, DifferencedProcess from pyrenew.randomvariable import DistributionalVariable, TransformedVariable -from pyrenew_hew.utils import convert_to_logmean_log_sd +from pyrenew_hew.pyrenew_hew_data import PyrenewHEWData class LatentInfectionProcess(RandomVariable): @@ -66,9 +64,15 @@ def __init__( def validate(self): pass - def sample(self, n_days_post_init: int, n_weeks_post_init: int): + def sample(self, n_days_post_init: int): """ Sample latent infections. + + Parameters + ---------- + n_days_post_init + Number of days of infections to sample, not including + the initialization period. """ i0 = self.infection_initialization_process() @@ -80,8 +84,10 @@ def sample(self, n_days_post_init: int, n_weeks_post_init: int): dist.Normal(0, eta_sd / jnp.sqrt(1 - jnp.pow(autoreg_rt, 2))), )() + n_weeks_rt = n_days_post_init // 7 + 1 + log_rtu_weekly = self.ar_diff( - n=n_weeks_post_init, + n=n_weeks_rt, init_vals=jnp.array(log_r_mu_intercept), autoreg=jnp.array(autoreg_rt), noise_sd=jnp.array(eta_sd), @@ -146,17 +152,17 @@ def sample( self, latent_infections: ArrayLike, population_size: int, - data_observed_disease_ed_visits: ArrayLike, - n_observed_disease_ed_visits_datapoints: int, - n_weeks_post_init: int, + data_observed: ArrayLike, + n_datapoints: int, right_truncation_offset: int = None, - ) -> ArrayLike: + ) -> tuple[ArrayLike]: """ Observe and/or predict ED visit values """ p_ed_mean = self.p_ed_mean_rv() p_ed_w_sd = self.p_ed_w_sd_rv() autoreg_p_ed = self.autoreg_p_ed_rv() + n_weeks_p_ed_ar = n_datapoints // 7 + 1 p_ed_ar_init_rv = DistributionalVariable( "p_ed_ar_init", @@ -169,7 +175,7 @@ def sample( p_ed_ar = self.p_ed_ar_proc( noise_name="p_ed", - n=n_weeks_post_init, + n=n_weeks_p_ed_ar, autoreg=autoreg_p_ed, init_vals=p_ed_ar_init, noise_sd=p_ed_w_sd, @@ -178,7 +184,7 @@ def sample( iedr = jnp.repeat( transformation.SigmoidTransform()(p_ed_ar + p_ed_mean), repeats=7, - )[:n_observed_disease_ed_visits_datapoints] + )[:n_datapoints] # this is only applied after the ed visits are generated, not to all # the latent infections. This is why we cannot apply the iedr in # compute_delay_ascertained_incidence @@ -187,9 +193,7 @@ def sample( numpyro.deterministic("iedr", iedr) ed_wday_effect_raw = self.ed_wday_effect_rv() - ed_wday_effect = tile_until_n( - ed_wday_effect_raw, n_observed_disease_ed_visits_datapoints - ) + ed_wday_effect = tile_until_n(ed_wday_effect_raw, n_datapoints) inf_to_ed = self.inf_to_ed_rv() @@ -197,7 +201,7 @@ def sample( p_observed_given_incident=1, latent_incidence=latent_infections, delay_incidence_to_observation_pmf=inf_to_ed, - )[-n_observed_disease_ed_visits_datapoints:] + )[-n_datapoints:] latent_ed_visits_final = ( potential_latent_ed_visits @@ -211,8 +215,7 @@ def sample( self.ed_right_truncation_cdf_rv()[right_truncation_offset:] ) n_points_to_prepend = ( - n_observed_disease_ed_visits_datapoints - - prop_already_reported_tail.shape[0] + n_datapoints - prop_already_reported_tail.shape[0] ) prop_already_reported = jnp.pad( prop_already_reported_tail, @@ -231,12 +234,12 @@ def sample( concentration_rv=self.ed_neg_bin_concentration_rv, ) - observed_ed_visits = ed_visit_obs_rv( + sampled_ed_visits = ed_visit_obs_rv( mu=latent_ed_visits_now, - obs=data_observed_disease_ed_visits, + obs=data_observed, ) - return observed_ed_visits + return sampled_ed_visits, iedr class HospAdmitObservationProcess(RandomVariable): @@ -244,11 +247,15 @@ def __init__( self, inf_to_hosp_admit_rv: RandomVariable, hosp_admit_neg_bin_concentration_rv: RandomVariable, + ihr_rv: RandomVariable = None, + ihr_rel_iedr_rv: RandomVariable = None, ): self.inf_to_hosp_admit_rv = inf_to_hosp_admit_rv self.hosp_admit_neg_bin_concentration_rv = ( hosp_admit_neg_bin_concentration_rv ) + self.ihr_rv = ihr_rv + self.ihr_rel_iedr_rv = ihr_rel_iedr_rv def validate(self): pass @@ -258,17 +265,26 @@ def sample( latent_infections: ArrayLike, first_latent_infection_dow: int, population_size: int, - n_observed_hospital_admissions_datapoints: int, - data_observed_disease_hospital_admissions: ArrayLike | None = None, + n_datapoints: int, + data_observed: ArrayLike = None, + iedr: ArrayLike = None, ) -> ArrayLike: """ Observe and/or predict incident hospital admissions. """ inf_to_hosp_admit = self.inf_to_hosp_admit_rv() + if iedr is not None: + ihr_rel_iedr = self.ihr_rel_iedr_rv() + ihr = iedr[0] * ihr_rel_iedr + numpyro.deterministic("ihr", ihr) + else: + ihr = self.ihr_rv() + + latent_admissions = population_size * ihr * latent_infections latent_hospital_admissions = compute_delay_ascertained_incidence( p_observed_given_incident=1, - latent_incidence=latent_infections, + latent_incidence=latent_admissions, delay_incidence_to_observation_pmf=inf_to_hosp_admit, ) @@ -284,22 +300,22 @@ def sample( latent_hospital_admissions, input_data_first_dow=first_latent_admission_dow, week_start_dow=6, # MMWR epiweek, starts Sunday - )[-n_observed_hospital_admissions_datapoints:] + ) hospital_admissions_obs_rv = NegativeBinomialObservation( "observed_hospital_admissions", concentration_rv=self.hosp_admit_neg_bin_concentration_rv, ) sampled_admissions = hospital_admissions_obs_rv( - mu=predicted_weekly_admissions, - obs=data_observed_disease_hospital_admissions, + mu=predicted_weekly_admissions[-n_datapoints:], + obs=data_observed, ) return sampled_admissions class WastewaterObservationProcess(RandomVariable): """ - Placeholder for wasteater obs process + Placeholder for wastewater obs process """ def __init__(self) -> None: @@ -325,298 +341,57 @@ def __init__( self.latent_infection_process_rv = latent_infection_process_rv self.ed_visit_obs_process_rv = ed_visit_obs_process_rv self.hosp_admit_obs_process_rv = hosp_admit_obs_process_rv - return None + self.wastewater_obs_process_rv = wastewater_obs_process_rv - def validate(self): # numpydoc ignore=GL08 - return None + def validate(self) -> None: # numpydoc ignore=GL08 + pass def sample( self, - n_observed_disease_ed_visits_datapoints: int = None, - n_observed_hospital_admissions_datapoints: int = None, - data_observed_disease_ed_visits: ArrayLike = None, - data_observed_disease_hospital_admissions: ArrayLike = None, - right_truncation_offset: int = None, - first_observation_date: datetime.datetime.date = None, + data: PyrenewHEWData = None, + sample_ed_visits: bool = False, + sample_hospital_admissions: bool = False, + sample_wastewater: bool = False, ) -> ArrayLike: # numpydoc ignore=GL08 - if ( - n_observed_disease_ed_visits_datapoints is None - and data_observed_disease_ed_visits is None - ): - raise ValueError( - "Either n_observed_disease_ed_visits_datapoints or " - "data_observed_disease_ed_visits " - "must be passed." - ) - elif ( - n_observed_disease_ed_visits_datapoints is not None - and data_observed_disease_ed_visits is not None - ): - raise ValueError( - "Cannot pass both n_observed_disease_ed_visits_datapoints " - "and data_observed_disease_ed_visits." - ) - elif n_observed_disease_ed_visits_datapoints is None: - n_observed_disease_ed_visits_datapoints = len( - data_observed_disease_ed_visits - ) - - if ( - n_observed_hospital_admissions_datapoints is None - and data_observed_disease_hospital_admissions is not None - ): - n_observed_hospital_admissions_datapoints = len( - data_observed_disease_hospital_admissions - ) - elif n_observed_hospital_admissions_datapoints is None: - n_observed_hospital_admissions_datapoints = 0 - - n_weeks_post_init = n_observed_disease_ed_visits_datapoints // 7 + 1 - n_days_post_init = n_observed_disease_ed_visits_datapoints - latent_infections = self.latent_infection_process_rv( - n_days_post_init=n_days_post_init, - n_weeks_post_init=n_weeks_post_init, + n_days_post_init=data.n_days_post_init, ) n_init_days = self.latent_infection_process_rv.infection_initialization_process.infection_init_method.n_timepoints first_latent_infection_dow = ( - first_observation_date - datetime.timedelta(days=n_init_days) + data.first_data_date_overall - datetime.timedelta(days=n_init_days) ).weekday() - sample_ed_visits = True - sample_admissions = n_observed_hospital_admissions_datapoints > 0 - sample_wastewater = False - sampled_ed_visits, sampled_admissions, sampled_wastewater = ( None, None, None, ) + iedr = None + if sample_ed_visits: - sampled_ed_visits = self.ed_visit_obs_process_rv( + sampled_ed_visits, iedr = self.ed_visit_obs_process_rv( latent_infections=latent_infections, population_size=self.population_size, - data_observed_disease_ed_visits=( - data_observed_disease_ed_visits - ), - n_observed_disease_ed_visits_datapoints=( - n_observed_disease_ed_visits_datapoints - ), - n_weeks_post_init=n_weeks_post_init, - right_truncation_offset=right_truncation_offset, + data_observed=data.data_observed_disease_ed_visits, + n_datapoints=data.n_ed_visits_datapoints, + right_truncation_offset=data.right_truncation_offset, ) - if sample_admissions: + if sample_hospital_admissions: sampled_admissions = self.hosp_admit_obs_process_rv( latent_infections=latent_infections, first_latent_infection_dow=first_latent_infection_dow, population_size=self.population_size, - n_observed_hospital_admissions_datapoints=( - n_observed_hospital_admissions_datapoints - ), - data_observed_disease_hospital_admissions=( - data_observed_disease_hospital_admissions - ), + n_datapoints=data.n_hospital_admissions_datapoints, + data_observed=(data.data_observed_disease_hospital_admissions), + iedr=iedr, ) if sample_wastewater: sampled_wastewater = self.wastewater_obs_process_rv() return { "ed_visits": sampled_ed_visits, - "admissions": sampled_admissions, + "hospital_admissions": sampled_admissions, "wasewater": sampled_wastewater, } - - -def create_pyrenew_hew_model_from_stan_data(stan_data_file): - with open( - stan_data_file, - "r", - ) as file: - stan_data = json.load(file) - - i_first_obs_over_n_prior_a = stan_data["i_first_obs_over_n_prior_a"] - i_first_obs_over_n_prior_b = stan_data["i_first_obs_over_n_prior_b"] - i0_first_obs_n_rv = DistributionalVariable( - "i0_first_obs_n_rv", - dist.Beta(i_first_obs_over_n_prior_a, i_first_obs_over_n_prior_b), - ) - - mean_initial_exp_growth_rate_prior_mean = stan_data[ - "mean_initial_exp_growth_rate_prior_mean" - ] - mean_initial_exp_growth_rate_prior_sd = stan_data[ - "mean_initial_exp_growth_rate_prior_sd" - ] - initialization_rate_rv = DistributionalVariable( - "rate", - dist.TruncatedNormal( - loc=mean_initial_exp_growth_rate_prior_mean, - scale=mean_initial_exp_growth_rate_prior_sd, - low=-1, - high=1, - ), - ) - # could reasonably switch to non-Truncated - - r_prior_mean = stan_data["r_prior_mean"] - r_prior_sd = stan_data["r_prior_sd"] - r_logmean, r_logsd = convert_to_logmean_log_sd(r_prior_mean, r_prior_sd) - log_r_mu_intercept_rv = DistributionalVariable( - "log_r_mu_intercept_rv", dist.Normal(r_logmean, r_logsd) - ) - - eta_sd_sd = stan_data["eta_sd_sd"] - eta_sd_rv = DistributionalVariable( - "eta_sd", dist.TruncatedNormal(0, eta_sd_sd, low=0) - ) - - autoreg_rt_a = stan_data["autoreg_rt_a"] - autoreg_rt_b = stan_data["autoreg_rt_b"] - autoreg_rt_rv = DistributionalVariable( - "autoreg_rt", dist.Beta(autoreg_rt_a, autoreg_rt_b) - ) - - generation_interval_pmf_rv = DeterministicVariable( - "generation_interval_pmf", jnp.array(stan_data["generation_interval"]) - ) - - infection_feedback_pmf_rv = DeterministicVariable( - "infection_feedback_pmf", - jnp.array(stan_data["infection_feedback_pmf"]), - ) - - inf_feedback_prior_logmean = stan_data["inf_feedback_prior_logmean"] - inf_feedback_prior_logsd = stan_data["inf_feedback_prior_logsd"] - inf_feedback_strength_rv = TransformedVariable( - "inf_feedback", - DistributionalVariable( - "inf_feedback_raw", - dist.LogNormal( - inf_feedback_prior_logmean, inf_feedback_prior_logsd - ), - ), - transforms=transformation.AffineTransform(loc=0, scale=-1), - ) - # Could be reparameterized? - - p_hosp_prior_mean = stan_data["p_hosp_prior_mean"] - p_hosp_sd_logit = stan_data["p_hosp_sd_logit"] - - p_ed_mean_rv = DistributionalVariable( - "p_ed_mean", - dist.Normal( - transformation.SigmoidTransform().inv(p_hosp_prior_mean), - p_hosp_sd_logit, - ), - ) # logit scale - - p_ed_w_sd_sd = stan_data["p_hosp_w_sd_sd"] - p_ed_w_sd_rv = DistributionalVariable( - "p_ed_w_sd_sd", dist.TruncatedNormal(0, p_ed_w_sd_sd, low=0) - ) - - autoreg_p_ed_a = stan_data["autoreg_p_hosp_a"] - autoreg_p_ed_b = stan_data["autoreg_p_hosp_b"] - autoreg_p_ed_rv = DistributionalVariable( - "autoreg_p_ed", dist.Beta(autoreg_p_ed_a, autoreg_p_ed_b) - ) - - ed_wday_effect_rv = TransformedVariable( - "ed_wday_effect", - DistributionalVariable( - "ed_wday_effect_raw", - dist.Dirichlet( - jnp.array(stan_data["hosp_wday_effect_prior_alpha"]) - ), - ), - transformation.AffineTransform(loc=0, scale=7), - ) - - inf_to_ed_rv = DeterministicVariable( - "inf_to_ed", jnp.array(stan_data["inf_to_hosp"]) - ) - - inv_sqrt_phi_prior_mean = stan_data["inv_sqrt_phi_prior_mean"] - inv_sqrt_phi_prior_sd = stan_data["inv_sqrt_phi_prior_sd"] - - ed_neg_bin_concentration_rv = TransformedVariable( - "ed_visit_neg_bin_concentration", - DistributionalVariable( - "inv_sqrt_ed_visit_neg_bin_conc", - dist.TruncatedNormal( - loc=inv_sqrt_phi_prior_mean, - scale=inv_sqrt_phi_prior_sd, - low=1 / jnp.sqrt(5000), - ), - ), - transforms=transformation.PowerTransform(-2), - ) - - inf_to_hosp_admit_rv = DeterministicVariable( - "inf_to_hosp_admit", jnp.array(stan_data["inf_to_hosp"]) - ) - - hosp_admit_neg_bin_concentration_rv = TransformedVariable( - "hosp_admit_neg_bin_concentration", - DistributionalVariable( - "inv_sqrt_hosp_admit_neg_bin_conc", - dist.TruncatedNormal( - loc=inv_sqrt_phi_prior_mean, - scale=inv_sqrt_phi_prior_sd, - low=1 / jnp.sqrt(5000), - ), - ), - transforms=transformation.PowerTransform(-2), - ) - - uot = stan_data["uot"] - uot = len(jnp.array(stan_data["inf_to_hosp"])) - population_size = stan_data["state_pop"] - - data_observed_disease_ed_visits = jnp.array(stan_data["hosp"]) - ed_right_truncation_pmf_rv = DeterministicVariable( - "ed_visit_right_truncation_pmf", jnp.array(1) - ) - - my_latent_infection_model = LatentInfectionProcess( - i0_first_obs_n_rv=i0_first_obs_n_rv, - initialization_rate_rv=initialization_rate_rv, - log_r_mu_intercept_rv=log_r_mu_intercept_rv, - autoreg_rt_rv=autoreg_rt_rv, - eta_sd_rv=eta_sd_rv, # sd of random walk for ar process, - generation_interval_pmf_rv=generation_interval_pmf_rv, - infection_feedback_pmf_rv=infection_feedback_pmf_rv, - infection_feedback_strength_rv=inf_feedback_strength_rv, - n_initialization_points=uot, - ) - - my_ed_visit_obs_model = EDVisitObservationProcess( - p_ed_mean_rv=p_ed_mean_rv, - p_ed_w_sd_rv=p_ed_w_sd_rv, - autoreg_p_ed_rv=autoreg_p_ed_rv, - ed_wday_effect_rv=ed_wday_effect_rv, - inf_to_ed_rv=inf_to_ed_rv, - ed_neg_bin_concentration_rv=ed_neg_bin_concentration_rv, - ed_right_truncation_pmf_rv=ed_right_truncation_pmf_rv, - ) - - my_hosp_admit_obs_model = HospAdmitObservationProcess( - inf_to_hosp_admit_rv=inf_to_hosp_admit_rv, - hosp_admit_neg_bin_concentration_rv=( - hosp_admit_neg_bin_concentration_rv - ), - ) - - my_wastewater_obs_model = WastewaterObservationProcess() - - my_model = PyrenewHEWModel( - population_size=population_size, - latent_infection_process_rv=my_latent_infection_model, - ed_visit_obs_process_rv=my_ed_visit_obs_model, - hosp_admit_obs_process_rv=my_hosp_admit_obs_model, - wastewater_obs_process_rv=my_wastewater_obs_model, - ) - - return my_model, data_observed_disease_ed_visits From 7517a3357a4cf12db6b027e05229c5f8fd2b9fbe Mon Sep 17 00:00:00 2001 From: "Dylan H. Morris" Date: Wed, 8 Jan 2025 08:30:09 -0500 Subject: [PATCH 12/32] Remove old demo --- .../data/fit_hosp_only/stan_data.json | 386 ----------- demos/hosp_only_ww_model/model_comp.qmd | 137 ---- .../hosp_only_ww_model/pyrenew_hew_model.qmd | 122 ---- demos/hosp_only_ww_model/wwinference.Rmd | 637 ------------------ 4 files changed, 1282 deletions(-) delete mode 100644 demos/hosp_only_ww_model/data/fit_hosp_only/stan_data.json delete mode 100644 demos/hosp_only_ww_model/model_comp.qmd delete mode 100644 demos/hosp_only_ww_model/pyrenew_hew_model.qmd delete mode 100644 demos/hosp_only_ww_model/wwinference.Rmd diff --git a/demos/hosp_only_ww_model/data/fit_hosp_only/stan_data.json b/demos/hosp_only_ww_model/data/fit_hosp_only/stan_data.json deleted file mode 100644 index 4f745d60..00000000 --- a/demos/hosp_only_ww_model/data/fit_hosp_only/stan_data.json +++ /dev/null @@ -1,386 +0,0 @@ -{ - "gt_max": 15, - "hosp_delay_max": 55, - "inf_to_hosp": [0.0, 0.00469384736487552, 0.0145200073436112, 0.0278627741704387, 0.0423656492135518, 0.0558071445014868, 0.0665713169684116, 0.0737925805176124, 0.0772854627892072, 0.0773666390616176, 0.0746515449009948, 0.0698761436052596, 0.0637663813017696, 0.0569581929821651, 0.0499600186601535, 0.0431457477049282, 0.0367662806214046, 0.0309702535668238, 0.0258273785539499, 0.0213504646948306, 0.0175141661880584, 0.0142698211023571, 0.0115565159519833, 0.00930888979824423, 0.00746229206759215, 0.00595605679409682, 0.00473519993107751, 0.00375117728281842, 0.00296198928038098, 0.00233187862772459, 0.00183079868293457, 0.00143377454057296, 0.00107076258525208, 0.000773006742366448, 0.000539573690886396, 0.000364177599116743, 0.000237727628685578, 0.000150157714457011, 9.18283319498657e-05, 5.44079947589853e-05, 3.12548818921464e-05, 1.74202619730274e-05, 9.42698047424712e-06, 4.95614149002087e-06, 2.53275674485913e-06, 1.25854819834554e-06, 6.08116579596933e-07, 2.85572858589747e-07, 1.30129404249734e-07, 5.73280599448305e-08, 2.4219376577964e-08, 9.6316861194457e-09, 3.43804936850951e-09, 9.34806280366887e-10, 0.0], - "mwpd": 227000.0, - "ot": 90, - "n_subpops": 1, - "n_ww_sites": 0.0, - "n_ww_lab_sites": 0, - "owt": 0, - "oht": 90, - "n_censored": 0, - "n_uncensored": 0, - "uot": 50, - "ht": 35, - "n_weeks": 18, - "ind_m": [ - [1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], - [1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], - [1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], - [1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], - [1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], - [1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], - [1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0], - [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0], - [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0], - [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0], - [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0], - [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0] - ], - "tot_weeks": 25, - "p_hosp_m": [ - [1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], - [1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], - [1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], - [1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], - [1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], - [1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], - [1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0], - [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0], - [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0], - [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0], - [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0], - [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0], - [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0] - ], - "generation_interval": [0.161701189933765, 0.320525743089203, 0.242198071982593, 0.134825252524032, 0.0689141939998525, 0.0346219683116734, 0.017497710736154, 0.00908172017279556, 0.00483656086299504, 0.00260732346885217, 0.00143298046642562, 0.00082002579123121, 0.0004729600977183, 0.000284420637980485, 0.000179877924728358], - "ts": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15], - "state_pop": 3000000.0, - "subpop_size": [3000000.0], - "norm_pop": 3000000.0, - "ww_sampled_times": [], - "hosp_times": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90], - "ww_sampled_subpops": [], - "lab_site_to_subpop_map": [], - "ww_sampled_lab_sites": [], - "ww_log_lod": [], - "ww_censored": [], - "ww_uncensored": [], - "hosp": [25, 17, 25, 24, 26, 25, 27, 37, 32, 19, 18, 33, 22, 31, 34, 47, 45, 28, 38, 43, 33, 47, 38, 48, 43, 39, 52, 40, 49, 63, 47, 64, 58, 48, 92, 54, 53, 81, 56, 77, 84, 74, 62, 67, 74, 75, 89, 100, 65, 83, 96, 74, 59, 57, 60, 74, 70, 69, 60, 50, 75, 60, 53, 54, 50, 56, 48, 55, 41, 37, 50, 50, 39, 30, 31, 23, 35, 34, 33, 16, 23, 16, 21, 28, 29, 26, 30, 30, 27, 23], - "day_of_week": [5, 6, 7, 1, 2, 3, 4, 5, 6, 7, 1, 2, 3, 4, 5, 6, 7, 1, 2, 3, 4, 5, 6, 7, 1, 2, 3, 4, 5, 6, 7, 1, 2, 3, 4, 5, 6, 7, 1, 2, 3, 4, 5, 6, 7, 1, 2, 3, 4, 5, 6, 7, 1, 2, 3, 4, 5, 6, 7, 1, 2, 3, 4, 5, 6, 7, 1, 2, 3, 4, 5, 6, 7, 1, 2, 3, 4, 5, 6, 7, 1, 2, 3, 4, 5, 6, 7, 1, 2, 3, 4, 5, 6, 7, 1, 2, 3, 4, 5, 6, 7, 1, 2, 3, 4, 5, 6, 7, 1, 2, 3, 4, 5, 6, 7, 1, 2, 3, 4, 5, 6, 7, 1, 2, 3], - "log_conc": [], - "compute_likelihood": 1, - "include_ww": 0, - "include_hosp": 1, - "if_l": 15, - "infection_feedback_pmf": [0.161701189933765, 0.320525743089203, 0.242198071982593, 0.134825252524032, 0.0689141939998525, 0.0346219683116734, 0.017497710736154, 0.00908172017279556, 0.00483656086299504, 0.00260732346885217, 0.00143298046642562, 0.00082002579123121, 0.0004729600977183, 0.000284420637980485, 0.000179877924728358], - "viral_shedding_pars": [5.0, 1.0, 5.1, 0.5, 17.0, 3.0], - "autoreg_rt_a": 2, - "autoreg_rt_b": 40, - "autoreg_rt_subpop_a": 1, - "autoreg_rt_subpop_b": 4, - "autoreg_p_hosp_a": 1, - "autoreg_p_hosp_b": 100, - "inv_sqrt_phi_prior_mean": 0.1, - "inv_sqrt_phi_prior_sd": 0.1414214, - "r_prior_mean": 1, - "r_prior_sd": 1, - "log10_g_prior_mean": 12, - "log10_g_prior_sd": 2, - "i_first_obs_over_n_prior_a": 1.00402380952381, - "i_first_obs_over_n_prior_b": 5.99597619047619, - "hosp_wday_effect_prior_alpha": [5, 5, 5, 5, 5, 5, 5], - "mean_initial_exp_growth_rate_prior_mean": 0, - "mean_initial_exp_growth_rate_prior_sd": 0.01, - "sigma_initial_exp_growth_rate_prior_mode": 0, - "sigma_initial_exp_growth_rate_prior_sd": 0.05, - "mode_sigma_ww_site_prior_mode": 1, - "mode_sigma_ww_site_prior_sd": 1, - "sd_log_sigma_ww_site_prior_mode": 0, - "sd_log_sigma_ww_site_prior_sd": 0.693, - "eta_sd_sd": 0.01, - "eta_sd_mean": 0.0278, - "sigma_i_first_obs_prior_mode": 0, - "sigma_i_first_obs_prior_sd": 0.5, - "p_hosp_prior_mean": 0.01, - "p_hosp_sd_logit": 0.3, - "p_hosp_w_sd_sd": 0.01, - "ww_site_mod_sd_sd": 0.25, - "inf_feedback_prior_logmean": 4.498, - "inf_feedback_prior_logsd": 0.636, - "sigma_rt_prior": 0.1, - "log_phi_g_prior_mean": -2.302585, - "log_phi_g_prior_sd": 5, - "offset_ref_log_r_t_prior_mean": 0, - "offset_ref_log_r_t_prior_sd": 0.2, - "offset_ref_logit_i_first_obs_prior_mean": 0, - "offset_ref_logit_i_first_obs_prior_sd": 0.25, - "offset_ref_initial_exp_growth_rate_prior_mean": 0, - "offset_ref_initial_exp_growth_rate_prior_sd": 0.025 -} diff --git a/demos/hosp_only_ww_model/model_comp.qmd b/demos/hosp_only_ww_model/model_comp.qmd deleted file mode 100644 index 4485ff5f..00000000 --- a/demos/hosp_only_ww_model/model_comp.qmd +++ /dev/null @@ -1,137 +0,0 @@ ---- -title: "PyRenew and wwinference Fit and Forecast Comparison" -format: gfm -editor: visual ---- - -This document shows graphical comparisons for key variables in the PyRenew model fit to example data (notebooks/pyrenew_hew_model.qmd) and Stan model fit to example data (notebooks/wwinference.Rmd). In order to render this document, those notebooks must be rendered first. - -```{r} -#| output: false -library(tidyverse) -library(tidybayes) -library(fs) -library(cmdstanr) -library(posterior) -library(jsonlite) -library(scales) -library(here) -library(forecasttools) -ci_width <- c(0.5, 0.8, 0.95) -fit_dir <- here(path("demos/hosp_only_ww_model/data/fit_hosp_only")) -``` - -## Load Data - -```{r} -hosp_data <- tibble(.value = path(fit_dir, "stan_data", ext = "json") |> - jsonlite::read_json() |> - pluck("hosp") |> - unlist()) |> - mutate(time = row_number()) - -stan_files <- - dir_ls(fit_dir, - glob = "*wwinference*" - ) |> - enframe(name = NULL, value = "file_path") |> - mutate(file_details = path_ext_remove(path_file(file_path))) |> - separate_wider_delim(file_details, - delim = "-", - names = c("model", "date", "chain", "hash") - ) |> - mutate(date = ymd_hm(date)) |> - filter(date == max(date)) |> - pull(file_path) - - -stan_tidy_draws <- read_cmdstan_csv(stan_files)$post_warmup_draws |> - tidy_draws() - -pyrenew_tidy_draws <- - path(fit_dir, "inference_data", ext = "csv") |> - read_csv() |> - forecasttools::inferencedata_to_tidy_draws() -``` - -## Calculate Credible Intervals for Plotting - -```{r} -combined_ci_for_plotting <- - bind_rows( - deframe(pyrenew_tidy_draws)$posterior_predictive |> - gather_draws(observed_ed_visits[time], rt[time], iedr[time]) |> - median_qi(.width = ci_width) |> - mutate(model = "pyrenew"), - stan_tidy_draws |> - gather_draws(pred_hosp[time], rt[time], p_hosp[time]) |> - mutate(.variable = case_when( - .variable == "pred_hosp" ~ "observed_ed_visits", - .variable == "p_hosp" ~ "iedr", - TRUE ~ .variable - )) |> - median_qi(.width = ci_width) |> - mutate(model = "stan") - ) -``` - -## Hospital Admission Comparison - -```{r} -combined_ci_for_plotting |> - filter(.variable == "observed_ed_visits") |> - ggplot(aes(time, .value)) + - facet_wrap(~model) + - geom_lineribbon(aes(ymin = .lower, ymax = .upper), color = "#08519c") + - scale_fill_brewer( - name = "Credible Interval Width", - labels = ~ percent(as.numeric(.)) - ) + - geom_point(data = hosp_data) + - cowplot::theme_cowplot() + - ggtitle("Vignette Data Model Comparison") + - scale_y_continuous("Hospital Admissions") + - scale_x_continuous("Time") + - theme(legend.position = "bottom") -``` - -## Rt Comparions - -```{r} -combined_ci_for_plotting |> - filter(.variable == "rt") |> - ggplot(aes(time, .value)) + - facet_wrap(~model) + - geom_lineribbon(aes(ymin = .lower, ymax = .upper), color = "#08519c") + - scale_fill_brewer( - name = "Credible Interval Width", - labels = ~ percent(as.numeric(.)) - ) + - cowplot::theme_cowplot() + - ggtitle("Vignette Data Model Comparison") + - scale_y_log10("Rt", breaks = scales::log_breaks(n = 6)) + - scale_x_continuous("Time") + - theme(legend.position = "bottom") + - geom_hline(yintercept = 1, linetype = "dashed") -``` - -## IEDR Comparison - -```{r} -combined_ci_for_plotting |> - filter(.variable == "iedr") |> - ggplot(aes(time, .value)) + - facet_wrap(~model) + - geom_lineribbon(aes(ymin = .lower, ymax = .upper), color = "#08519c") + - scale_fill_brewer( - name = "Credible Interval Width", - labels = ~ percent(as.numeric(.)) - ) + - cowplot::theme_cowplot() + - ggtitle("Vignette Data Model Comparison") + - scale_y_log10("IEDR (p_hosp)", breaks = scales::log_breaks(n = 6)) + - scale_x_continuous("Time") + - theme(legend.position = "bottom") -``` - -IEDR lengths are different (Stan model generates an unnecessarily long version, see https://github.com/CDCgov/ww-inference-model/issues/43#issuecomment-2330269879) diff --git a/demos/hosp_only_ww_model/pyrenew_hew_model.qmd b/demos/hosp_only_ww_model/pyrenew_hew_model.qmd deleted file mode 100644 index 3e993fe0..00000000 --- a/demos/hosp_only_ww_model/pyrenew_hew_model.qmd +++ /dev/null @@ -1,122 +0,0 @@ ---- -title: "Replicating Hospital Only Model from ww-inference-model" -format: gfm -engine: jupyter ---- - -```{python} -# | label: setup -import jax -import numpyro -import arviz as az -import pyrenew_hew.plotting as plotting -from pyrenew_hew.pyrenew_hew_model import ( - create_pyrenew_hew_model_from_stan_data, -) - -numpyro.set_host_device_count(4) -``` - -## Background - -This tutorial provides a demonstration of our reimplementation of "Model 2" from the [`ww-inference-model` project](https://github.com/CDCgov/ww-inference-model). -The model is described [here](https://github.com/CDCgov/ww-inference-model/blob/main/model_definition.md). -Stan code for the model is [here](https://github.com/CDCgov/ww-inference-model/blob/main/inst/stan/wwinference.stan). - -The model we provide is designed to be fully-compatible with the stan_data generated in the that project. -We provide the stan data used in the `wwinference` [vignette](https://github.com/CDCgov/ww-inference-model/blob/main/vignettes/wwinference.Rmd) in the [`ww-inference-model` project](https://github.com/CDCgov/ww-inference-model). -The data is available in `notebooks/data/fit_hosp_only/stan_data.json`. -This data was generated by running `notebooks/wwinference.Rmd`, which replicates the original vignette and saves the relevant data. -This script also saves the posterior samples from the model for comparison to our own model. - -## Load Data and Create Priors - -We begin by loading the Stan data, converting it the correct inputs for our model, and definitng the model. - -```{python} -# | label: create model -my_pyrenew_hew_model, data_observed_disease_ed_visits = ( - create_pyrenew_hew_model_from_stan_data( - "data/fit_hosp_only/stan_data.json" - ) -) -``` - -# Simulate from the model - -We check that we can simulate from the prior predictive -```{python} -# | label: prior predictive -n_forecast_days = 35 - -prior_predictive = my_pyrenew_hew_model.prior_predictive( - n_observed_disease_ed_visits_datapoints=len(data_observed_disease_ed_visits) + n_forecast_days, - numpyro_predictive_args={"num_samples": 200}, -) -``` - -# Fit the model - -Now we can fit the model to the observed data: -```{python} -# | label: fit the model -my_pyrenew_hew_model.run( - num_warmup=500, - num_samples=500, - rng_key=jax.random.key(200), - data_observed_disease_ed_visits=data_observed_disease_ed_visits, - mcmc_args=dict(num_chains=4, progress_bar=False), - nuts_args=dict(find_heuristic_step_size=True), -) -``` - -Create the posterior predictive and forecast: - -```{python} -# | label: posterior predictive -posterior_predictive = my_pyrenew_hew_model.posterior_predictive( - n_observed_disease_ed_visits_datapoints=len(data_observed_disease_ed_visits) + n_forecast_days -) -``` - -## Prepare for plotting - -```{python} -# | label: prepare for plotting -import arviz as az - -idata = az.from_numpyro( - my_pyrenew_hew_model.mcmc, - posterior_predictive=posterior_predictive, - prior=prior_predictive, -) -``` - -## Plot Predictive Distributions - -```{python} -# | label: plot prior preditive -plotting.plot_predictive(idata, prior=True) -``` - -```{python} -# | label: plot posterior preditive -plotting.plot_predictive(idata) -``` - -## Plot all posteriors - -```{python} -# | label: plot all posteriors -for key in list(idata.posterior.keys()): - try: - plotting.plot_posterior(idata, key) - except Exception as e: - print(f"An error occurred while plotting {key}: {e}") -``` - -## Save for Post-Processing - -```{python} -idata.to_dataframe().to_csv("data/fit_hosp_only/inference_data.csv", index=False) -``` diff --git a/demos/hosp_only_ww_model/wwinference.Rmd b/demos/hosp_only_ww_model/wwinference.Rmd deleted file mode 100644 index 93b6c66e..00000000 --- a/demos/hosp_only_ww_model/wwinference.Rmd +++ /dev/null @@ -1,637 +0,0 @@ ---- -title: "Getting started with wwinference" -description: "A quick start example demonstrating the use of wwinference to jointly fit wastewater and hospital admissions data" -author: "Kaitlyn Johnson" -date: "2024-06-27" -output: - bookdown::html_vignette2: - fig_caption: yes - code_folding: show -pkgdown: - as_is: true -vignette: > - %\VignetteIndexEntry{Getting started with wwinference} - %\VignetteEngine{knitr::rmarkdown} - %\VignetteEncoding{UTF-8} ---- - -```{r setup, echo=FALSE} -knitr::opts_chunk$set(dev = "svg") -options(mc.cores = 4) # This tells cmdstan to run the 4 chains in parallel -``` - -# Quick start - -In this quick start, we demonstrate using `wwinference` to specify and fit a -minimal model using daily COVID-19 hospital admissions from a "global" population and -viral concentrations in wastewater from a few "local" wastewater treatment plants, -which come from subsets of the larger population. -In this context, when we say "global", we are referring to a larger -population e.g. a state, and when we say "local" we are referring to a smaller -subset of that population, e.g. a municipality within that state. -This is intended to be used as a reference for those -interested in fitting the `wwinference` model to their own data. - -# Packages - -In this quick start, we also use `dplyr` `tidybayes` and `ggplot2` packages. -These are installed as dependencies when `wwinference` is installed. - -```{r load-pkgs, warning=FALSE, message=FALSE} -library(wwinference) -library(dplyr) -library(ggplot2) -library(tidybayes) -``` - -# Data - -The model expects two types of data: daily counts of hospital admissions data -from the larger "global" population, and wastewater concentration -data from wastewater treatment plants whose catchment areas are contained within -the larger "global" population. For this quick start, we will use -simulated data, modeled after a hypothetical US state with 4 wastewater -treatment plants (also referred to as sites) reporting data on log scale viral -concentrations of SARS-COV-2, processed in 3 different labs, covering about 25% -of the state's population. This simulated data contains daily counts of the -total hospital admissions in a hypothetical US state from September 1, 2023 to -November 29, 2023. It contains wastewater log genome concentration data -from September 1, 2023 to December 1, 2023, with varying sampling frequencies. -We will be using this data to produce a forecast of COVID-19 hospital admissions -as of December 6, 2023. These data are provided as part of the package data. - -These data are already in a format that can be used for the `wwinference` package. -For the hospital admissions data, it contains: - -- a date (column `date`): the date of the observation, in this case, the date -the hospital admissions occurred -- a count (column `daily_hosp_admits`): the number of hospital admissions -observed on that day -- a population size (column `state_pop`): the population size covered -by the hospital admissions data, in this case, the size of the theoretical state. - -Additionally, we provide the `hosp_data_eval` dataset which contains the -simulated hospital admissions 28 days ahead of the forecast date, which can be -used to evaluate the model. - -For the wastewater data, the expcted format is a table of observations with the - -following columns. The wastewater data should not contain `NA` values for days with -missing observations, instead these should be excluded: -- a date (column `date`): the date the sample was collected -- a site indicator (column `site`): the unique identifier for the wastewater treatment plant -that the sample was collected from -- a lab indicator (column `lab`): the unique identifier for the lab where the sample was processed -- a concentration (column `log_genome_copies_ml`): the measured -log genome copies per mL for the given sample. This column should not -contain `NA` values, even if the observation for that sample is below the limit of -detection. -- a limit of detection (column `log_lod`): the natural log of the limit -of detection of the assay used to process the sample. Units should be the same -units as the concentration column. -- a site population size (column `site_pop`): the population size covered by the -wastewater catchment area of that site - - - -```{r load-data} -hosp_data <- wwinference::hosp_data -hosp_data_eval <- wwinference::hosp_data_eval -ww_data <- wwinference::ww_data - -head(ww_data) -head(hosp_data) -``` - - -# Pre-processing - -The user will need to provide data that is in a similar format to the package -data, as described above. This represents the bare minimum required data for a -single location and a single forecast date. We will need to do some -pre-processing to add some additional variables that the model will need to be -able apply features such as outlier exclusion and censoring of values below the -limit of detection. - - -## Parameters - -Get the example parameters from the package, which we will use here. -Note that some of these are COVID specific, others are more general to the -model, as indicated in the .toml file. - -```{r get-params} -params <- get_params( - system.file("extdata", "example_params.toml", - package = "wwinference" - ) -) -``` - -## Wastewater data pre-processing - -The `preprocess_ww_data()` function adds the following variables to the original -dataset. First, it assigns a unique identifier -the unique combinations of labs and sites, since this is the unit we will -use for estimating the observation error in the reported measurements. -Second it adds a column `below_lod` which is an indicator of whether the -reported concentration is above or below the limit of detection (LOD). If the -observation is below the LOD, the model will treat this observation as censored. -Third, it adds a column `flag_as_ww_outlier` that indicates whether the -measurement is identified as an outlier by our algorithm and the default -thresholds. While the default choice will be to exclude the measurements flagged -as outliers, the user can still choose to include these if they'd like later on. -The user must specify the name of the column containing the -concentration measurements (presumed to be in genome copies per mL) and the -name of the column containing the limit of detection for each measurement. The -function assumes that the original data contains the columns `date`, `site`, -and `lab`, and will return a dataframe with the column names needed to -pass to the downstream model fitting functions. - -```{r preprocess-ww-data} -ww_data_preprocessed <- preprocess_ww_data( - ww_data, - conc_col_name = "log_genome_copies_per_ml", - lod_col_name = "log_lod" -) -``` -Note that this function assumes that there are no missing values in the -concentration column. The package expects observations below the LOD will -be replaced with a numeric value below the LOD. If there are NAs in your dataset -when observations are below the LOD, we suggest replacing them with a value -below the LOD in upstream pre-processing. - -## Hospital admissions data pre-processing - -The `preprocess_count_data()` function standardizes the column names of the -resulting datafame. The user must specify the name of the column containing -the daily hospital admissions counts and the population size that the hospital -admissions are coming from (from in this case, a hypothetical US state). The -function assumes that the original data contains the column `date`, and will -return a dataframe with the column names needed to pass to the downstream model -fitting functions. - -```{r preprocess-hosp-data} -hosp_data_preprocessed <- preprocess_count_data( - hosp_data, - count_col_name = "daily_hosp_admits", - pop_size_col_name = "state_pop" -) -``` - -We'll make some plots of the data just to make sure it looks like what we'd expect: - -```{r wastewater-time-series-fig, out.width='100%'} -ggplot(ww_data_preprocessed) + - geom_point( - aes( - x = date, y = log_genome_copies_per_ml, - color = as.factor(lab_site_name) - ), - show.legend = FALSE, - size = 0.5 - ) + - geom_point( - data = ww_data_preprocessed |> filter( - log_genome_copies_per_ml <= log_lod - ), - aes(x = date, y = log_genome_copies_per_ml, color = "red"), - show.legend = FALSE, size = 0.5 - ) + - scale_x_date( - date_breaks = "2 weeks", - labels = scales::date_format("%Y-%m-%d") - ) + - geom_hline(aes(yintercept = log_lod), linetype = "dashed") + - facet_wrap(~lab_site_name, scales = "free") + - xlab("") + - ylab("Genome copies/mL") + - ggtitle("Lab-site level wastewater concentration") + - theme_bw() + - theme( - axis.text.x = element_text( - size = 5, vjust = 1, - hjust = 1, angle = 45 - ), - axis.title.x = element_text(size = 12), - axis.text.y = element_text(size = 5), - strip.text = element_text(size = 5), - axis.title.y = element_text(size = 12), - plot.title = element_text( - size = 10, - vjust = 0.5, hjust = 0.5 - ) - ) - - -ggplot(hosp_data_preprocessed) + - # Plot the hospital admissions data that we will evaluate against in white - geom_point( - data = hosp_data_eval, aes( - x = date, - y = daily_hosp_admits_for_eval - ), - shape = 21, color = "black", fill = "white" - ) + - # Plot the data we will calibrate to - geom_point(aes(x = date, y = count)) + - scale_x_date( - date_breaks = "2 weeks", - labels = scales::date_format("%Y-%m-%d") - ) + - xlab("") + - ylab("Daily hospital admissions") + - ggtitle("State level hospital admissions") + - theme_bw() + - theme( - axis.text.x = element_text( - size = 8, vjust = 1, - hjust = 1, angle = 45 - ), - axis.title.x = element_text(size = 12), - axis.title.y = element_text(size = 12), - plot.title = element_text( - size = 10, - vjust = 0.5, hjust = 0.5 - ) - ) -``` - -The closed circles indicate the data the model will be calibrated to, while -the open circles indicate data we later observe after the forecast date. - -## Data exclusion - -As an optional additional pre-processing step, the user can decide to exclude -certain data points in the model fit procedure. For example, -we recommend excluding the flagged wastewater concentration outliers. To do so -we will use the `indicate_ww_exclusions()` function, which will add the -flagged outliers to the exclude column where indicated. - -```{r indicate-ww-exclusions} -ww_data_to_fit <- indicate_ww_exclusions( - ww_data_preprocessed, - outlier_col_name = "flag_as_ww_outlier", - remove_outliers = TRUE -) -``` - -# Model specification: - -We will need to set some metadata to facilitate model specification. -This includes: -- forecast date (the date we are making a forecast) -- number of days to calibrate the model for -- number of days to forecast beyond the forecast date -- specification of the generation interval, in this case for COVID-19 -- specification of the delay from infection to the count data, in this case - from infection to COVID-19 hospital admission - -## Calibration time and forecast time - -The calibration time represents the number of days to calibrate the count data -to. This must be less than or equal to the number of rows in `hosp_data`. The -forecast horizon represents the number of days from the forecast date to -generate forecasted hospital admissions for. Typically, the hospital admissions -data will not be complete up until the forecast date, and we will refer to the -time between the last hospital admissions data point and the forecast date as -the nowcast time. The model will "forecast" this period, in addition to the -specified forecast horizon. - -```{r set-forecast-params} -forecast_date <- "2023-12-06" -calibration_time <- 90 -forecast_horizon <- 28 -``` - -## Delay distributions - -We will pass in probability mass functions (PMFs) that are specific to -COVID, and to the delay from infections to hospital admissions, the count -data we are using to fit the model. If using a different pathogen or a -different count dataset, these PMFs need to be replaced. We provide them as -package data here. The model expects that these are discrete daily PMFs. - -Additionally, the model requires specifying a delay distribution for the -infection feedback term, which essentially describes the delay at which -high incident infections results in negative feedback on future infections -(due to susceptibility, behavior changes, policies to reduce transmission, -etc.). We by default set this as the generation interval, but this can be -modified with any discrete daily PMF. - -```{r set-delay-distributions} -generation_interval <- wwinference::default_covid_gi -inf_to_hosp <- wwinference::default_covid_inf_to_hosp - -# Assign infection feedback equal to the generation interval -infection_feedback_pmf <- generation_interval -``` - -We will pass these to the `get_model_spec()` function of the `wwinference()` model, -along with the other specified parameters above. - -# Precompiling the model - -As `wwinference` uses `cmdstan` to fit its models, it is necessary to first -compile the model. This can be done using the `compile_model()` function. - -```{r compile-model} -model <- wwinference::compile_model() -``` - -# Fitting the model - -We're now ready to fit the model using the “No-U-Turn Sampler Markov chain -Monte Carlo” method. This is a type of Hamiltonian Monte Carlo (HMC) algorithm -and is the core fitting method used by `cmdstan`. The user can adjust the MCMC -settings (see the documentation for `get_mcmc_options()`), -however this vignette will use -the default parameter settings which includes running 4 parallel chains with -750 warm up iterations, 500 sampling iterations for each chain, a target average -acceptance probability of 0.95 and a maximum tree depth of 12. The default is -not to set a the seed for the random number generator for the MCMC model runs -(which would produce stochastic results each time the model is run), but for -reproducibility we will set the seed of the Stan PRNG to `123` in this vignette. - -When applying the model to real data, experimenting with these MCMC settings may make it possible -to achieve improved model convergence and/or faster model fitting times. See the [Stan User's Guide](https://mc-stan.org/docs/cmdstan-guide/diagnose_utility.html#building-the-diagnose-command) for an introduction to No-U-Turn sampler convergence diagnostics and configuration parameters. - -We also pass our preprocessed datasets (`ww_data_to_fit` and -`hosp_data_preprocessed`), specify our model using `get_model_spec()`, -set the MCMC settings by passing a list of arguments to `fit_opts` that will be passed to the `cmdstanr::sample()` function, and pass in our -pre-compiled model(`model`) to `wwinference()` where they are combined and -used to fit the model. - -```{r fitting-model, warning=FALSE, message=FALSE} -ww_fit <- wwinference( - ww_data = ww_data_to_fit, - count_data = hosp_data_preprocessed, - forecast_date = forecast_date, - calibration_time = calibration_time, - forecast_horizon = forecast_horizon, - model_spec = get_model_spec( - generation_interval = generation_interval, - inf_to_count_delay = inf_to_hosp, - infection_feedback_pmf = infection_feedback_pmf, - params = params - ), - fit_opts = list(seed = 123), - compiled_model = model -) -``` - -# The `wwinference_fit` object - -The `wwinference()` function returns a `wwinference_fit` object which includes -the underlying and the underlying -[`CmdStanModel` object](https://mc-stan.org/cmdstanr/reference/CmdStanModel.html) - (`fit`), a list of the two sources of input -data (`raw_input_data`), the list of the arguments passed to stan -(`stan_data_list`), and the list of the MCMC options (`fit_opts`) passed to -stan. We show how to generate downstream elements from a `wwinference_fit` -object. - -`wwinference_fit` objects currently have the following methods available: - -```{r show-methods} -methods(class = "wwinference_fit") -``` - -The `print` and `summary` methods can provide some information about the model. In particular, the `summary` method is a wrapper for `cmdstanr::summary()`: - -```{r print-and-summary} -print(ww_fit) -summary(ww_fit) -``` - -## Extracting the posterior predictions - -Working with the posterior predictions alongside the input data can be useful -to check that your model is fitting the data well and that the -nowcasted/forecast quantities look reasonable. - -We can use the `get_draws()` function to generate dataframes that contain -the posterior draws of the estimated, nowcasted, and forecasted quantities, -joined to the relevant data. - -We can generate this directly on the output of `wwinference()` using: -```{r extracting-draws} -draws <- get_draws(ww_fit) - -print(draws) -``` - -Note that by default the `get_draws()` function will return a list of class `wwinference_fit_draws` -which contains separate dataframes of the posterior draws for predicted counts (`"predicted_counts"`), -wastewater concentrations (`"predicted_ww"`), global $\mathcal{R}(t)$ (`"global_rt"`) estimates, and -subpopulation-level $\mathcal{R}(t)$ estimates ("`subpop_rt"`). -To examine a particular variable (e.g. `"predicted_counts"` for posterior -predicted hospital admissions in this case), access the corresponding tibble using the `$` operator. - - -You can also specify which outputs to return using the `what` argument. -```{r example subset draws} -hosp_draws <- get_draws(ww_fit, what = "predicted_counts") -hosp_draws_df <- hosp_draws$predicted_counts -head(hosp_draws_df) -``` - - - -### Using explicit passed arguments rather than S3 methods - -Rather than using S3 methods supplied for `wwinference()`, the elements in the -`wwinference_fit` object can also be used directly to create this dataframe. -This is demonstrated below: - -```{r extracting-draws-explicit, eval = FALSE} -draws_explicit <- get_draws( - x = ww_fit$raw_input_data$input_ww_data, - count_data = ww_fit$raw_input_data$input_count_data, - date_time_spine = ww_fit$raw_input_data$date_time_spine, - site_subpop_spine = ww_fit$raw_input_data$site_subpop_spine, - lab_site_subpop_spine = ww_fit$raw_input_data$lab_site_subpop_spine, - stan_data_list = ww_fit$stan_data_list, - fit_obj = ww_fit$fit -) -``` - - -## Plotting the outputs - -We can create plots of the outputs using corresponding dataframes in the `draws` -object and the fitting wrapper functions. Note that by default, these plots -will not include outliers that were flagged for exclusion. Data points -that are below the LOD will be plotted in blue. - -```{r generating-figures, out.width='100%'} -plot_hosp_with_eval <- get_plot_forecasted_counts( - draws = draws$predicted_counts, - forecast_date = forecast_date, - count_data_eval = hosp_data_eval, - count_data_eval_col_name = "daily_hosp_admits_for_eval" -) -plot_hosp_with_eval - - -plot_ww <- get_plot_ww_conc(draws$predicted_ww, forecast_date) -plot_ww - -plot_state_rt <- get_plot_global_rt(draws$global_rt, forecast_date) -plot_state_rt - -plot_subpop_rt <- get_plot_subpop_rt(draws$subpop_rt, forecast_date) -plot_subpop_rt -``` - -To plot the forecasts without the retrospectively observed hospital admissions, -simply don't pass them to the plotting function. -```{r plot-only-count-forecasts, out.width='100%'} -plot_hosp <- get_plot_forecasted_counts( - draws = draws$predicted_counts, - forecast_date = forecast_date -) -plot_hosp -``` - -The previous three are equivalent to calling the `plot` method of `wwinference_fit_draws` using the `what` argument: - -```{r, out.width='100%'} -plot( - x = draws, - what = "predicted_counts", - count_data_eval = hosp_data_eval, - count_data_eval_col_name = "daily_hosp_admits_for_eval", - forecast_date = forecast_date -) -plot(draws, what = "predicted_ww", forecast_date = forecast_date) -plot(draws, what = "global_rt", forecast_date = forecast_date) -plot(draws, what = "subpop_rt", forecast_date = forecast_date) -``` - -## Diagnostics - -We strongly recommend running diagnostics as a post-processing step on the -model outputs. - -This can be done by passing the output of - -`wwinference()` into the `get_model_diagnostic_flags()`, `summary_diagnostics()` -and `parameter_diagnostics()` functions. - -`get_model_diagnostic_flags()` will print out a table of any flags, if any of -these are TRUE, it will print out a warning. -We have set default thresholds on the model diagnostics for production-level -runs, we recommend adjusting as needed (see below) - -To further troubleshoot, you can look at -the summary diagnostics using the `summary_diagnostics()` function -and the diagnostics of the individual parameters using -the `parameter_diagnostics()` function. - -For further information on troubleshooting the model diagnostics, -we recommend the (bayesplot tutorial)[https://mc-stan.org/bayesplot/articles/visual-mcmc-diagnostics.html]. - -You can access the CmdStan object directly using `ww_fit$fit$result` - -```{r diagnostics-using-S3-methods} -convergence_flag_df <- get_model_diagnostic_flags(ww_fit) -print(convergence_flag_df) -summary_diagnostics(ww_fit) -param_diagnostics <- parameter_diagnostics(ww_fit) -head(param_diagnostics) -``` - -This can also be done explicitly by parsing the elements of the -`wwinference_fit` object into the custom functions we built / directly calling -`CmdStan`'s built in functions. - -Start by passing the stan fit object(`ww_fit$fit$result`) into the -`get_model_diagnostic_flags()` and adjusting the thresholds if desired. - -Then, we recommend looking at the diagnostics summary provided by `CmdStan`, -which we had wrapped into the `parameter_diagnostics()` call above. Lastly, -we recommend looking at the individual model parameters provided by `CmdStan` -to identify which components of the model might be driving the convergence -issues. - -For further information on troubleshooting the model diagnostics, -we recommend the [bayesplot tutorial](https://mc-stan.org/bayesplot/articles/visual-mcmc-diagnostics.html). - -```{r diagnostics-explicit} -convergence_flag_df <- get_model_diagnostic_flags( - x = ww_fit$fit$result, - ebmfi_tolerance = 0.2, - divergences_tolerance = 0.01, - frac_high_rhat_tolerance = 0.05, - rhat_tolerance = 1.05, - max_tree_depth_tol = 0.01 -) -# Get the tables using the CmdStan functions via wrappers -summary(ww_fit) -parameter_diagnostics(ww_fit, quiet = TRUE) -head(convergence_flag_df) -``` - -## Fit to only hospital admissions data - -The package also has functionality to fit the model without wastewater data. -This can be useful when doing comparisons of the impact the wastewater data -has on the forecast, or as a part of a pipeline where one might choose to -rely on the admissions only model if there are covergence or known data issues -with the wastewater data. - -```{r fit-hosp-only, warning=FALSE, message=FALSE} -fit_hosp_only <- wwinference( - ww_data = ww_data_to_fit, - count_data = hosp_data_preprocessed, - forecast_date = forecast_date, - calibration_time = calibration_time, - forecast_horizon = forecast_horizon, - model_spec = get_model_spec( - generation_interval = generation_interval, - inf_to_count_delay = inf_to_hosp, - infection_feedback_pmf = infection_feedback_pmf, - include_ww = FALSE, - params = params - ), - fit_opts = list(seed = 123), - compiled_model = model -) -``` - -```{r plot-hosp-only, out.width='100%'} -draws_hosp_only <- get_draws(fit_hosp_only) -plot(draws_hosp_only, - what = "predicted_counts", - count_data_eval = hosp_data_eval, - count_data_eval_col_name = "daily_hosp_admits_for_eval", - forecast_date = forecast_date -) -plot(draws_hosp_only, what = "global_rt", forecast_date = forecast_date) -``` - -```{r copy results} -library(fs) - -fit_dir <- path("demos", "ww_model", "data", "fit") -fit_hosp_only_dir <- path( - "demos", "hosp_only_ww_model", "data", "fit_hosp_only" -) -dir_create(fit_dir) -dir_create(fit_hosp_only_dir) - -file_copy( - path = ww_fit$fit$result$data_file(), - new_path = path(fit_dir, "stan_data", ext = "json"), - overwrite = TRUE -) -file_copy( - path = fit_hosp_only$fit$result$data_file(), - new_path = path(fit_hosp_only_dir, "stan_data", ext = "json"), - overwrite = TRUE -) - -ww_fit$fit$result$save_output_files(fit_dir) -fit_hosp_only$fit$result$save_output_files(fit_hosp_only_dir) -``` From 2cfefdd14db48759f4a3cb4b7d6d59a8dc1d1197 Mon Sep 17 00:00:00 2001 From: "Dylan H. Morris" Date: Wed, 8 Jan 2025 08:35:37 -0500 Subject: [PATCH 13/32] Fix missing close parenthesis, run precommit --- pipelines/build_pyrenew_model.py | 6 ++---- pyrenew_hew/pyrenew_hew_model.py | 15 ++++++++------- 2 files changed, 10 insertions(+), 11 deletions(-) diff --git a/pipelines/build_pyrenew_model.py b/pipelines/build_pyrenew_model.py index 40d8a576..5e71836b 100644 --- a/pipelines/build_pyrenew_model.py +++ b/pipelines/build_pyrenew_model.py @@ -7,7 +7,6 @@ from pyrenew.deterministic import DeterministicVariable from pyrenew_hew.pyrenew_hew_data import PyrenewHEWData - from pyrenew_hew.pyrenew_hew_model import ( EDVisitObservationProcess, HospAdmitObservationProcess, @@ -108,7 +107,6 @@ def build_model_from_dir( ), ihr_rel_iedr_rv=priors["ihr_rel_iedr_rv"], ihr_rv=priors["ihr_rv"], - ) # placeholder @@ -119,7 +117,7 @@ def build_model_from_dir( latent_infection_process_rv=my_latent_infection_model, ed_visit_obs_process_rv=my_ed_visit_obs_model, hosp_admit_obs_process_rv=my_hosp_admit_obs_model, - wastewater_obs_process_rv=my_wastewater_obs_model + wastewater_obs_process_rv=my_wastewater_obs_model, ) my_data = PyrenewHEWData( @@ -127,7 +125,7 @@ def build_model_from_dir( data_observed_disease_hospital_admissions=( data_observed_disease_hospital_admissions ), - data_observed_disease_wastewater=None, # placeholder + data_observed_disease_wastewater=None, # placeholder right_truncation_offset=right_truncation_offset, first_ed_visits_date=first_ed_visits_date, first_hospital_admissions_date=first_hospital_admissions_date, diff --git a/pyrenew_hew/pyrenew_hew_model.py b/pyrenew_hew/pyrenew_hew_model.py index 6996887c..87633a62 100644 --- a/pyrenew_hew/pyrenew_hew_model.py +++ b/pyrenew_hew/pyrenew_hew_model.py @@ -248,11 +248,12 @@ def __init__( inf_to_hosp_admit_rv: RandomVariable, hosp_admit_neg_bin_concentration_rv: RandomVariable, ihr_rv: RandomVariable = None, - ihr_rel_iedr_rv: RandomVariable = None) -> None: + ihr_rel_iedr_rv: RandomVariable = None, + ) -> None: self.inf_to_hosp_admit_rv = inf_to_hosp_admit_rv self.hosp_admit_neg_bin_concentration_rv = ( - hosp_admit_neg_bin_concentration_rv - ), + (hosp_admit_neg_bin_concentration_rv), + ) self.ihr_rv = ihr_rv self.ihr_rel_iedr_rv = ihr_rel_iedr_rv @@ -305,10 +306,11 @@ def sample( "observed_hospital_admissions", concentration_rv=self.hosp_admit_neg_bin_concentration_rv, ) - + sampled_admissions = hospital_admissions_obs_rv( - mu=predicted_weekly_admissions[-n_datapoints:], - obs=data_observed, + mu=predicted_weekly_admissions[-n_datapoints:], obs=data_observed + ) + return sampled_admissions @@ -385,7 +387,6 @@ def sample( n_datapoints=data.n_hospital_admissions_datapoints, data_observed=(data.data_observed_disease_hospital_admissions), iedr=iedr, - ) if sample_wastewater: sampled_wastewater = self.wastewater_obs_process_rv() From 813bfc2c70562effdfc162e5b491864fd931fda4 Mon Sep 17 00:00:00 2001 From: "Dylan H. Morris" Date: Wed, 8 Jan 2025 09:19:57 -0500 Subject: [PATCH 14/32] Fix incorrectly added parenthesis --- pyrenew_hew/pyrenew_hew_model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyrenew_hew/pyrenew_hew_model.py b/pyrenew_hew/pyrenew_hew_model.py index 87633a62..8373ce73 100644 --- a/pyrenew_hew/pyrenew_hew_model.py +++ b/pyrenew_hew/pyrenew_hew_model.py @@ -252,7 +252,7 @@ def __init__( ) -> None: self.inf_to_hosp_admit_rv = inf_to_hosp_admit_rv self.hosp_admit_neg_bin_concentration_rv = ( - (hosp_admit_neg_bin_concentration_rv), + hosp_admit_neg_bin_concentration_rv ) self.ihr_rv = ihr_rv self.ihr_rel_iedr_rv = ihr_rel_iedr_rv From 44716fe00d2b2d810f5e724bdf42993de4cafe80 Mon Sep 17 00:00:00 2001 From: "Dylan H. Morris" Date: Wed, 8 Jan 2025 09:22:44 -0500 Subject: [PATCH 15/32] Fix upgrade issue with pygit2 by holding at 1.16.0 for now --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index e561d840..b1b51463 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -16,7 +16,7 @@ ipykernel = "^6.29.5" polars = "^1.5.0" pypdf = "^5.1.0" pyarrow = "^18.0.0" -pygit2 = "^1.16.0" +pygit2 = "1.16.0" azuretools = {git = "https://github.com/cdcgov/cfa-azuretools"} forecasttools = {git = "https://github.com/CDCgov/forecasttools-py"} tomli-w = "^1.1.0" From 8584cbe0b68636709b23148567efee49273eb2fa Mon Sep 17 00:00:00 2001 From: "Dylan H. Morris" Date: Wed, 8 Jan 2025 13:29:22 -0500 Subject: [PATCH 16/32] Try fitting Pyrenew-HE --- pipelines/fit_pyrenew_model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pipelines/fit_pyrenew_model.py b/pipelines/fit_pyrenew_model.py index 432b2fd9..33fb86b9 100644 --- a/pipelines/fit_pyrenew_model.py +++ b/pipelines/fit_pyrenew_model.py @@ -30,7 +30,7 @@ def fit_and_save_model( my_model.run( data=my_data, sample_ed_visits=True, - sample_hospital_admissions=False, + sample_hospital_admissions=True, sample_wastewater=False, num_warmup=n_warmup, num_samples=n_samples, From 376469a3a42c7e19130b47f1572d91b59dc453ef Mon Sep 17 00:00:00 2001 From: "Dylan H. Morris" Date: Mon, 13 Jan 2025 23:33:57 +0000 Subject: [PATCH 17/32] Use mmwr epiweek function --- pyrenew_hew/pyrenew_hew_model.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/pyrenew_hew/pyrenew_hew_model.py b/pyrenew_hew/pyrenew_hew_model.py index 8373ce73..ce904335 100644 --- a/pyrenew_hew/pyrenew_hew_model.py +++ b/pyrenew_hew/pyrenew_hew_model.py @@ -9,7 +9,7 @@ from pyrenew.arrayutils import repeat_until_n, tile_until_n from pyrenew.convolve import ( compute_delay_ascertained_incidence, - daily_to_weekly, + daily_to_mmwr_epiweekly, ) from pyrenew.latent import ( InfectionInitializationProcess, @@ -296,10 +296,9 @@ def sample( first_latent_infection_dow + longest_possible_delay ) % 7 - predicted_weekly_admissions = daily_to_weekly( + predicted_weekly_admissions = daily_to_mmwr_epiweekly( latent_hospital_admissions, input_data_first_dow=first_latent_admission_dow, - week_start_dow=6, # MMWR epiweek, starts Sunday ) hospital_admissions_obs_rv = NegativeBinomialObservation( From 43bda9d0edb507fdb03112bed44ea1a0eb9984d7 Mon Sep 17 00:00:00 2001 From: "Dylan H. Morris" Date: Mon, 13 Jan 2025 23:37:13 +0000 Subject: [PATCH 18/32] Fix type hint for dict --- pyrenew_hew/pyrenew_hew_model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyrenew_hew/pyrenew_hew_model.py b/pyrenew_hew/pyrenew_hew_model.py index ce904335..82c336d0 100644 --- a/pyrenew_hew/pyrenew_hew_model.py +++ b/pyrenew_hew/pyrenew_hew_model.py @@ -352,7 +352,7 @@ def sample( sample_ed_visits: bool = False, sample_hospital_admissions: bool = False, sample_wastewater: bool = False, - ) -> ArrayLike: # numpydoc ignore=GL08 + ) -> dict[str, ArrayLike]: # numpydoc ignore=GL08 latent_infections = self.latent_infection_process_rv( n_days_post_init=data.n_days_post_init, ) From 7e246569ff60331b8a2f0f2a52ec2b5856a2def8 Mon Sep 17 00:00:00 2001 From: "Dylan H. Morris" Date: Mon, 13 Jan 2025 23:41:47 +0000 Subject: [PATCH 19/32] Comment priors --- pipelines/priors/prod_priors.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/pipelines/priors/prod_priors.py b/pipelines/priors/prod_priors.py index a50c3c6e..2d5c0138 100644 --- a/pipelines/priors/prod_priors.py +++ b/pipelines/priors/prod_priors.py @@ -37,6 +37,7 @@ ) # Could be reparameterized? +# low confidence logit-Normal p_ed_visit_mean_rv = DistributionalVariable( "p_ed_visit_mean", dist.Normal( @@ -45,6 +46,7 @@ ), ) # logit scale +# low confidence logit-Normal with same mode as IEDR ihr_rv = TransformedVariable( "ihr", DistributionalVariable( @@ -76,6 +78,8 @@ transformation.AffineTransform(loc=0, scale=7), ) +# low confidence with a mode at equivalence and +# plausiblity of 2x or 1/2 the rate ihr_rel_iedr_rv = DistributionalVariable( "ihr_rel_iedr", dist.LogNormal(0, jnp.log(jnp.sqrt(2))) ) From d239a476deb0ccc2bcd6a1a5ded1fb1481f6ead5 Mon Sep 17 00:00:00 2001 From: "Dylan H. Morris" Date: Tue, 14 Jan 2025 17:54:58 -0500 Subject: [PATCH 20/32] Unpin and update pygit2 --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index b1b51463..c705a0de 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -16,7 +16,7 @@ ipykernel = "^6.29.5" polars = "^1.5.0" pypdf = "^5.1.0" pyarrow = "^18.0.0" -pygit2 = "1.16.0" +pygit2 = "^1.17.0" azuretools = {git = "https://github.com/cdcgov/cfa-azuretools"} forecasttools = {git = "https://github.com/CDCgov/forecasttools-py"} tomli-w = "^1.1.0" From e0651d1a8860efa88fbdd6427917ca8bc3f491af Mon Sep 17 00:00:00 2001 From: "Dylan H. Morris" Date: Fri, 17 Jan 2025 14:46:49 -0500 Subject: [PATCH 21/32] Use dictionaries and filter(None) for dates --- pyrenew_hew/pyrenew_hew_data.py | 24 ++++++++++++------------ 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/pyrenew_hew/pyrenew_hew_data.py b/pyrenew_hew/pyrenew_hew_data.py index b074cb15..7512ebf0 100644 --- a/pyrenew_hew/pyrenew_hew_data.py +++ b/pyrenew_hew/pyrenew_hew_data.py @@ -90,27 +90,27 @@ def last_wastewater_date(self): @property def first_data_dates(self): - return [ - self.first_ed_visits_date, - self.first_hospital_admissions_date, - self.first_wastewater_date, - ] + return dict( + ed_visits=self.first_ed_visits_date, + hospital_admissions=self.first_hospital_admissions_date, + wastewater=self.first_wastewater_date, + ) @property def last_data_dates(self): - return [ - self.last_ed_visits_date, - self.last_hospital_admissions_date, - self.last_wastewater_date, - ] + return dict( + ed_visits=self.last_ed_visits_date, + hospital_admissions=self.last_hospital_admissions_date, + wastewater=self.last_wastewater_date, + ) @property def first_data_date_overall(self): - return min([x for x in self.first_data_dates if x is not None]) + return min(filter(None, self.first_data_dates.values())) @property def last_data_date_overall(self): - return max([x for x in self.last_data_dates if x is not None]) + return min(filter(None, self.last_data_dates.values())) @property def n_days_post_init(self): From 4c8819478264077e95704077c15e60403727414f Mon Sep 17 00:00:00 2001 From: "Dylan H. Morris" Date: Fri, 17 Jan 2025 14:48:39 -0500 Subject: [PATCH 22/32] Update pyrenew_hew/pyrenew_hew_model.py Co-authored-by: Damon Bayer --- pyrenew_hew/pyrenew_hew_model.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/pyrenew_hew/pyrenew_hew_model.py b/pyrenew_hew/pyrenew_hew_model.py index c4854bb3..7d912aa6 100644 --- a/pyrenew_hew/pyrenew_hew_model.py +++ b/pyrenew_hew/pyrenew_hew_model.py @@ -494,7 +494,9 @@ def sample( data.first_data_date_overall - datetime.timedelta(days=n_init_days) ).weekday() - sampled_ed_visits, sampled_admissions, sampled_wastewater = ( + sampled_ed_visits = None + sampled_admissions = None + sampled_wastewater = None None, None, None, From a0cd3a8d63a0337d17febebaf5ef5430230e5353 Mon Sep 17 00:00:00 2001 From: "Dylan H. Morris" Date: Fri, 17 Jan 2025 14:48:59 -0500 Subject: [PATCH 23/32] Update pyrenew_hew/pyrenew_hew_model.py Co-authored-by: Damon Bayer --- pyrenew_hew/pyrenew_hew_model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyrenew_hew/pyrenew_hew_model.py b/pyrenew_hew/pyrenew_hew_model.py index 7d912aa6..96a9ee53 100644 --- a/pyrenew_hew/pyrenew_hew_model.py +++ b/pyrenew_hew/pyrenew_hew_model.py @@ -439,7 +439,7 @@ def sample( concentration_rv=self.hosp_admit_neg_bin_concentration_rv, ) - sampled_admissions = hospital_admissions_obs_rv( + observed_hospital_admissions = hospital_admissions_obs_rv( mu=predicted_weekly_admissions[-n_datapoints:], obs=data_observed ) From 1aa327af731d79ce6a984b28dc1a6ed37d4dfbf2 Mon Sep 17 00:00:00 2001 From: "Dylan H. Morris" Date: Fri, 17 Jan 2025 14:56:49 -0500 Subject: [PATCH 24/32] IEDR/IHR conditional logic --- pyrenew_hew/pyrenew_hew_model.py | 39 +++++++++++++++++++++----------- 1 file changed, 26 insertions(+), 13 deletions(-) diff --git a/pyrenew_hew/pyrenew_hew_model.py b/pyrenew_hew/pyrenew_hew_model.py index 96a9ee53..dcbdfbe8 100644 --- a/pyrenew_hew/pyrenew_hew_model.py +++ b/pyrenew_hew/pyrenew_hew_model.py @@ -407,18 +407,35 @@ def sample( """ inf_to_hosp_admit = self.inf_to_hosp_admit_rv() - if iedr is not None: - ihr_rel_iedr = self.ihr_rel_iedr_rv() - ihr = iedr[0] * ihr_rel_iedr + if self.ihr_rel_iedr_rv is not None and self.ihr_rv is not None: + raise ValueError( + "IHR must either be specified " + "in absolute terms by a non-None " + "`ihr_rv` or specified relative " + "to the IEDR by a non-None " + "`ihr_rel_iedr_rv`, but not both. " + "Got non-None RVs for both " + "quantities" + ) + elif self.ihr_rel_iedr_rv is not None: + if iedr is None: + raise ValueError( + "Must pass in an IEDR to " "compute IHR relative to IEDR." + ) + ihr = iedr[0] * self.ihr_rel_iedr_rv() numpyro.deterministic("ihr", ihr) - else: + elif self.ihr_rv is not None: ihr = self.ihr_rv() - - latent_admissions = population_size * ihr * latent_infections + else: + raise ValueError( + "Must provide either an ihr_rv " + "or an ihr_rel_iedr_rv. " + "Got neither (both were None)." + ) latent_hospital_admissions = compute_delay_ascertained_incidence( p_observed_given_incident=1, - latent_incidence=latent_admissions, - delay_incidence_to_observation_pmf=inf_to_hosp_admit, + latent_incidence=(population_size * ihr * latent_infections), + delay_incidence_to_observation_pmf=(inf_to_hosp_admit), ) longest_possible_delay = inf_to_hosp_admit.shape[0] @@ -443,7 +460,7 @@ def sample( mu=predicted_weekly_admissions[-n_datapoints:], obs=data_observed ) - return sampled_admissions + return observed_hospital_admissions class WastewaterObservationProcess(RandomVariable): @@ -497,10 +514,6 @@ def sample( sampled_ed_visits = None sampled_admissions = None sampled_wastewater = None - None, - None, - None, - ) iedr = None From c108adffff62b50fe777f03d10b8a9add250f5cf Mon Sep 17 00:00:00 2001 From: "Dylan H. Morris" Date: Fri, 17 Jan 2025 14:59:36 -0500 Subject: [PATCH 25/32] Standardize nomenclature --- pyrenew_hew/pyrenew_hew_model.py | 22 +++++++++++----------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/pyrenew_hew/pyrenew_hew_model.py b/pyrenew_hew/pyrenew_hew_model.py index dcbdfbe8..0c5b349c 100644 --- a/pyrenew_hew/pyrenew_hew_model.py +++ b/pyrenew_hew/pyrenew_hew_model.py @@ -367,12 +367,12 @@ def sample( concentration_rv=self.ed_neg_bin_concentration_rv, ) - sampled_ed_visits = ed_visit_obs_rv( + observed_ed_visits = ed_visit_obs_rv( mu=latent_ed_visits_now, obs=data_observed, ) - return sampled_ed_visits, iedr + return observed_ed_visits, iedr class HospAdmitObservationProcess(RandomVariable): @@ -511,14 +511,14 @@ def sample( data.first_data_date_overall - datetime.timedelta(days=n_init_days) ).weekday() - sampled_ed_visits = None - sampled_admissions = None - sampled_wastewater = None + observed_ed_visits = None + observed_admissions = None + observed_wastewater = None iedr = None if sample_ed_visits: - sampled_ed_visits, iedr = self.ed_visit_obs_process_rv( + observed_ed_visits, iedr = self.ed_visit_obs_process_rv( latent_infections=latent_infections, population_size=self.population_size, data_observed=data.data_observed_disease_ed_visits, @@ -527,7 +527,7 @@ def sample( ) if sample_hospital_admissions: - sampled_admissions = self.hosp_admit_obs_process_rv( + observed_admissions = self.hosp_admit_obs_process_rv( latent_infections=latent_infections, first_latent_infection_dow=first_latent_infection_dow, population_size=self.population_size, @@ -536,10 +536,10 @@ def sample( iedr=iedr, ) if sample_wastewater: - sampled_wastewater = self.wastewater_obs_process_rv() + observed_wastewater = self.wastewater_obs_process_rv() return { - "ed_visits": sampled_ed_visits, - "hospital_admissions": sampled_admissions, - "wasewater": sampled_wastewater, + "ed_visits": observed_ed_visits, + "hospital_admissions": observed_admissions, + "wasewater": observed_wastewater, } From 26cc2d3cd6a35c490f47a6f9bc6e9974e62d9ec6 Mon Sep 17 00:00:00 2001 From: "Dylan H. Morris" Date: Fri, 17 Jan 2025 15:27:17 -0500 Subject: [PATCH 26/32] Fix missing import --- pyrenew_hew/pyrenew_hew_model.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/pyrenew_hew/pyrenew_hew_model.py b/pyrenew_hew/pyrenew_hew_model.py index 0c5b349c..10aed95d 100644 --- a/pyrenew_hew/pyrenew_hew_model.py +++ b/pyrenew_hew/pyrenew_hew_model.py @@ -20,7 +20,11 @@ from pyrenew.metaclass import Model, RandomVariable from pyrenew.observation import NegativeBinomialObservation from pyrenew.process import ARProcess, DifferencedProcess -from pyrenew.randomvariable import DistributionalVariable, TransformedVariable +from pyrenew.randomvariable import ( + DeterministicVariable, + DistributionalVariable, + TransformedVariable, +) from pyrenew_hew.pyrenew_hew_data import PyrenewHEWData From c3040ad656b3d2654c65b3642a08fbc0234db7ce Mon Sep 17 00:00:00 2001 From: "Dylan H. Morris" Date: Fri, 17 Jan 2025 15:35:07 -0500 Subject: [PATCH 27/32] deterministic module --- pyrenew_hew/pyrenew_hew_model.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/pyrenew_hew/pyrenew_hew_model.py b/pyrenew_hew/pyrenew_hew_model.py index 10aed95d..7673d0ec 100644 --- a/pyrenew_hew/pyrenew_hew_model.py +++ b/pyrenew_hew/pyrenew_hew_model.py @@ -7,11 +7,11 @@ import pyrenew.transformation as transformation from jax.typing import ArrayLike from numpyro.infer.reparam import LocScaleReparam -from pyrenew.arrayutils import repeat_until_n, tile_until_n from pyrenew.convolve import ( compute_delay_ascertained_incidence, daily_to_mmwr_epiweekly, ) +from pyrenew.deterministic import DeterministicVariable from pyrenew.latent import ( InfectionInitializationProcess, InfectionsWithFeedback, @@ -20,11 +20,7 @@ from pyrenew.metaclass import Model, RandomVariable from pyrenew.observation import NegativeBinomialObservation from pyrenew.process import ARProcess, DifferencedProcess -from pyrenew.randomvariable import ( - DeterministicVariable, - DistributionalVariable, - TransformedVariable, -) +from pyrenew.randomvariable import DistributionalVariable, TransformedVariable from pyrenew_hew.pyrenew_hew_data import PyrenewHEWData From 5eb688c771cabdf7f69d65431a1e205951f07092 Mon Sep 17 00:00:00 2001 From: "Dylan H. Morris" Date: Sat, 18 Jan 2025 11:13:51 -0500 Subject: [PATCH 28/32] Use fact that latent infections now has an n_initialization_points attribute --- pyrenew_hew/pyrenew_hew_model.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/pyrenew_hew/pyrenew_hew_model.py b/pyrenew_hew/pyrenew_hew_model.py index 7673d0ec..418a6e47 100644 --- a/pyrenew_hew/pyrenew_hew_model.py +++ b/pyrenew_hew/pyrenew_hew_model.py @@ -7,6 +7,7 @@ import pyrenew.transformation as transformation from jax.typing import ArrayLike from numpyro.infer.reparam import LocScaleReparam +from pyrenew.arrayutils import tile_until_n from pyrenew.convolve import ( compute_delay_ascertained_incidence, daily_to_mmwr_epiweekly, @@ -185,7 +186,7 @@ def sample(self, n_days_post_init: int): rtu_subpop_ar_proc = ARProcess() rtu_subpop_ar_weekly = rtu_subpop_ar_proc( noise_name="rtu_ar_proc", - n=n_weeks_post_init, + n=n_weeks_rt, init_vals=rtu_subpop_ar_init[jnp.newaxis], autoreg=autoreg_rt_subpop[jnp.newaxis], noise_sd=sigma_rt, @@ -506,9 +507,9 @@ def sample( latent_infections = self.latent_infection_process_rv( n_days_post_init=data.n_days_post_init, ) - n_init_days = self.latent_infection_process_rv.infection_initialization_process.infection_init_method.n_timepoints first_latent_infection_dow = ( - data.first_data_date_overall - datetime.timedelta(days=n_init_days) + data.first_data_date_overall + - datetime.timedelta(days=self.n_initialization_points) ).weekday() observed_ed_visits = None From b672c39b5eb8dff220d2b7d37125907dd5284c61 Mon Sep 17 00:00:00 2001 From: "Dylan H. Morris" Date: Sat, 18 Jan 2025 11:19:37 -0500 Subject: [PATCH 29/32] access attribute correctly --- pyrenew_hew/pyrenew_hew_model.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pyrenew_hew/pyrenew_hew_model.py b/pyrenew_hew/pyrenew_hew_model.py index 418a6e47..e3b7683d 100644 --- a/pyrenew_hew/pyrenew_hew_model.py +++ b/pyrenew_hew/pyrenew_hew_model.py @@ -504,12 +504,12 @@ def sample( sample_hospital_admissions: bool = False, sample_wastewater: bool = False, ) -> dict[str, ArrayLike]: # numpydoc ignore=GL08 + n_init_days = self.latent_infection_process_rv.n_initialization_points latent_infections = self.latent_infection_process_rv( n_days_post_init=data.n_days_post_init, ) first_latent_infection_dow = ( - data.first_data_date_overall - - datetime.timedelta(days=self.n_initialization_points) + data.first_data_date_overall - datetime.timedelta(days=n_init_days) ).weekday() observed_ed_visits = None From 0d03fd088ec0d21895518c895139aab1fcbe74fe Mon Sep 17 00:00:00 2001 From: "Dylan H. Morris" Date: Sat, 18 Jan 2025 11:33:57 -0500 Subject: [PATCH 30/32] Tweak prior assignment to handle new exactly one ihr prior check --- pipelines/build_pyrenew_model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pipelines/build_pyrenew_model.py b/pipelines/build_pyrenew_model.py index 5e71836b..80a94479 100644 --- a/pipelines/build_pyrenew_model.py +++ b/pipelines/build_pyrenew_model.py @@ -105,7 +105,7 @@ def build_model_from_dir( hosp_admit_neg_bin_concentration_rv=( priors["hosp_admit_neg_bin_concentration_rv"] ), - ihr_rel_iedr_rv=priors["ihr_rel_iedr_rv"], + ihr_rel_iedr_rv=None, # since for now we only use H or E, not HE ihr_rv=priors["ihr_rv"], ) From 001e29205b9c9581fa4c0efdcf0c79868a5693e8 Mon Sep 17 00:00:00 2001 From: "Dylan H. Morris" Date: Sat, 18 Jan 2025 11:38:04 -0500 Subject: [PATCH 31/32] Update test --- tests/test_LatentInfectionProcess.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/tests/test_LatentInfectionProcess.py b/tests/test_LatentInfectionProcess.py index 40dd192f..695fe916 100644 --- a/tests/test_LatentInfectionProcess.py +++ b/tests/test_LatentInfectionProcess.py @@ -34,7 +34,6 @@ def test_LatentInfectionProcess(): infection_feedback_strength_rv = DeterministicVariable("inf_feedback", -2) n_initialization_points = 10 n_days_post_init = 14 - n_weeks_post_init = 2 my_latent_infection_model = LatentInfectionProcess( i0_first_obs_n_rv=i0_first_obs_n_rv, @@ -50,8 +49,7 @@ def test_LatentInfectionProcess(): with numpyro.handlers.seed(rng_seed=223): latent_inf_w_hierarchical_effects = my_latent_infection_model( - n_days_post_init=n_days_post_init, - n_weeks_post_init=n_weeks_post_init, + n_days_post_init=n_days_post_init ) # Calculate latent infections without hierarchical dynamics From 4131e81b62524469b8f4a56a1cfccb895522bd3e Mon Sep 17 00:00:00 2001 From: "Dylan H. Morris" Date: Sat, 18 Jan 2025 12:18:00 -0500 Subject: [PATCH 32/32] Fix bug introduced by copy/paste --- pyrenew_hew/pyrenew_hew_data.py | 2 +- pyrenew_hew/pyrenew_hew_model.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/pyrenew_hew/pyrenew_hew_data.py b/pyrenew_hew/pyrenew_hew_data.py index 7512ebf0..75ac5d1d 100644 --- a/pyrenew_hew/pyrenew_hew_data.py +++ b/pyrenew_hew/pyrenew_hew_data.py @@ -110,7 +110,7 @@ def first_data_date_overall(self): @property def last_data_date_overall(self): - return min(filter(None, self.last_data_dates.values())) + return max(filter(None, self.last_data_dates.values())) @property def n_days_post_init(self): diff --git a/pyrenew_hew/pyrenew_hew_model.py b/pyrenew_hew/pyrenew_hew_model.py index e3b7683d..37f0e408 100644 --- a/pyrenew_hew/pyrenew_hew_model.py +++ b/pyrenew_hew/pyrenew_hew_model.py @@ -209,7 +209,7 @@ def sample(self, n_days_post_init: int): repeats=7, axis=0, )[:n_days_post_init, :] - ) + ) # indexed rel to first post-init day. i0_subpop_rv = DeterministicVariable( "i0_subpop", i_first_obs_over_n_subpop @@ -318,7 +318,7 @@ def sample( iedr = jnp.repeat( transformation.SigmoidTransform()(p_ed_ar + p_ed_mean), repeats=7, - )[:n_datapoints] + )[:n_datapoints] # indexed rel to first ed report day # this is only applied after the ed visits are generated, not to all # the latent infections. This is why we cannot apply the iedr in # compute_delay_ascertained_incidence