Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve the prior interface for make epi model inference #61

Merged
Merged
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions EpiAware/src/EpiAware.jl
Original file line number Diff line number Diff line change
@@ -38,12 +38,12 @@ export create_discrete_pmf, default_rw_priors, default_delay_obs_priors, spread_
export EpiData, Renewal, ExpGrowthRate, DirectInfections

# Exported Turing model constructors
export make_epi_inference_model, random_walk, delay_observations
export make_epi_inference_model, delay_observations_model, random_walk_process

include("epimodel.jl")
include("utilities.jl")
include("models.jl")
include("latent-processes.jl")
include("observation-processes.jl")
include("models.jl")

end
31 changes: 31 additions & 0 deletions EpiAware/src/latent-processes.jl
Original file line number Diff line number Diff line change
@@ -18,3 +18,34 @@ end
end
return rw, init, (; σ_RW,)
end

"""
struct LatentProcess{F<:Function}
A struct representing a latent process with its priors.
# Fields
- `latent_process`: The latent process function for a `Turing` model.
- `latent_process_priors`: NamedTuple containing the priors for the latent process.
"""
struct LatentProcess{F<:Function}
latent_process::F
latent_process_priors::NamedTuple
end

"""
random_walk_process(; latent_process_priors = default_rw_priors())
Create a `LatentProcess` struct reflecting a random walk process with optional priors.
# Arguments
- `latent_process_priors`: Optional priors for the random walk process.
# Returns
- `LatentProcess`: A random walk process.
"""
function random_walk_process(; latent_process_priors = default_rw_priors())
LatentProcess(random_walk, latent_process_priors)
end
15 changes: 8 additions & 7 deletions EpiAware/src/models.jl
Original file line number Diff line number Diff line change
@@ -1,25 +1,26 @@
@model function make_epi_inference_model(
y_t,
epimodel::AbstractEpiModel,
latent_process,
observation_process;
process_priors,
latent_process_obj::LatentProcess,
observation_process_obj::ObservationModel;
pos_shift = 1e-6,
)
#Latent process
time_steps = epimodel.data.time_horizon
@submodel latent_process, init, latent_process_aux =
latent_process(time_steps; latent_process_priors = process_priors)
@submodel latent_process, init, latent_process_aux = latent_process_obj.latent_process(
time_steps;
latent_process_priors = latent_process_obj.latent_process_priors,
)

#Transform into infections
I_t = epimodel(latent_process, init)

#Predictive distribution of ascerted cases
@submodel generated_y_t, generated_y_t_aux = observation_process(
@submodel generated_y_t, generated_y_t_aux = observation_process_obj.observation_model(
y_t,
I_t,
epimodel::AbstractEpiModel;
observation_process_priors = process_priors,
observation_process_priors = observation_process_obj.observation_model_priors,
pos_shift = pos_shift,
)

32 changes: 32 additions & 0 deletions EpiAware/src/observation-processes.jl
Original file line number Diff line number Diff line change
@@ -22,3 +22,35 @@ end

return y_t, (; neg_bin_cluster_factor,)
end


"""
struct ObservationModel{F<:Function}
A struct representing an observation model with its priors.
# Fields
- `observation_model`: The observation model function for a `Turing` model.
- `observation_model_priors`: NamedTuple containing the priors for the observation model.
"""
struct ObservationModel{F<:Function}
observation_model::F
observation_model_priors::NamedTuple
end

"""
delay_observations_model(; latent_process_priors = default_rw_priors())
Create an `ObservationModel` struct reflecting a delayed observation process with optional priors.
# Arguments
- `latent_process_priors`: Optional priors for the delayed observation process.
# Returns
- `ObservationModel`: An observation model with delayed observations.
"""
function delay_observations_model(; observation_model_priors = default_delay_obs_priors())
ObservationModel(delay_observations, observation_model_priors)
end
21 changes: 5 additions & 16 deletions EpiAware/test/predictive_checking/toy_model_log_infs_RW.jl
Original file line number Diff line number Diff line change
@@ -101,20 +101,16 @@ In this case we use the `DirectInfections` model.
=#

toy_log_infs = DirectInfections(model_data)
rwp = random_walk_process()
obs_mdl = delay_observations_model()

#=
## Generate a `Turing` `Model`
We don't have observed data, so we use `missing` value for `y_t`.
=#

log_infs_model = make_epi_inference_model(
missing,
toy_log_infs,
random_walk,
delay_observations;
process_priors = merge(default_rw_priors(), default_delay_obs_priors()),
pos_shift = 1e-6,
)
log_infs_model =
make_epi_inference_model(missing, toy_log_infs, rwp, obs_mdl; pos_shift = 1e-6)

#=
## Sample from the model
@@ -147,14 +143,7 @@ We treat the generated data as observed data and attempt to infer underlying inf

truth_data = random_epidemic.y_t

model = make_epi_inference_model(
truth_data,
toy_log_infs,
random_walk,
delay_observations;
process_priors = merge(default_rw_priors(), default_delay_obs_priors()),
pos_shift = 1e-6,
)
model = make_epi_inference_model(truth_data, toy_log_infs, rwp, obs_mdl; pos_shift = 1e-6)

@time chn = sample(
model,
2 changes: 1 addition & 1 deletion EpiAware/test/test_latent-processes.jl
Original file line number Diff line number Diff line change
@@ -2,7 +2,7 @@
@testitem "Testing random_walk against theoretical properties" begin
using DynamicPPL, Turing
n = 5
model = random_walk(n)
model = EpiAware.random_walk(n)
fixed_model = fix(model, (σ²_RW = 1.0, init_rw_value = 0.0)) #Fixing the standard deviation of the random walk process
n_samples = 1000
samples_day_5 =
37 changes: 9 additions & 28 deletions EpiAware/test/test_models.jl
Original file line number Diff line number Diff line change
@@ -4,21 +4,14 @@
# Define test inputs
y_t = missing # Data will be generated from the model
data = EpiData([0.2, 0.3, 0.5], [0.1, 0.4, 0.5], 0.8, 10, exp)
process_priors = merge(default_rw_priors(), default_delay_obs_priors())
pos_shift = 1e-6


epimodel = DirectInfections(data)

rwp = random_walk_process()
obs_mdl = delay_observations_model()
# Call the function
test_mdl = make_epi_inference_model(
y_t,
epimodel,
random_walk,
delay_observations;
process_priors,
pos_shift,
)
test_mdl = make_epi_inference_model(y_t, epimodel, rwp, obs_mdl; pos_shift)

# Define expected outputs for a conditional model
# Underlying log-infections are const value 1 for all time steps and
@@ -39,20 +32,14 @@ end
# Define test inputs
y_t = missing # Data will be generated from the model
data = EpiData([0.2, 0.3, 0.5], [0.1, 0.4, 0.5], 0.8, 10, exp)
process_priors = merge(default_rw_priors(), default_delay_obs_priors())
pos_shift = 1e-6

epimodel = ExpGrowthRate(data)
rwp = random_walk_process()
obs_mdl = delay_observations_model()

# Call the function
test_mdl = make_epi_inference_model(
y_t,
epimodel,
random_walk,
delay_observations;
process_priors,
pos_shift,
)
test_mdl = make_epi_inference_model(y_t, epimodel, rwp, obs_mdl; pos_shift)

# Define expected outputs for a conditional model
# Underlying log-infections are const value 1 for all time steps and
@@ -77,16 +64,10 @@ end
pos_shift = 1e-6

epimodel = Renewal(data)

rwp = random_walk_process()
obs_mdl = delay_observations_model()
# Call the function
test_mdl = make_epi_inference_model(
y_t,
epimodel,
random_walk,
delay_observations;
process_priors,
pos_shift,
)
test_mdl = make_epi_inference_model(y_t, epimodel, rwp, obs_mdl; pos_shift)

# Define expected outputs for a conditional model
# Underlying log-infections are const value 1 for all time steps and
2 changes: 1 addition & 1 deletion EpiAware/test/test_observation-processes.jl
Original file line number Diff line number Diff line change
@@ -11,7 +11,7 @@
observation_process_priors = default_delay_obs_priors()

# Call the function
mdl = delay_observations(
mdl = EpiAware.delay_observations(
missing,
I_t,
epimodel;