diff --git a/EpiAware/README.md b/EpiAware/README.md index 32d6b8d6d..d5fee9898 100644 --- a/EpiAware/README.md +++ b/EpiAware/README.md @@ -9,8 +9,9 @@ - Solid lines indicate implemented features/analysis. - Dashed lines indicate planned features/analysis. +## Proposed `EpiAware` model diagram ```mermaid -flowchart TD +flowchart LR A["Underlying dists. and specify length of sims @@ -29,24 +30,40 @@ C["Observational Data Obs. cases y_t"] D["Latent processes --------------------- -Random Walk"] -E[Turing model constructor] -F["Latent Process priors"] +random_walk"] +E["Turing model constructor +--------------------- +make_epi_inference_model"] +F["Latent Process priors +--------------------- +default_rw_priors"] G[Posterior draws] H[Posterior checking] I[Post-processing] DataW[Data wrangling and QC] -J["Observation Model +J["Observation models --------------------- delay_observations"] +K["Observation model priors +--------------------- +default_delay_obs_priors"] +ObservationModel["ObservationModel +--------------------- +delay_observations_model"] +LatentProcess["LatentProcess +--------------------- +random_walk_process"] A --> EpiModel B --> EpiModel EpiModel -->E C-->E -D-->|random_walk| E -J-->E -F-->|default_rw_priors|E +D-->LatentProcess +F-->LatentProcess +J-->ObservationModel +K-->ObservationModel +LatentProcess-->E +ObservationModel-->E E-->|sample...NUTS...| G G-.->H H-.->I diff --git a/EpiAware/src/EpiAware.jl b/EpiAware/src/EpiAware.jl index f9235431e..55a37e03f 100644 --- a/EpiAware/src/EpiAware.jl +++ b/EpiAware/src/EpiAware.jl @@ -38,12 +38,12 @@ export create_discrete_pmf, default_rw_priors, default_delay_obs_priors, spread_ export EpiData, Renewal, ExpGrowthRate, DirectInfections # Exported Turing model constructors -export make_epi_inference_model, random_walk, delay_observations +export make_epi_inference_model, delay_observations_model, random_walk_process include("epimodel.jl") include("utilities.jl") -include("models.jl") include("latent-processes.jl") include("observation-processes.jl") +include("models.jl") end diff --git a/EpiAware/src/latent-processes.jl b/EpiAware/src/latent-processes.jl index eb43bd332..b85a775eb 100644 --- a/EpiAware/src/latent-processes.jl +++ b/EpiAware/src/latent-processes.jl @@ -18,3 +18,34 @@ end end return rw, init, (; σ_RW,) end + +""" + struct LatentProcess{F<:Function} + +A struct representing a latent process with its priors. + +# Fields +- `latent_process`: The latent process function for a `Turing` model. +- `latent_process_priors`: NamedTuple containing the priors for the latent process. + +""" +struct LatentProcess{F <: Function} + latent_process::F + latent_process_priors::NamedTuple +end + +""" + random_walk_process(; latent_process_priors = default_rw_priors()) + +Create a `LatentProcess` struct reflecting a random walk process with optional priors. + +# Arguments +- `latent_process_priors`: Optional priors for the random walk process. + +# Returns +- `LatentProcess`: A random walk process. + +""" +function random_walk_process(; latent_process_priors = default_rw_priors()) + LatentProcess(random_walk, latent_process_priors) +end diff --git a/EpiAware/src/models.jl b/EpiAware/src/models.jl index 9396e6fd3..8d896de85 100644 --- a/EpiAware/src/models.jl +++ b/EpiAware/src/models.jl @@ -1,25 +1,26 @@ @model function make_epi_inference_model( y_t, epimodel::AbstractEpiModel, - latent_process, - observation_process; - process_priors, + latent_process_obj::LatentProcess, + observation_process_obj::ObservationModel; pos_shift = 1e-6 ) #Latent process time_steps = epimodel.data.time_horizon - @submodel latent_process, init, latent_process_aux = latent_process( - time_steps; latent_process_priors = process_priors) + @submodel latent_process, init, latent_process_aux = latent_process_obj.latent_process( + time_steps; + latent_process_priors = latent_process_obj.latent_process_priors + ) #Transform into infections I_t = epimodel(latent_process, init) #Predictive distribution of ascerted cases - @submodel generated_y_t, generated_y_t_aux = observation_process( + @submodel generated_y_t, generated_y_t_aux = observation_process_obj.observation_model( y_t, I_t, epimodel::AbstractEpiModel; - observation_process_priors = process_priors, + observation_process_priors = observation_process_obj.observation_model_priors, pos_shift = pos_shift ) diff --git a/EpiAware/src/observation-processes.jl b/EpiAware/src/observation-processes.jl index 379e0701b..3389a5cfd 100644 --- a/EpiAware/src/observation-processes.jl +++ b/EpiAware/src/observation-processes.jl @@ -21,3 +21,34 @@ end return y_t, (; neg_bin_cluster_factor,) end + +""" + struct ObservationModel{F<:Function} + +A struct representing an observation model with its priors. + +# Fields +- `observation_model`: The observation model function for a `Turing` model. +- `observation_model_priors`: NamedTuple containing the priors for the observation model. + +""" +struct ObservationModel{F <: Function} + observation_model::F + observation_model_priors::NamedTuple +end + +""" + delay_observations_model(; latent_process_priors = default_rw_priors()) + +Create an `ObservationModel` struct reflecting a delayed observation process with optional priors. + +# Arguments +- `latent_process_priors`: Optional priors for the delayed observation process. + +# Returns +- `ObservationModel`: An observation model with delayed observations. + +""" +function delay_observations_model(; observation_model_priors = default_delay_obs_priors()) + ObservationModel(delay_observations, observation_model_priors) +end diff --git a/EpiAware/test/predictive_checking/toy_model_log_infs_RW.jl b/EpiAware/test/predictive_checking/toy_model_log_infs_RW.jl index d2a46b632..3de96a657 100644 --- a/EpiAware/test/predictive_checking/toy_model_log_infs_RW.jl +++ b/EpiAware/test/predictive_checking/toy_model_log_infs_RW.jl @@ -101,6 +101,8 @@ In this case we use the `DirectInfections` model. =# toy_log_infs = DirectInfections(model_data) +rwp = random_walk_process() +obs_mdl = delay_observations_model() #= ## Generate a `Turing` `Model` @@ -108,13 +110,7 @@ We don't have observed data, so we use `missing` value for `y_t`. =# log_infs_model = make_epi_inference_model( - missing, - toy_log_infs, - random_walk, - delay_observations; - process_priors = merge(default_rw_priors(), default_delay_obs_priors()), - pos_shift = 1e-6 -) + missing, toy_log_infs, rwp, obs_mdl; pos_shift = 1e-6) #= ## Sample from the model @@ -147,14 +143,7 @@ We treat the generated data as observed data and attempt to infer underlying inf truth_data = random_epidemic.y_t -model = make_epi_inference_model( - truth_data, - toy_log_infs, - random_walk, - delay_observations; - process_priors = merge(default_rw_priors(), default_delay_obs_priors()), - pos_shift = 1e-6 -) +model = make_epi_inference_model(truth_data, toy_log_infs, rwp, obs_mdl; pos_shift = 1e-6) @time chn = sample( model, diff --git a/EpiAware/test/test_latent-processes.jl b/EpiAware/test/test_latent-processes.jl index b8c7f7f00..5541e515a 100644 --- a/EpiAware/test/test_latent-processes.jl +++ b/EpiAware/test/test_latent-processes.jl @@ -2,7 +2,7 @@ @testitem "Testing random_walk against theoretical properties" begin using DynamicPPL, Turing n = 5 - model = random_walk(n) + model = EpiAware.random_walk(n) fixed_model = fix(model, (σ²_RW = 1.0, init_rw_value = 0.0)) #Fixing the standard deviation of the random walk process n_samples = 1000 samples_day_5 = sample(fixed_model, Prior(), n_samples) |> diff --git a/EpiAware/test/test_models.jl b/EpiAware/test/test_models.jl index 7e94bccb8..61e56b2e3 100644 --- a/EpiAware/test/test_models.jl +++ b/EpiAware/test/test_models.jl @@ -4,20 +4,13 @@ # Define test inputs y_t = missing # Data will be generated from the model data = EpiData([0.2, 0.3, 0.5], [0.1, 0.4, 0.5], 0.8, 10, exp) - process_priors = merge(default_rw_priors(), default_delay_obs_priors()) pos_shift = 1e-6 epimodel = DirectInfections(data) - + rwp = random_walk_process() + obs_mdl = delay_observations_model() # Call the function - test_mdl = make_epi_inference_model( - y_t, - epimodel, - random_walk, - delay_observations; - process_priors, - pos_shift - ) + test_mdl = make_epi_inference_model(y_t, epimodel, rwp, obs_mdl; pos_shift) # Define expected outputs for a conditional model # Underlying log-infections are const value 1 for all time steps and @@ -38,20 +31,14 @@ end # Define test inputs y_t = missing # Data will be generated from the model data = EpiData([0.2, 0.3, 0.5], [0.1, 0.4, 0.5], 0.8, 10, exp) - process_priors = merge(default_rw_priors(), default_delay_obs_priors()) pos_shift = 1e-6 epimodel = ExpGrowthRate(data) + rwp = random_walk_process() + obs_mdl = delay_observations_model() # Call the function - test_mdl = make_epi_inference_model( - y_t, - epimodel, - random_walk, - delay_observations; - process_priors, - pos_shift - ) + test_mdl = make_epi_inference_model(y_t, epimodel, rwp, obs_mdl; pos_shift) # Define expected outputs for a conditional model # Underlying log-infections are const value 1 for all time steps and @@ -76,16 +63,10 @@ end pos_shift = 1e-6 epimodel = Renewal(data) - + rwp = random_walk_process() + obs_mdl = delay_observations_model() # Call the function - test_mdl = make_epi_inference_model( - y_t, - epimodel, - random_walk, - delay_observations; - process_priors, - pos_shift - ) + test_mdl = make_epi_inference_model(y_t, epimodel, rwp, obs_mdl; pos_shift) # Define expected outputs for a conditional model # Underlying log-infections are const value 1 for all time steps and diff --git a/EpiAware/test/test_observation-processes.jl b/EpiAware/test/test_observation-processes.jl index b30266f65..1386d3bb5 100644 --- a/EpiAware/test/test_observation-processes.jl +++ b/EpiAware/test/test_observation-processes.jl @@ -11,7 +11,7 @@ observation_process_priors = default_delay_obs_priors() # Call the function - mdl = delay_observations( + mdl = EpiAware.delay_observations( missing, I_t, epimodel;