diff --git a/pipelines/generate_predictive.py b/pipelines/generate_predictive.py index 57b32542..a2ff3ce6 100644 --- a/pipelines/generate_predictive.py +++ b/pipelines/generate_predictive.py @@ -20,7 +20,12 @@ def generate_and_save_predictions( model_dir = Path(model_run_dir, model_name) if not model_dir.exists(): raise FileNotFoundError(f"The directory {model_dir} does not exist.") - (my_model, my_data) = build_model_from_dir(model_run_dir) + (my_model, my_data) = build_model_from_dir( + model_run_dir, + sample_ed_visits=predict_ed_visits, + sample_hospital_admissions=predict_hospital_admissions, + sample_wastewater=predict_wastewater, + ) my_model._init_model(1, 1) fresh_sampler = my_model.mcmc.sampler @@ -33,6 +38,7 @@ def generate_and_save_predictions( my_model.mcmc.sampler = fresh_sampler forecast_data = my_data.to_forecast_data(n_forecast_points) + print(forecast_data) posterior_predictive = my_model.posterior_predictive( data=forecast_data,