From 3cfff6007a0c2d35733039c82fc411b2f3236658 Mon Sep 17 00:00:00 2001 From: Samuel Brand Date: Wed, 21 Feb 2024 13:40:58 +0000 Subject: [PATCH 1/4] New `LatentProcess` struct with default constructor for Random walk; along with updated unit tests and inference run checking --- EpiAware/src/EpiAware.jl | 4 +-- EpiAware/src/latent-processes.jl | 31 +++++++++++++++++++ EpiAware/src/models.jl | 8 +++-- .../toy_model_log_infs_RW.jl | 5 +-- EpiAware/test/test_models.jl | 10 +++--- 5 files changed, 47 insertions(+), 11 deletions(-) diff --git a/EpiAware/src/EpiAware.jl b/EpiAware/src/EpiAware.jl index 01dc5bb67..3e3b421a7 100644 --- a/EpiAware/src/EpiAware.jl +++ b/EpiAware/src/EpiAware.jl @@ -38,12 +38,12 @@ export create_discrete_pmf, default_rw_priors, default_delay_obs_priors, spread_ export EpiData, Renewal, ExpGrowthRate, DirectInfections # Exported Turing model constructors -export make_epi_inference_model, random_walk, delay_observations +export make_epi_inference_model, delay_observations, random_walk_process include("epimodel.jl") include("utilities.jl") -include("models.jl") include("latent-processes.jl") include("observation-processes.jl") +include("models.jl") end diff --git a/EpiAware/src/latent-processes.jl b/EpiAware/src/latent-processes.jl index 2ac2de53b..287677a62 100644 --- a/EpiAware/src/latent-processes.jl +++ b/EpiAware/src/latent-processes.jl @@ -18,3 +18,34 @@ end end return rw, init, (; σ_RW,) end + +""" + struct LatentProcess{F<:Function} + +A struct representing a latent process with its priors. + +# Fields +- `latent_process`: The latent process function. +- `latent_process_priors`: NamedTuple containing the priors for the latent process. + +""" +struct LatentProcess{F<:Function} + latent_process::F + latent_process_priors::NamedTuple +end + +""" + random_walk_process(; latent_process_priors = default_rw_priors()) + +Create a `LatentProcess` struct reflecting a random walk process with optional priors. + +# Arguments +- `latent_process_priors`: Optional priors for the random walk process. + +# Returns +- `LatentProcess`: A random walk process. + +""" +function random_walk_process(; latent_process_priors = default_rw_priors()) + LatentProcess(random_walk, latent_process_priors) +end diff --git a/EpiAware/src/models.jl b/EpiAware/src/models.jl index 24bf036fc..5eccadeb0 100644 --- a/EpiAware/src/models.jl +++ b/EpiAware/src/models.jl @@ -1,15 +1,17 @@ @model function make_epi_inference_model( y_t, epimodel::AbstractEpiModel, - latent_process, + latent_process_obj::LatentProcess, observation_process; process_priors, pos_shift = 1e-6, ) #Latent process time_steps = epimodel.data.time_horizon - @submodel latent_process, init, latent_process_aux = - latent_process(time_steps; latent_process_priors = process_priors) + @submodel latent_process, init, latent_process_aux = latent_process_obj.latent_process( + time_steps; + latent_process_priors = latent_process_obj.latent_process_priors, + ) #Transform into infections I_t = epimodel(latent_process, init) diff --git a/EpiAware/test/predictive_checking/toy_model_log_infs_RW.jl b/EpiAware/test/predictive_checking/toy_model_log_infs_RW.jl index 41f0993c4..b8381a8c4 100644 --- a/EpiAware/test/predictive_checking/toy_model_log_infs_RW.jl +++ b/EpiAware/test/predictive_checking/toy_model_log_infs_RW.jl @@ -101,6 +101,7 @@ In this case we use the `DirectInfections` model. =# toy_log_infs = DirectInfections(model_data) +rwp = EpiAware.random_walk_process() #= ## Generate a `Turing` `Model` @@ -110,7 +111,7 @@ We don't have observed data, so we use `missing` value for `y_t`. log_infs_model = make_epi_inference_model( missing, toy_log_infs, - random_walk, + rwp, delay_observations; process_priors = merge(default_rw_priors(), default_delay_obs_priors()), pos_shift = 1e-6, @@ -150,7 +151,7 @@ truth_data = random_epidemic.y_t model = make_epi_inference_model( truth_data, toy_log_infs, - random_walk, + rwp, delay_observations; process_priors = merge(default_rw_priors(), default_delay_obs_priors()), pos_shift = 1e-6, diff --git a/EpiAware/test/test_models.jl b/EpiAware/test/test_models.jl index 61d4493b0..4f39783dc 100644 --- a/EpiAware/test/test_models.jl +++ b/EpiAware/test/test_models.jl @@ -9,12 +9,12 @@ epimodel = DirectInfections(data) - + rwp = EpiAware.random_walk_process() # Call the function test_mdl = make_epi_inference_model( y_t, epimodel, - random_walk, + rwp, delay_observations; process_priors, pos_shift, @@ -43,12 +43,13 @@ end pos_shift = 1e-6 epimodel = ExpGrowthRate(data) + rwp = EpiAware.random_walk_process() # Call the function test_mdl = make_epi_inference_model( y_t, epimodel, - random_walk, + rwp, delay_observations; process_priors, pos_shift, @@ -77,12 +78,13 @@ end pos_shift = 1e-6 epimodel = Renewal(data) + rwp = EpiAware.random_walk_process() # Call the function test_mdl = make_epi_inference_model( y_t, epimodel, - random_walk, + rwp, delay_observations; process_priors, pos_shift, From 349d444cdc7791795a2986e4e211f32a57a686c7 Mon Sep 17 00:00:00 2001 From: Samuel Brand Date: Wed, 21 Feb 2024 13:55:32 +0000 Subject: [PATCH 2/4] Created `ObservationModel` struct along with default constructor, unit tests and inference test --- EpiAware/src/EpiAware.jl | 2 +- EpiAware/src/latent-processes.jl | 2 +- EpiAware/src/models.jl | 7 ++-- EpiAware/src/observation-processes.jl | 32 +++++++++++++++ .../toy_model_log_infs_RW.jl | 22 +++-------- EpiAware/test/test_models.jl | 39 +++++-------------- 6 files changed, 51 insertions(+), 53 deletions(-) diff --git a/EpiAware/src/EpiAware.jl b/EpiAware/src/EpiAware.jl index 3e3b421a7..2d6432e90 100644 --- a/EpiAware/src/EpiAware.jl +++ b/EpiAware/src/EpiAware.jl @@ -38,7 +38,7 @@ export create_discrete_pmf, default_rw_priors, default_delay_obs_priors, spread_ export EpiData, Renewal, ExpGrowthRate, DirectInfections # Exported Turing model constructors -export make_epi_inference_model, delay_observations, random_walk_process +export make_epi_inference_model, delay_observations_model, random_walk_process include("epimodel.jl") include("utilities.jl") diff --git a/EpiAware/src/latent-processes.jl b/EpiAware/src/latent-processes.jl index 287677a62..ed7b4afa1 100644 --- a/EpiAware/src/latent-processes.jl +++ b/EpiAware/src/latent-processes.jl @@ -25,7 +25,7 @@ end A struct representing a latent process with its priors. # Fields -- `latent_process`: The latent process function. +- `latent_process`: The latent process function for a `Turing` model. - `latent_process_priors`: NamedTuple containing the priors for the latent process. """ diff --git a/EpiAware/src/models.jl b/EpiAware/src/models.jl index 5eccadeb0..9605534e4 100644 --- a/EpiAware/src/models.jl +++ b/EpiAware/src/models.jl @@ -2,8 +2,7 @@ y_t, epimodel::AbstractEpiModel, latent_process_obj::LatentProcess, - observation_process; - process_priors, + observation_process_obj::ObservationModel; pos_shift = 1e-6, ) #Latent process @@ -17,11 +16,11 @@ I_t = epimodel(latent_process, init) #Predictive distribution of ascerted cases - @submodel generated_y_t, generated_y_t_aux = observation_process( + @submodel generated_y_t, generated_y_t_aux = observation_process_obj.observation_model( y_t, I_t, epimodel::AbstractEpiModel; - observation_process_priors = process_priors, + observation_process_priors = observation_process_obj.observation_model_priors, pos_shift = pos_shift, ) diff --git a/EpiAware/src/observation-processes.jl b/EpiAware/src/observation-processes.jl index 90d904449..5d6e1e0a7 100644 --- a/EpiAware/src/observation-processes.jl +++ b/EpiAware/src/observation-processes.jl @@ -22,3 +22,35 @@ end return y_t, (; neg_bin_cluster_factor,) end + + +""" + struct ObservationModel{F<:Function} + +A struct representing an observation model with its priors. + +# Fields +- `observation_model`: The observation model function for a `Turing` model. +- `observation_model_priors`: NamedTuple containing the priors for the observation model. + +""" +struct ObservationModel{F<:Function} + observation_model::F + observation_model_priors::NamedTuple +end + +""" + delay_observations_model(; latent_process_priors = default_rw_priors()) + +Create an `ObservationModel` struct reflecting a delayed observation process with optional priors. + +# Arguments +- `latent_process_priors`: Optional priors for the delayed observation process. + +# Returns +- `ObservationModel`: An observation model with delayed observations. + +""" +function delay_observations_model(; observation_model_priors = default_delay_obs_priors()) + ObservationModel(delay_observations, observation_model_priors) +end diff --git a/EpiAware/test/predictive_checking/toy_model_log_infs_RW.jl b/EpiAware/test/predictive_checking/toy_model_log_infs_RW.jl index b8381a8c4..097eaf397 100644 --- a/EpiAware/test/predictive_checking/toy_model_log_infs_RW.jl +++ b/EpiAware/test/predictive_checking/toy_model_log_infs_RW.jl @@ -101,21 +101,16 @@ In this case we use the `DirectInfections` model. =# toy_log_infs = DirectInfections(model_data) -rwp = EpiAware.random_walk_process() +rwp = random_walk_process() +obs_mdl = delay_observations_model() #= ## Generate a `Turing` `Model` We don't have observed data, so we use `missing` value for `y_t`. =# -log_infs_model = make_epi_inference_model( - missing, - toy_log_infs, - rwp, - delay_observations; - process_priors = merge(default_rw_priors(), default_delay_obs_priors()), - pos_shift = 1e-6, -) +log_infs_model = + make_epi_inference_model(missing, toy_log_infs, rwp, obs_mdl; pos_shift = 1e-6) #= ## Sample from the model @@ -148,14 +143,7 @@ We treat the generated data as observed data and attempt to infer underlying inf truth_data = random_epidemic.y_t -model = make_epi_inference_model( - truth_data, - toy_log_infs, - rwp, - delay_observations; - process_priors = merge(default_rw_priors(), default_delay_obs_priors()), - pos_shift = 1e-6, -) +model = make_epi_inference_model(truth_data, toy_log_infs, rwp, obs_mdl; pos_shift = 1e-6) @time chn = sample( model, diff --git a/EpiAware/test/test_models.jl b/EpiAware/test/test_models.jl index 4f39783dc..cd07d44b5 100644 --- a/EpiAware/test/test_models.jl +++ b/EpiAware/test/test_models.jl @@ -4,21 +4,14 @@ # Define test inputs y_t = missing # Data will be generated from the model data = EpiData([0.2, 0.3, 0.5], [0.1, 0.4, 0.5], 0.8, 10, exp) - process_priors = merge(default_rw_priors(), default_delay_obs_priors()) pos_shift = 1e-6 epimodel = DirectInfections(data) - rwp = EpiAware.random_walk_process() + rwp = random_walk_process() + obs_mdl = delay_observations_model() # Call the function - test_mdl = make_epi_inference_model( - y_t, - epimodel, - rwp, - delay_observations; - process_priors, - pos_shift, - ) + test_mdl = make_epi_inference_model(y_t, epimodel, rwp, obs_mdl; pos_shift) # Define expected outputs for a conditional model # Underlying log-infections are const value 1 for all time steps and @@ -39,21 +32,14 @@ end # Define test inputs y_t = missing # Data will be generated from the model data = EpiData([0.2, 0.3, 0.5], [0.1, 0.4, 0.5], 0.8, 10, exp) - process_priors = merge(default_rw_priors(), default_delay_obs_priors()) pos_shift = 1e-6 epimodel = ExpGrowthRate(data) - rwp = EpiAware.random_walk_process() + rwp = random_walk_process() + obs_mdl = delay_observations_model() # Call the function - test_mdl = make_epi_inference_model( - y_t, - epimodel, - rwp, - delay_observations; - process_priors, - pos_shift, - ) + test_mdl = make_epi_inference_model(y_t, epimodel, rwp, obs_mdl; pos_shift) # Define expected outputs for a conditional model # Underlying log-infections are const value 1 for all time steps and @@ -78,17 +64,10 @@ end pos_shift = 1e-6 epimodel = Renewal(data) - rwp = EpiAware.random_walk_process() - + rwp = random_walk_process() + obs_mdl = delay_observations_model() # Call the function - test_mdl = make_epi_inference_model( - y_t, - epimodel, - rwp, - delay_observations; - process_priors, - pos_shift, - ) + test_mdl = make_epi_inference_model(y_t, epimodel, rwp, obs_mdl; pos_shift) # Define expected outputs for a conditional model # Underlying log-infections are const value 1 for all time steps and From 9866027c1efec68bbb75c461f477bc10b6f6e204 Mon Sep 17 00:00:00 2001 From: Samuel Brand Date: Wed, 21 Feb 2024 14:24:41 +0000 Subject: [PATCH 3/4] fixed test for sub processes/models --- EpiAware/test/test_latent-processes.jl | 2 +- EpiAware/test/test_observation-processes.jl | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/EpiAware/test/test_latent-processes.jl b/EpiAware/test/test_latent-processes.jl index 9fb85a236..87b4845f3 100644 --- a/EpiAware/test/test_latent-processes.jl +++ b/EpiAware/test/test_latent-processes.jl @@ -2,7 +2,7 @@ @testitem "Testing random_walk against theoretical properties" begin using DynamicPPL, Turing n = 5 - model = random_walk(n) + model = EpiAware.random_walk(n) fixed_model = fix(model, (σ²_RW = 1.0, init_rw_value = 0.0)) #Fixing the standard deviation of the random walk process n_samples = 1000 samples_day_5 = diff --git a/EpiAware/test/test_observation-processes.jl b/EpiAware/test/test_observation-processes.jl index 238bfd5fa..caa588676 100644 --- a/EpiAware/test/test_observation-processes.jl +++ b/EpiAware/test/test_observation-processes.jl @@ -11,7 +11,7 @@ observation_process_priors = default_delay_obs_priors() # Call the function - mdl = delay_observations( + mdl = EpiAware.delay_observations( missing, I_t, epimodel; From e4f2cc41eca64de194e38535e811cd2cfe522546 Mon Sep 17 00:00:00 2001 From: Samuel Brand Date: Wed, 21 Feb 2024 20:27:28 +0000 Subject: [PATCH 4/4] Update EpiAware README --- EpiAware/README.md | 33 +++++++++++++++++++++++++-------- 1 file changed, 25 insertions(+), 8 deletions(-) diff --git a/EpiAware/README.md b/EpiAware/README.md index 5e699c1b5..80749460b 100644 --- a/EpiAware/README.md +++ b/EpiAware/README.md @@ -6,8 +6,9 @@ - Solid lines indicate implemented features/analysis. - Dashed lines indicate planned features/analysis. +## Proposed `EpiAware` model diagram ```mermaid -flowchart TD +flowchart LR A["Underlying dists. and specify length of sims @@ -26,24 +27,40 @@ C["Observational Data Obs. cases y_t"] D["Latent processes --------------------- -Random Walk"] -E[Turing model constructor] -F["Latent Process priors"] +random_walk"] +E["Turing model constructor +--------------------- +make_epi_inference_model"] +F["Latent Process priors +--------------------- +default_rw_priors"] G[Posterior draws] H[Posterior checking] I[Post-processing] DataW[Data wrangling and QC] -J["Observation Model +J["Observation models --------------------- delay_observations"] +K["Observation model priors +--------------------- +default_delay_obs_priors"] +ObservationModel["ObservationModel +--------------------- +delay_observations_model"] +LatentProcess["LatentProcess +--------------------- +random_walk_process"] A --> EpiModel B --> EpiModel EpiModel -->E C-->E -D-->|random_walk| E -J-->E -F-->|default_rw_priors|E +D-->LatentProcess +F-->LatentProcess +J-->ObservationModel +K-->ObservationModel +LatentProcess-->E +ObservationModel-->E E-->|sample...NUTS...| G G-.->H H-.->I