From 141dcd126975aac88a5b9c6698a7873a2d81469f Mon Sep 17 00:00:00 2001 From: Claude Date: Mon, 10 Nov 2025 20:12:49 +0000 Subject: [PATCH 01/51] Standardise S3 interface for estimate_secondary (#1142) This commit implements the S3 design standardisation described in inst/dev/design_s3.md for estimate_secondary(), completing the work started in #1144 for estimate_infections(). - Return structure now matches estimate_infections: fit, args, observations - Added S3 class hierarchy: c("epinowfit", "estimate_secondary", class(ret)) - Updated documentation to reflect new accessor-based interface - Added get_samples.estimate_secondary() to extract posterior samples - Added get_predictions.estimate_secondary() to get model predictions - Added summary.estimate_secondary() to summarise results - Updated plot.estimate_secondary() to use get_predictions() accessor - Updated forecast_secondary() to use new accessor methods - Updated all tests for estimate_secondary() to use new structure - Tests now use accessor methods (get_predictions(), get_samples(), summary()) - All existing functionality preserved through accessor methods This is a breaking change. Code that directly accesses the return structure will need to be updated: - Old: result$predictions -> New: get_predictions(result) - Old: result$posterior -> New: get_samples(result) - Old: result$data -> New: result$args The estimate_secondary function now shares a consistent S3 interface with estimate_infections, following the design specification. Note: Using get_predictions() instead of predict() to avoid confusion with standard R predict() methods which typically take newdata for out-of-sample prediction. --- R/estimate_secondary.R | 43 ++++++++------------ R/get.R | 52 ++++++++++++++++++++++++ R/summarise.R | 17 ++++++++ tests/testthat/test-estimate_secondary.R | 35 ++++++++++------ 4 files changed, 109 insertions(+), 38 deletions(-) diff --git a/R/estimate_secondary.R b/R/estimate_secondary.R index 915994181..5a9521e22 100644 --- a/R/estimate_secondary.R +++ b/R/estimate_secondary.R @@ -56,11 +56,11 @@ #' @param verbose Logical, should model fitting progress be returned. Defaults #' to [interactive()]. #' -#' @return A list containing: `predictions` (a `` ordered by date -#' with the primary, and secondary observations, and a summary of the model -#' estimated secondary observations), `posterior` which contains a summary of -#' the entire model posterior, `data` (a list of data used to fit the -#' model), and `fit` (the `stanfit` object). +#' @return An `` object which is a list of outputs +#' including: the stan object (`fit`), arguments used to fit the model +#' (`args`), and the observed data (`observations`). Use `summary()` to access +#' estimates, `get_samples()` to extract posterior samples, and +#' `get_predictions()` to access predictions. #' @export #' @inheritParams estimate_infections #' @inheritParams update_secondary_args @@ -266,22 +266,15 @@ estimate_secondary <- function(data, fit <- fit_model(stan_, id = "estimate_secondary") - out <- list() - out$predictions <- extract_stan_param(fit, "sim_secondary", CrIs = CrIs) - out$predictions <- out$predictions[, lapply(.SD, round, 1)] - out$predictions <- out$predictions[, date := reports[(burn_in + 1):.N]$date] - out$predictions <- data.table::merge.data.table( - reports, out$predictions, - all = TRUE, by = "date" - ) - out$posterior <- extract_stan_param( - fit, - CrIs = CrIs + # Create standardized S3 return structure + ret <- list( + fit = fit, + args = stan_data, + observations = reports ) - out$data <- stan_data - out$fit <- fit - class(out) <- c("estimate_secondary", class(out)) - return(out) + + class(ret) <- c("epinowfit", "estimate_secondary", class(ret)) + return(ret) } #' Update estimate_secondary default priors @@ -381,7 +374,7 @@ plot.estimate_secondary <- function(x, primary = FALSE, from = NULL, to = NULL, new_obs = NULL, ...) { - predictions <- data.table::copy(x$predictions) + predictions <- data.table::copy(get_predictions(x)) if (!is.null(new_obs)) { new_obs <- data.table::as.data.table(new_obs) @@ -623,7 +616,7 @@ forecast_secondary <- function(estimate, if (inherits(primary, "estimate_infections")) { primary_samples <- get_samples(primary) primary <- primary_samples[variable == primary_variable] - primary <- primary[date > max(estimate$predictions$date, na.rm = TRUE)] + primary <- primary[date > max(get_predictions(estimate)$date, na.rm = TRUE)] primary <- primary[, .(date, sample, value)] if (!is.null(samples)) { primary <- primary[sample(.N, samples, replace = TRUE)] @@ -641,10 +634,10 @@ forecast_secondary <- function(estimate, include = FALSE ) # extract data from stanfit - stan_data <- estimate$data + stan_data <- estimate$args # combined primary from data and input primary - primary_fit <- estimate$predictions[ + primary_fit <- get_predictions(estimate)[ , .(date, value = primary, sample = list(unique(updated_primary$sample))) ] @@ -721,7 +714,7 @@ forecast_secondary <- function(estimate, # link previous prediction observations with forecast observations forecast_obs <- data.table::rbindlist( list( - estimate$predictions[, .(date, primary, secondary)], + get_predictions(estimate)[, .(date, primary, secondary)], data.table::copy(primary)[, .(primary = median(value)), by = "date"] ), use.names = TRUE, fill = TRUE diff --git a/R/get.R b/R/get.R index 850c68f49..7b2a98f4d 100644 --- a/R/get.R +++ b/R/get.R @@ -265,3 +265,55 @@ get_samples.estimate_infections <- function(object, ...) { get_samples.forecast_infections <- function(object, ...) { object$samples } + +#' @rdname get_samples +#' @export +get_samples.estimate_secondary <- function(object, ...) { + # Extract posterior samples from the fit + extract_stan_param(object$fit, CrIs = c(0.2, 0.5, 0.9)) +} + +#' Get predictions from a fitted secondary model +#' +#' @description `r lifecycle::badge("stable")` +#' Extracts predictions from a fitted secondary model, combining observations +#' with model estimates. +#' +#' @param object A fitted model object from `estimate_secondary()` +#' @param CrIs Numeric vector of credible intervals to return. Defaults to +#' c(0.2, 0.5, 0.9). +#' @param ... Additional arguments (currently unused) +#' +#' @return A `data.table` with columns: date, primary, secondary, and summary +#' statistics (mean, sd, credible intervals) for the model predictions. +#' +#' @export +#' @examples +#' \dontrun{ +#' # After fitting a model +#' predictions <- get_predictions(fit) +#' } +get_predictions <- function(object, ...) { + UseMethod("get_predictions") +} + +#' @rdname get_predictions +#' @export +get_predictions.estimate_secondary <- function(object, CrIs = c(0.2, 0.5, 0.9), ...) { + # Extract predictions from the fit + predictions <- extract_stan_param(object$fit, "sim_secondary", CrIs = CrIs) + predictions <- predictions[, lapply(.SD, round, 1)] + + # Add dates based on burn_in + burn_in <- object$args$burn_in + predictions <- predictions[, date := object$observations[(burn_in + 1):.N]$date] + + # Merge with observations + predictions <- data.table::merge.data.table( + object$observations, predictions, + all = TRUE, by = "date" + ) + + return(predictions) +>>>>>>> 84e3c31e (Standardise S3 interface for estimate_secondary (#1142)) +} diff --git a/R/summarise.R b/R/summarise.R index cefa6b7f3..ff3894ad4 100644 --- a/R/summarise.R +++ b/R/summarise.R @@ -885,3 +885,20 @@ summary.forecast_infections <- function(object, print.epinowfit <- function(x, ...) { print(summary(x)) } + +#' Summarise results from estimate_secondary +#' +#' @description `r lifecycle::badge("stable")` +#' Returns a summary of the fitted secondary model including posterior +#' parameter estimates. +#' +#' @param object A fitted model object from `estimate_secondary()` +#' @param CrIs Numeric vector of credible intervals to return. Defaults to +#' c(0.2, 0.5, 0.9). +#' @param ... Additional arguments (currently unused) +#' +#' @return A `` of summary output +#' @export +summary.estimate_secondary <- function(object, CrIs = c(0.2, 0.5, 0.9), ...) { + get_samples(object) +} diff --git a/tests/testthat/test-estimate_secondary.R b/tests/testthat/test-estimate_secondary.R index e0e019f5a..44fc19110 100644 --- a/tests/testthat/test-estimate_secondary.R +++ b/tests/testthat/test-estimate_secondary.R @@ -35,7 +35,7 @@ params <- c( "scaling" = "params[1]" ) -inc_posterior <- inc$posterior[variable %in% params] +inc_posterior <- get_samples(inc)[variable %in% params] # fit model to example data with a fixed delay inc_fixed <- estimate_secondary(inc_cases[1:60], @@ -66,20 +66,29 @@ prev <- estimate_secondary(prev_cases[1:100], ) # extract posterior parameters of interest -prev_posterior <- prev$posterior[variable %in% params] +prev_posterior <- get_samples(prev)[variable %in% params] # Test output test_that("estimate_secondary can return values from simulated data and plot them", { - expect_equal(names(inc), c("predictions", "posterior", "data", "fit")) + expect_equal(names(inc), c("fit", "args", "observations")) + expect_s3_class(inc, "estimate_secondary") + expect_s3_class(inc, "epinowfit") + + # Test accessor methods + predictions <- get_predictions(inc) expect_equal( - names(inc$predictions), + names(predictions), c( "date", "primary", "secondary", "accumulate", "mean", "se_mean", "sd", "lower_90", "lower_50", "lower_20", "median", "upper_20", "upper_50", "upper_90" ) ) - expect_true(is.list(inc$data)) + + posterior <- get_samples(inc) + expect_true(is.data.frame(posterior)) + + expect_true(is.list(inc$args)) # validation plot of observations vs estimates expect_error(plot(inc, primary = TRUE), NA) }) @@ -105,8 +114,8 @@ test_that("estimate_secondary successfully returns estimates when passed NA valu obs = obs_opts(scale = Normal(mean = 0.2, sd = 0.2), week_effect = FALSE), verbose = FALSE ) - expect_true(is.list(inc_na$data)) - expect_true(is.list(prev_na$data)) + expect_true(is.list(inc_na$args)) + expect_true(is.list(prev_na$args)) }) test_that("estimate_secondary successfully returns estimates when accumulating to weekly", { @@ -132,7 +141,7 @@ test_that("estimate_secondary successfully returns estimates when accumulating t scale = Normal(mean = 0.4, sd = 0.05), week_effect = FALSE ), verbose = FALSE ) - expect_true(is.list(inc_weekly$data)) + expect_true(is.list(inc_weekly$args)) }) test_that("estimate_secondary works when only estimating scaling", { @@ -141,7 +150,7 @@ test_that("estimate_secondary works when only estimating scaling", { delay = delay_opts(), verbose = FALSE ) - expect_equal(names(inc), c("predictions", "posterior", "data", "fit")) + expect_equal(names(inc), c("fit", "args", "observations")) }) test_that("estimate_secondary can recover simulated parameters", { @@ -172,7 +181,7 @@ test_that("estimate_secondary can recover simulated parameters with the verbose = FALSE, stan = stan_opts(backend = "cmdstanr") ) ))) - inc_posterior_cmdstanr <- inc_cmdstanr$posterior[variable %in% params] + inc_posterior_cmdstanr <- get_samples(inc_cmdstanr)[variable %in% params] expect_equal( inc_posterior_cmdstanr[, mean], c(1.8, 0.5, 0.4), tolerance = 0.1 @@ -243,8 +252,8 @@ test_that("estimate_secondary works with filter_leading_zeros set", { verbose = FALSE )) expect_s3_class(out, "estimate_secondary") - expect_named(out, c("predictions", "posterior", "data", "fit")) - expect_equal(out$predictions$primary, modified_data$primary) + expect_named(out, c("fit", "args", "observations")) + expect_equal(get_predictions(out)$primary, modified_data$primary) }) test_that("estimate_secondary works with zero_threshold set", { @@ -264,5 +273,5 @@ test_that("estimate_secondary works with zero_threshold set", { verbose = FALSE ) expect_s3_class(out, "estimate_secondary") - expect_named(out, c("predictions", "posterior", "data", "fit")) + expect_named(out, c("fit", "args", "observations")) }) From b7761f611894f492819b715680e6e3aab05a4caf Mon Sep 17 00:00:00 2001 From: Claude Date: Mon, 10 Nov 2025 22:09:55 +0000 Subject: [PATCH 02/51] Add get_predictions() for estimate_infections to complete S3 interface (#1142) This commit adds get_predictions() for estimate_infections() to provide a consistent interface across both estimate_infections() and estimate_secondary(). ## Changes ### New S3 Method - Added get_predictions.estimate_infections() - extracts predicted reported cases from the model, merging them with observations ### Documentation Updates - Updated get_predictions() generic documentation to cover both model types - Updated estimate_infections() @return documentation to mention get_predictions() alongside get_samples() and summary() ### Tests - Updated test_estimate_infections() helper to test get_predictions() - Verifies it returns expected columns (date, confirm, mean) ## Consistent Interface Now both estimate_infections() and estimate_secondary() share the same accessor pattern: - get_samples() - raw posterior samples - get_predictions() - predicted observable outcomes - summary() - summarized estimates This provides a unified S3 interface as specified in the design document. --- R/estimate_infections.R | 5 +-- R/get.R | 38 ++++++++++++++++++++--- tests/testthat/test-estimate_infections.R | 8 +++++ 3 files changed, 44 insertions(+), 7 deletions(-) diff --git a/R/estimate_infections.R b/R/estimate_infections.R index 1a29395d7..bb3ddfce0 100644 --- a/R/estimate_infections.R +++ b/R/estimate_infections.R @@ -73,8 +73,9 @@ #' @export #' @return An `` object which is a list of outputs #' including: the stan object (`fit`), arguments used to fit the model -#' (`args`), and the observed data (`observations`). The estimates included in -#' the fit can be accessed using the `summary()` function. +#' (`args`), and the observed data (`observations`). Use `summary()` to access +#' estimates, `get_samples()` to extract posterior samples, and +#' `get_predictions()` to access predicted reported cases. #' #' @seealso [epinow()] [regional_epinow()] [forecast_infections()] #' [estimate_truncation()] diff --git a/R/get.R b/R/get.R index 7b2a98f4d..973c0ce5d 100644 --- a/R/get.R +++ b/R/get.R @@ -273,18 +273,20 @@ get_samples.estimate_secondary <- function(object, ...) { extract_stan_param(object$fit, CrIs = c(0.2, 0.5, 0.9)) } -#' Get predictions from a fitted secondary model +#' Get predictions from a fitted model #' #' @description `r lifecycle::badge("stable")` -#' Extracts predictions from a fitted secondary model, combining observations -#' with model estimates. +#' Extracts predictions from a fitted model, combining observations with model +#' estimates. For `estimate_infections()` returns predicted reported cases, for +#' `estimate_secondary()` returns predicted secondary observations. #' -#' @param object A fitted model object from `estimate_secondary()` +#' @param object A fitted model object (e.g., from `estimate_infections()` or +#' `estimate_secondary()`) #' @param CrIs Numeric vector of credible intervals to return. Defaults to #' c(0.2, 0.5, 0.9). #' @param ... Additional arguments (currently unused) #' -#' @return A `data.table` with columns: date, primary, secondary, and summary +#' @return A `data.table` with columns including date, observations, and summary #' statistics (mean, sd, credible intervals) for the model predictions. #' #' @export @@ -297,6 +299,32 @@ get_predictions <- function(object, ...) { UseMethod("get_predictions") } +#' @rdname get_predictions +#' @export +get_predictions.estimate_infections <- function(object, CrIs = c(0.2, 0.5, 0.9), ...) { + # Get samples for reported cases + samples <- get_samples(object) + reported_samples <- samples[variable == "reported_cases"] + + # Calculate summary measures + predictions <- calc_summary_measures( + reported_samples, + summarise_by = "date", + order_by = "date", + CrIs = CrIs + ) + + # Merge with observations + predictions <- data.table::merge.data.table( + object$observations[, .(date, confirm)], + predictions, + by = "date", + all = TRUE + ) + + return(predictions) +} + #' @rdname get_predictions #' @export get_predictions.estimate_secondary <- function(object, CrIs = c(0.2, 0.5, 0.9), ...) { diff --git a/tests/testthat/test-estimate_infections.R b/tests/testthat/test-estimate_infections.R index 83d1a4c0e..94a4db312 100644 --- a/tests/testthat/test-estimate_infections.R +++ b/tests/testthat/test-estimate_infections.R @@ -29,6 +29,14 @@ test_estimate_infections <- function(...) { expect_true(nrow(get_samples(out)) > 0) expect_true(nrow(summary(out, type = "parameters")) > 0) expect_true(nrow(out$observations) > 0) + + # Test get_predictions accessor + predictions <- get_predictions(out) + expect_true(nrow(predictions) > 0) + expect_true("date" %in% names(predictions)) + expect_true("confirm" %in% names(predictions)) + expect_true("mean" %in% names(predictions)) + invisible(out) } From 4eef3572ba2599ab10b0e89bf95e7bb74bd648fa Mon Sep 17 00:00:00 2001 From: Claude Date: Mon, 10 Nov 2025 22:11:59 +0000 Subject: [PATCH 03/51] Fix summary.estimate_secondary() to return summary statistics instead of raw samples Changed summary.estimate_secondary() to call extract_stan_param() which returns summary statistics (mean, sd, median, credible intervals) for all model parameters, rather than forwarding to get_samples() which returns raw posterior samples. This makes the behavior consistent with summary.estimate_infections() and provides a clearer distinction between the accessor methods: - get_samples() - raw posterior samples (individual draws) - summary() - summary statistics of parameters (mean, sd, CrIs) - get_predictions() - predictions merged with observations --- R/summarise.R | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/R/summarise.R b/R/summarise.R index ff3894ad4..ae0a740f5 100644 --- a/R/summarise.R +++ b/R/summarise.R @@ -890,15 +890,17 @@ print.epinowfit <- function(x, ...) { #' #' @description `r lifecycle::badge("stable")` #' Returns a summary of the fitted secondary model including posterior -#' parameter estimates. +#' parameter estimates with credible intervals. #' #' @param object A fitted model object from `estimate_secondary()` #' @param CrIs Numeric vector of credible intervals to return. Defaults to #' c(0.2, 0.5, 0.9). #' @param ... Additional arguments (currently unused) #' -#' @return A `` of summary output +#' @return A `` with summary statistics (mean, sd, median, +#' credible intervals) for all model parameters #' @export summary.estimate_secondary <- function(object, CrIs = c(0.2, 0.5, 0.9), ...) { - get_samples(object) + # Extract all parameters with summary statistics + extract_stan_param(object$fit, CrIs = CrIs) } From 9c4a6d39ab822ffe49cf0884ba60a69da5b767cd Mon Sep 17 00:00:00 2001 From: Claude Date: Mon, 10 Nov 2025 22:16:14 +0000 Subject: [PATCH 04/51] Add type and params arguments to summary.estimate_secondary() for consistency Modified summary.estimate_secondary() to match the interface of summary.estimate_infections(): - Added 'type' argument with options "snapshot" (default) and "parameters" - Default (type = "snapshot") returns only key parameters (delay_params, params, frac_obs) for a concise summary - type = "parameters" returns all parameters - Added 'params' argument to filter specific parameters when type = "parameters" This provides consistent default behavior across both summary methods: - estimate_infections(): defaults to concise snapshot (~6 lines) - estimate_secondary(): now also defaults to concise snapshot (key params only) Users can get full parameter tables with summary(fit, type = "parameters"). --- R/summarise.R | 33 ++++++++++++++++++++++++++++++--- 1 file changed, 30 insertions(+), 3 deletions(-) diff --git a/R/summarise.R b/R/summarise.R index ae0a740f5..dfb639dec 100644 --- a/R/summarise.R +++ b/R/summarise.R @@ -893,14 +893,41 @@ print.epinowfit <- function(x, ...) { #' parameter estimates with credible intervals. #' #' @param object A fitted model object from `estimate_secondary()` +#' @param type Character string indicating the type of summary to return. +#' Options are "snapshot" (default, key parameters only) or "parameters" +#' (all parameters). +#' @param params Character vector of parameter names to include. Only used +#' when `type = "parameters"`. If NULL (default), returns all parameters. #' @param CrIs Numeric vector of credible intervals to return. Defaults to #' c(0.2, 0.5, 0.9). #' @param ... Additional arguments (currently unused) #' #' @return A `` with summary statistics (mean, sd, median, -#' credible intervals) for all model parameters +#' credible intervals) for model parameters. When `type = "snapshot"`, +#' returns only key parameters (delays, scaling). When `type = "parameters"`, +#' returns all or filtered parameters. +#' @importFrom rlang arg_match #' @export -summary.estimate_secondary <- function(object, CrIs = c(0.2, 0.5, 0.9), ...) { +summary.estimate_secondary <- function(object, + type = c("snapshot", "parameters"), + params = NULL, + CrIs = c(0.2, 0.5, 0.9), ...) { + type <- arg_match(type) + # Extract all parameters with summary statistics - extract_stan_param(object$fit, CrIs = CrIs) + out <- extract_stan_param(object$fit, CrIs = CrIs) + + if (type == "snapshot") { + # Return only key parameters for a concise summary + # Typical parameters: delay_params (distribution parameters), params (scaling, etc.) + key_vars <- c("delay_params", "params", "frac_obs") + out <- out[grepl(paste(key_vars, collapse = "|"), variable)] + } else if (type == "parameters") { + # Optional filtering by parameter name + if (!is.null(params)) { + out <- out[variable %in% params] + } + } + + return(out[]) } From b6ae503b77cb1e9c9965e6ce1891818fd4313ddb Mon Sep 17 00:00:00 2001 From: Claude Date: Mon, 10 Nov 2025 22:20:37 +0000 Subject: [PATCH 05/51] Rename summary type from 'snapshot' to 'estimates' for estimate_secondary The term 'snapshot' is specific to real-time epidemic monitoring in estimate_infections() (showing R, growth rate, doubling time). For estimate_secondary(), we're showing estimated model parameters (delay distributions and scaling factors), so 'estimates' is more appropriate. --- R/summarise.R | 20 +++++++++++--------- 1 file changed, 11 insertions(+), 9 deletions(-) diff --git a/R/summarise.R b/R/summarise.R index dfb639dec..522c3063e 100644 --- a/R/summarise.R +++ b/R/summarise.R @@ -894,8 +894,8 @@ print.epinowfit <- function(x, ...) { #' #' @param object A fitted model object from `estimate_secondary()` #' @param type Character string indicating the type of summary to return. -#' Options are "snapshot" (default, key parameters only) or "parameters" -#' (all parameters). +#' Options are "estimates" (default, key estimated parameters only) or +#' "parameters" (all parameters or filtered set). #' @param params Character vector of parameter names to include. Only used #' when `type = "parameters"`. If NULL (default), returns all parameters. #' @param CrIs Numeric vector of credible intervals to return. Defaults to @@ -903,13 +903,14 @@ print.epinowfit <- function(x, ...) { #' @param ... Additional arguments (currently unused) #' #' @return A `` with summary statistics (mean, sd, median, -#' credible intervals) for model parameters. When `type = "snapshot"`, -#' returns only key parameters (delays, scaling). When `type = "parameters"`, -#' returns all or filtered parameters. +#' credible intervals) for model parameters. When `type = "estimates"`, +#' returns only key estimated parameters (delay distribution parameters and +#' scaling factors). When `type = "parameters"`, returns all or filtered +#' parameters. #' @importFrom rlang arg_match #' @export summary.estimate_secondary <- function(object, - type = c("snapshot", "parameters"), + type = c("estimates", "parameters"), params = NULL, CrIs = c(0.2, 0.5, 0.9), ...) { type <- arg_match(type) @@ -917,9 +918,10 @@ summary.estimate_secondary <- function(object, # Extract all parameters with summary statistics out <- extract_stan_param(object$fit, CrIs = CrIs) - if (type == "snapshot") { - # Return only key parameters for a concise summary - # Typical parameters: delay_params (distribution parameters), params (scaling, etc.) + if (type == "estimates") { + # Return only key estimated parameters for a concise summary + # Typical parameters: delay_params (distribution parameters), + # params (scaling factors) key_vars <- c("delay_params", "params", "frac_obs") out <- out[grepl(paste(key_vars, collapse = "|"), variable)] } else if (type == "parameters") { From cee929063b425dd2f6de43daaa7cfcf74adccac4 Mon Sep 17 00:00:00 2001 From: Claude Date: Tue, 11 Nov 2025 09:33:08 +0000 Subject: [PATCH 06/51] Rename summary type from 'estimates' to 'compact' for estimate_secondary Use 'compact' instead of 'estimates' to better distinguish the two modes: - type = 'compact': Shows only key parameters (default, concise view) - type = 'parameters': Shows all parameters or filtered subset The term 'compact' clearly describes the format difference rather than implying semantic distinction between types of estimates. --- R/summarise.R | 17 ++++++++--------- 1 file changed, 8 insertions(+), 9 deletions(-) diff --git a/R/summarise.R b/R/summarise.R index 522c3063e..c6bfe0a8d 100644 --- a/R/summarise.R +++ b/R/summarise.R @@ -894,8 +894,8 @@ print.epinowfit <- function(x, ...) { #' #' @param object A fitted model object from `estimate_secondary()` #' @param type Character string indicating the type of summary to return. -#' Options are "estimates" (default, key estimated parameters only) or -#' "parameters" (all parameters or filtered set). +#' Options are "compact" (default, key parameters only) or "parameters" +#' (all parameters or filtered set). #' @param params Character vector of parameter names to include. Only used #' when `type = "parameters"`. If NULL (default), returns all parameters. #' @param CrIs Numeric vector of credible intervals to return. Defaults to @@ -903,14 +903,13 @@ print.epinowfit <- function(x, ...) { #' @param ... Additional arguments (currently unused) #' #' @return A `` with summary statistics (mean, sd, median, -#' credible intervals) for model parameters. When `type = "estimates"`, -#' returns only key estimated parameters (delay distribution parameters and -#' scaling factors). When `type = "parameters"`, returns all or filtered -#' parameters. +#' credible intervals) for model parameters. When `type = "compact"`, +#' returns only key parameters (delay distribution parameters and scaling +#' factors). When `type = "parameters"`, returns all or filtered parameters. #' @importFrom rlang arg_match #' @export summary.estimate_secondary <- function(object, - type = c("estimates", "parameters"), + type = c("compact", "parameters"), params = NULL, CrIs = c(0.2, 0.5, 0.9), ...) { type <- arg_match(type) @@ -918,8 +917,8 @@ summary.estimate_secondary <- function(object, # Extract all parameters with summary statistics out <- extract_stan_param(object$fit, CrIs = CrIs) - if (type == "estimates") { - # Return only key estimated parameters for a concise summary + if (type == "compact") { + # Return only key parameters for a compact summary # Typical parameters: delay_params (distribution parameters), # params (scaling factors) key_vars <- c("delay_params", "params", "frac_obs") From 4e8fe0f8c81bdcff78eea85cdc5aae51afec171b Mon Sep 17 00:00:00 2001 From: Claude Date: Tue, 11 Nov 2025 10:21:25 +0000 Subject: [PATCH 07/51] Export S3 methods for estimate_secondary accessors and summary Add missing NAMESPACE exports for the new S3 interface: - get_samples.estimate_secondary() - get_predictions generic and methods for both estimate_infections and estimate_secondary - summary.estimate_secondary() These methods were defined but not exported, causing S3 dispatch to fail. --- NAMESPACE | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/NAMESPACE b/NAMESPACE index c51c4e62f..5e8d00071 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -10,8 +10,11 @@ S3method(discretise,dist_spec) S3method(discretise,multi_dist_spec) S3method(fix_parameters,dist_spec) S3method(fix_parameters,multi_dist_spec) +S3method(get_predictions,estimate_infections) +S3method(get_predictions,estimate_secondary) S3method(get_samples,estimate_infections) S3method(get_samples,forecast_infections) +S3method(get_samples,estimate_secondary) S3method(is_constrained,dist_spec) S3method(is_constrained,multi_dist_spec) S3method(max,dist_spec) @@ -32,6 +35,7 @@ S3method(sd,multi_dist_spec) S3method(summary,epinow) S3method(summary,estimate_infections) S3method(summary,forecast_infections) +S3method(summary,estimate_secondary) export(Fixed) export(Gamma) export(LogNormal) @@ -79,6 +83,7 @@ export(generation_time_opts) export(get_distribution) export(get_parameters) export(get_pmf) +export(get_predictions) export(get_regional_results) export(get_samples) export(gp_opts) From c26cc1d065a11b384f954724590510bf9fb3831e Mon Sep 17 00:00:00 2001 From: Claude Date: Tue, 11 Nov 2025 11:32:10 +0000 Subject: [PATCH 08/51] Add @method tag for summary.estimate_secondary Add the @method roxygen tag so roxygen2 correctly generates S3method(summary,estimate_secondary) in NAMESPACE instead of exporting it as a regular function. --- R/get.R | 1 - R/summarise.R | 1 + 2 files changed, 1 insertion(+), 1 deletion(-) diff --git a/R/get.R b/R/get.R index 973c0ce5d..a7c366e46 100644 --- a/R/get.R +++ b/R/get.R @@ -343,5 +343,4 @@ get_predictions.estimate_secondary <- function(object, CrIs = c(0.2, 0.5, 0.9), ) return(predictions) ->>>>>>> 84e3c31e (Standardise S3 interface for estimate_secondary (#1142)) } diff --git a/R/summarise.R b/R/summarise.R index c6bfe0a8d..7c3edb645 100644 --- a/R/summarise.R +++ b/R/summarise.R @@ -907,6 +907,7 @@ print.epinowfit <- function(x, ...) { #' returns only key parameters (delay distribution parameters and scaling #' factors). When `type = "parameters"`, returns all or filtered parameters. #' @importFrom rlang arg_match +#' @method summary estimate_secondary #' @export summary.estimate_secondary <- function(object, type = c("compact", "parameters"), From 3433201cd5aee06373df50f88e18a9892552e370 Mon Sep 17 00:00:00 2001 From: Claude Date: Thu, 13 Nov 2025 15:19:50 +0000 Subject: [PATCH 09/51] Fix linting issues: line length and unnecessary nesting - Break long function signatures across multiple lines - Break long data.table assignment across lines - Combine nested if conditions in summary method --- R/get.R | 12 +++++++++--- R/summarise.R | 6 ++---- 2 files changed, 11 insertions(+), 7 deletions(-) diff --git a/R/get.R b/R/get.R index a7c366e46..768a30ed9 100644 --- a/R/get.R +++ b/R/get.R @@ -301,7 +301,9 @@ get_predictions <- function(object, ...) { #' @rdname get_predictions #' @export -get_predictions.estimate_infections <- function(object, CrIs = c(0.2, 0.5, 0.9), ...) { +get_predictions.estimate_infections <- function(object, + CrIs = c(0.2, 0.5, 0.9), + ...) { # Get samples for reported cases samples <- get_samples(object) reported_samples <- samples[variable == "reported_cases"] @@ -327,14 +329,18 @@ get_predictions.estimate_infections <- function(object, CrIs = c(0.2, 0.5, 0.9), #' @rdname get_predictions #' @export -get_predictions.estimate_secondary <- function(object, CrIs = c(0.2, 0.5, 0.9), ...) { +get_predictions.estimate_secondary <- function(object, + CrIs = c(0.2, 0.5, 0.9), + ...) { # Extract predictions from the fit predictions <- extract_stan_param(object$fit, "sim_secondary", CrIs = CrIs) predictions <- predictions[, lapply(.SD, round, 1)] # Add dates based on burn_in burn_in <- object$args$burn_in - predictions <- predictions[, date := object$observations[(burn_in + 1):.N]$date] + predictions <- predictions[ + , date := object$observations[(burn_in + 1):.N]$date + ] # Merge with observations predictions <- data.table::merge.data.table( diff --git a/R/summarise.R b/R/summarise.R index 7c3edb645..f3d26c5dc 100644 --- a/R/summarise.R +++ b/R/summarise.R @@ -924,11 +924,9 @@ summary.estimate_secondary <- function(object, # params (scaling factors) key_vars <- c("delay_params", "params", "frac_obs") out <- out[grepl(paste(key_vars, collapse = "|"), variable)] - } else if (type == "parameters") { + } else if (type == "parameters" && !is.null(params)) { # Optional filtering by parameter name - if (!is.null(params)) { - out <- out[variable %in% params] - } + out <- out[variable %in% params] } return(out[]) From 7d59d4f178ab68c78d40cb54b26e6ebf5416ea2d Mon Sep 17 00:00:00 2001 From: Claude Date: Thu, 13 Nov 2025 15:26:16 +0000 Subject: [PATCH 10/51] Fix indentation linting issues in get_predictions functions Adjust hanging indent to match linter expectations: - get_predictions.estimate_infections: 49 -> 48 spaces - get_predictions.estimate_secondary: 48 -> 47 spaces --- R/get.R | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/R/get.R b/R/get.R index 768a30ed9..32500090e 100644 --- a/R/get.R +++ b/R/get.R @@ -302,8 +302,8 @@ get_predictions <- function(object, ...) { #' @rdname get_predictions #' @export get_predictions.estimate_infections <- function(object, - CrIs = c(0.2, 0.5, 0.9), - ...) { + CrIs = c(0.2, 0.5, 0.9), + ...) { # Get samples for reported cases samples <- get_samples(object) reported_samples <- samples[variable == "reported_cases"] @@ -330,8 +330,8 @@ get_predictions.estimate_infections <- function(object, #' @rdname get_predictions #' @export get_predictions.estimate_secondary <- function(object, - CrIs = c(0.2, 0.5, 0.9), - ...) { + CrIs = c(0.2, 0.5, 0.9), + ...) { # Extract predictions from the fit predictions <- extract_stan_param(object$fit, "sim_secondary", CrIs = CrIs) predictions <- predictions[, lapply(.SD, round, 1)] From 85bec8bdb34895c94c34cf44b483d0f1fbeb3867 Mon Sep 17 00:00:00 2001 From: "epiforecasts-workflows[bot]" Date: Thu, 13 Nov 2025 15:32:53 +0000 Subject: [PATCH 11/51] Update documentation --- NAMESPACE | 4 ++-- man/estimate_infections.Rd | 5 ++-- man/estimate_secondary.Rd | 10 ++++---- man/get_predictions.Rd | 39 ++++++++++++++++++++++++++++++ man/get_samples.Rd | 3 +++ man/summary.estimate_secondary.Rd | 40 +++++++++++++++++++++++++++++++ 6 files changed, 92 insertions(+), 9 deletions(-) create mode 100644 man/get_predictions.Rd create mode 100644 man/summary.estimate_secondary.Rd diff --git a/NAMESPACE b/NAMESPACE index 5e8d00071..60185bd3c 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -13,8 +13,8 @@ S3method(fix_parameters,multi_dist_spec) S3method(get_predictions,estimate_infections) S3method(get_predictions,estimate_secondary) S3method(get_samples,estimate_infections) -S3method(get_samples,forecast_infections) S3method(get_samples,estimate_secondary) +S3method(get_samples,forecast_infections) S3method(is_constrained,dist_spec) S3method(is_constrained,multi_dist_spec) S3method(max,dist_spec) @@ -34,8 +34,8 @@ S3method(sd,dist_spec) S3method(sd,multi_dist_spec) S3method(summary,epinow) S3method(summary,estimate_infections) -S3method(summary,forecast_infections) S3method(summary,estimate_secondary) +S3method(summary,forecast_infections) export(Fixed) export(Gamma) export(LogNormal) diff --git a/man/estimate_infections.Rd b/man/estimate_infections.Rd index c16820ab7..c15390315 100644 --- a/man/estimate_infections.Rd +++ b/man/estimate_infections.Rd @@ -106,8 +106,9 @@ horizon} \value{ An \verb{} object which is a list of outputs including: the stan object (\code{fit}), arguments used to fit the model -(\code{args}), and the observed data (\code{observations}). The estimates included in -the fit can be accessed using the \code{summary()} function. +(\code{args}), and the observed data (\code{observations}). Use \code{summary()} to access +estimates, \code{get_samples()} to extract posterior samples, and +\code{get_predictions()} to access predicted reported cases. } \description{ \ifelse{html}{\href{https://lifecycle.r-lib.org/articles/stages.html#maturing}{\figure{lifecycle-maturing.svg}{options: alt='[Maturing]'}}}{\strong{[Maturing]}} diff --git a/man/estimate_secondary.Rd b/man/estimate_secondary.Rd index d003b4d04..4788d2f27 100644 --- a/man/estimate_secondary.Rd +++ b/man/estimate_secondary.Rd @@ -91,11 +91,11 @@ number of cases based on the 7-day average. If the average is above this threshold then the zero is replaced using \code{fill}.} } \value{ -A list containing: \code{predictions} (a \verb{} ordered by date -with the primary, and secondary observations, and a summary of the model -estimated secondary observations), \code{posterior} which contains a summary of -the entire model posterior, \code{data} (a list of data used to fit the -model), and \code{fit} (the \code{stanfit} object). +An \verb{} object which is a list of outputs +including: the stan object (\code{fit}), arguments used to fit the model +(\code{args}), and the observed data (\code{observations}). Use \code{summary()} to access +estimates, \code{get_samples()} to extract posterior samples, and +\code{get_predictions()} to access predictions. } \description{ \ifelse{html}{\href{https://lifecycle.r-lib.org/articles/stages.html#stable}{\figure{lifecycle-stable.svg}{options: alt='[Stable]'}}}{\strong{[Stable]}} diff --git a/man/get_predictions.Rd b/man/get_predictions.Rd new file mode 100644 index 000000000..4bc942117 --- /dev/null +++ b/man/get_predictions.Rd @@ -0,0 +1,39 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/get.R +\name{get_predictions} +\alias{get_predictions} +\alias{get_predictions.estimate_infections} +\alias{get_predictions.estimate_secondary} +\title{Get predictions from a fitted model} +\usage{ +get_predictions(object, ...) + +\method{get_predictions}{estimate_infections}(object, CrIs = c(0.2, 0.5, 0.9), ...) + +\method{get_predictions}{estimate_secondary}(object, CrIs = c(0.2, 0.5, 0.9), ...) +} +\arguments{ +\item{object}{A fitted model object (e.g., from \code{estimate_infections()} or +\code{estimate_secondary()})} + +\item{...}{Additional arguments (currently unused)} + +\item{CrIs}{Numeric vector of credible intervals to return. Defaults to +c(0.2, 0.5, 0.9).} +} +\value{ +A \code{data.table} with columns including date, observations, and summary +statistics (mean, sd, credible intervals) for the model predictions. +} +\description{ +\ifelse{html}{\href{https://lifecycle.r-lib.org/articles/stages.html#stable}{\figure{lifecycle-stable.svg}{options: alt='[Stable]'}}}{\strong{[Stable]}} +Extracts predictions from a fitted model, combining observations with model +estimates. For \code{estimate_infections()} returns predicted reported cases, for +\code{estimate_secondary()} returns predicted secondary observations. +} +\examples{ +\dontrun{ +# After fitting a model +predictions <- get_predictions(fit) +} +} diff --git a/man/get_samples.Rd b/man/get_samples.Rd index d4ed3b890..cdb6ecaf6 100644 --- a/man/get_samples.Rd +++ b/man/get_samples.Rd @@ -4,6 +4,7 @@ \alias{get_samples} \alias{get_samples.estimate_infections} \alias{get_samples.forecast_infections} +\alias{get_samples.estimate_secondary} \title{Get posterior samples from a fitted model} \usage{ get_samples(object, ...) @@ -11,6 +12,8 @@ get_samples(object, ...) \method{get_samples}{estimate_infections}(object, ...) \method{get_samples}{forecast_infections}(object, ...) + +\method{get_samples}{estimate_secondary}(object, ...) } \arguments{ \item{object}{A fitted model object (e.g., from \code{estimate_infections()})} diff --git a/man/summary.estimate_secondary.Rd b/man/summary.estimate_secondary.Rd new file mode 100644 index 000000000..0ea9db8a6 --- /dev/null +++ b/man/summary.estimate_secondary.Rd @@ -0,0 +1,40 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/summarise.R +\name{summary.estimate_secondary} +\alias{summary.estimate_secondary} +\title{Summarise results from estimate_secondary} +\usage{ +\method{summary}{estimate_secondary}( + object, + type = c("compact", "parameters"), + params = NULL, + CrIs = c(0.2, 0.5, 0.9), + ... +) +} +\arguments{ +\item{object}{A fitted model object from \code{estimate_secondary()}} + +\item{type}{Character string indicating the type of summary to return. +Options are "compact" (default, key parameters only) or "parameters" +(all parameters or filtered set).} + +\item{params}{Character vector of parameter names to include. Only used +when \code{type = "parameters"}. If NULL (default), returns all parameters.} + +\item{CrIs}{Numeric vector of credible intervals to return. Defaults to +c(0.2, 0.5, 0.9).} + +\item{...}{Additional arguments (currently unused)} +} +\value{ +A \verb{} with summary statistics (mean, sd, median, +credible intervals) for model parameters. When \code{type = "compact"}, +returns only key parameters (delay distribution parameters and scaling +factors). When \code{type = "parameters"}, returns all or filtered parameters. +} +\description{ +\ifelse{html}{\href{https://lifecycle.r-lib.org/articles/stages.html#stable}{\figure{lifecycle-stable.svg}{options: alt='[Stable]'}}}{\strong{[Stable]}} +Returns a summary of the fitted secondary model including posterior +parameter estimates with credible intervals. +} From e611234199a09218f4fbe6fdc6852cf0cbe07c9b Mon Sep 17 00:00:00 2001 From: Claude Date: Thu, 13 Nov 2025 16:37:19 +0000 Subject: [PATCH 12/51] Add backward compatibility for forecast_secondary in get_predictions forecast_secondary() returns objects with the old structure ($predictions) rather than the new standardized structure ($fit, $args, $observations). Add check to handle both cases: return $predictions directly for forecast objects, extract from fit for fitted estimate_secondary objects. --- R/get.R | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/R/get.R b/R/get.R index 32500090e..a7c2d7abb 100644 --- a/R/get.R +++ b/R/get.R @@ -332,6 +332,11 @@ get_predictions.estimate_infections <- function(object, get_predictions.estimate_secondary <- function(object, CrIs = c(0.2, 0.5, 0.9), ...) { + # Handle forecast_secondary objects (old structure with $predictions) + if (!is.null(object$predictions) && is.null(object$fit)) { + return(object$predictions) + } + # Extract predictions from the fit predictions <- extract_stan_param(object$fit, "sim_secondary", CrIs = CrIs) predictions <- predictions[, lapply(.SD, round, 1)] From 0b92ea8eab543963d32f8a25a429a0d3b2cd4c93 Mon Sep 17 00:00:00 2001 From: Claude Date: Thu, 13 Nov 2025 16:39:35 +0000 Subject: [PATCH 13/51] Give forecast_secondary its own S3 class and methods forecast_secondary() now returns a 'forecast_secondary' object instead of 'estimate_secondary', matching the pattern used for forecast_infections(). Changes: - Change class from estimate_secondary to forecast_secondary - Add get_samples.forecast_secondary() method - Add get_predictions.forecast_secondary() method - Remove backward compatibility hack from get_predictions.estimate_secondary() - Export new S3 methods in NAMESPACE This properly separates fitted objects (estimate_secondary) from forecast objects (forecast_secondary), each with appropriate accessor methods. --- NAMESPACE | 2 ++ R/estimate_secondary.R | 2 +- R/get.R | 17 ++++++++++++----- 3 files changed, 15 insertions(+), 6 deletions(-) diff --git a/NAMESPACE b/NAMESPACE index 60185bd3c..30d3944d3 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -12,9 +12,11 @@ S3method(fix_parameters,dist_spec) S3method(fix_parameters,multi_dist_spec) S3method(get_predictions,estimate_infections) S3method(get_predictions,estimate_secondary) +S3method(get_predictions,forecast_secondary) S3method(get_samples,estimate_infections) S3method(get_samples,estimate_secondary) S3method(get_samples,forecast_infections) +S3method(get_samples,forecast_secondary) S3method(is_constrained,dist_spec) S3method(is_constrained,multi_dist_spec) S3method(max,dist_spec) diff --git a/R/estimate_secondary.R b/R/estimate_secondary.R index 5a9521e22..b007e95cb 100644 --- a/R/estimate_secondary.R +++ b/R/estimate_secondary.R @@ -728,6 +728,6 @@ forecast_secondary <- function(estimate, data.table::setcolorder( out$predictions, c("date", "primary", "secondary", "mean", "sd") ) - class(out) <- c("estimate_secondary", class(out)) + class(out) <- c("forecast_secondary", class(out)) return(out) } diff --git a/R/get.R b/R/get.R index a7c2d7abb..a6034190f 100644 --- a/R/get.R +++ b/R/get.R @@ -273,6 +273,12 @@ get_samples.estimate_secondary <- function(object, ...) { extract_stan_param(object$fit, CrIs = c(0.2, 0.5, 0.9)) } +#' @rdname get_samples +#' @export +get_samples.forecast_secondary <- function(object, ...) { + object$samples +} + #' Get predictions from a fitted model #' #' @description `r lifecycle::badge("stable")` @@ -332,11 +338,6 @@ get_predictions.estimate_infections <- function(object, get_predictions.estimate_secondary <- function(object, CrIs = c(0.2, 0.5, 0.9), ...) { - # Handle forecast_secondary objects (old structure with $predictions) - if (!is.null(object$predictions) && is.null(object$fit)) { - return(object$predictions) - } - # Extract predictions from the fit predictions <- extract_stan_param(object$fit, "sim_secondary", CrIs = CrIs) predictions <- predictions[, lapply(.SD, round, 1)] @@ -355,3 +356,9 @@ get_predictions.estimate_secondary <- function(object, return(predictions) } + +#' @rdname get_predictions +#' @export +get_predictions.forecast_secondary <- function(object, ...) { + object$predictions +} From f4bab074d4d9654b8c84d1e7331175a163cbe9bd Mon Sep 17 00:00:00 2001 From: "epiforecasts-workflows[bot]" Date: Thu, 13 Nov 2025 16:45:58 +0000 Subject: [PATCH 14/51] Update documentation --- man/get_predictions.Rd | 3 +++ man/get_samples.Rd | 3 +++ 2 files changed, 6 insertions(+) diff --git a/man/get_predictions.Rd b/man/get_predictions.Rd index 4bc942117..9aa7d4bc2 100644 --- a/man/get_predictions.Rd +++ b/man/get_predictions.Rd @@ -4,6 +4,7 @@ \alias{get_predictions} \alias{get_predictions.estimate_infections} \alias{get_predictions.estimate_secondary} +\alias{get_predictions.forecast_secondary} \title{Get predictions from a fitted model} \usage{ get_predictions(object, ...) @@ -11,6 +12,8 @@ get_predictions(object, ...) \method{get_predictions}{estimate_infections}(object, CrIs = c(0.2, 0.5, 0.9), ...) \method{get_predictions}{estimate_secondary}(object, CrIs = c(0.2, 0.5, 0.9), ...) + +\method{get_predictions}{forecast_secondary}(object, ...) } \arguments{ \item{object}{A fitted model object (e.g., from \code{estimate_infections()} or diff --git a/man/get_samples.Rd b/man/get_samples.Rd index cdb6ecaf6..e4c9d7736 100644 --- a/man/get_samples.Rd +++ b/man/get_samples.Rd @@ -5,6 +5,7 @@ \alias{get_samples.estimate_infections} \alias{get_samples.forecast_infections} \alias{get_samples.estimate_secondary} +\alias{get_samples.forecast_secondary} \title{Get posterior samples from a fitted model} \usage{ get_samples(object, ...) @@ -14,6 +15,8 @@ get_samples(object, ...) \method{get_samples}{forecast_infections}(object, ...) \method{get_samples}{estimate_secondary}(object, ...) + +\method{get_samples}{forecast_secondary}(object, ...) } \arguments{ \item{object}{A fitted model object (e.g., from \code{estimate_infections()})} From dcd983b1bf684e095a810b50a54b6e641f71ce56 Mon Sep 17 00:00:00 2001 From: Claude Date: Fri, 14 Nov 2025 08:59:07 +0000 Subject: [PATCH 15/51] Add plot method for forecast_secondary objects forecast_secondary objects now have their own plot.forecast_secondary() method, matching the pattern for forecast_infections. The implementation is identical to plot.estimate_secondary() since both use get_predictions() which works for both classes. --- NAMESPACE | 1 + R/estimate_secondary.R | 59 ++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 60 insertions(+) diff --git a/NAMESPACE b/NAMESPACE index 30d3944d3..a5d2f664f 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -29,6 +29,7 @@ S3method(plot,estimate_infections) S3method(plot,estimate_secondary) S3method(plot,estimate_truncation) S3method(plot,forecast_infections) +S3method(plot,forecast_secondary) S3method(print,dist_spec) S3method(print,epinowfit) S3method(sd,default) diff --git a/R/estimate_secondary.R b/R/estimate_secondary.R index b007e95cb..1cafd1d9c 100644 --- a/R/estimate_secondary.R +++ b/R/estimate_secondary.R @@ -421,6 +421,65 @@ plot.estimate_secondary <- function(x, primary = FALSE, ggplot2::theme(axis.text.x = ggplot2::element_text(angle = 90)) } +#' Plot method for forecast_secondary objects +#' +#' @description `r lifecycle::badge("stable")` +#' Plot method for forecast secondary observations. +#' +#' @inheritParams plot.estimate_secondary +#' @method plot forecast_secondary +#' @export +plot.forecast_secondary <- function(x, primary = FALSE, + from = NULL, to = NULL, + new_obs = NULL, + ...) { + predictions <- data.table::copy(get_predictions(x)) + + if (!is.null(new_obs)) { + new_obs <- data.table::as.data.table(new_obs) + new_obs <- new_obs[, .(date, secondary)] + predictions <- predictions[, secondary := NULL] + predictions <- data.table::merge.data.table( + predictions, new_obs, + all = TRUE, by = "date" + ) + } + if (!is.null(from)) { + predictions <- predictions[date >= from] + } + if (!is.null(to)) { + predictions <- predictions[date <= to] + } + + p <- ggplot2::ggplot(predictions, ggplot2::aes(x = date, y = secondary)) + + ggplot2::geom_col( + fill = "grey", col = "white", + show.legend = FALSE, na.rm = TRUE + ) + + if (primary) { + p <- p + + ggplot2::geom_point( + data = predictions, + ggplot2::aes(y = primary), + alpha = 0.4, size = 0.8 + ) + + ggplot2::geom_line( + data = predictions, + ggplot2::aes(y = primary), alpha = 0.4 + ) + } + p <- plot_CrIs(p, extract_CrIs(predictions), + alpha = 0.6, linewidth = 1 + ) + p + + ggplot2::theme_bw() + + ggplot2::labs(y = "Reports per day", x = "Date") + + ggplot2::scale_x_date(date_breaks = "week", date_labels = "%b %d") + + ggplot2::scale_y_continuous(labels = scales::comma) + + ggplot2::theme(axis.text.x = ggplot2::element_text(angle = 90)) +} + #' Convolve and scale a time series #' #' This applies a lognormal convolution with given, potentially time-varying From f36f0d07ab6bcce1da2862ed32a650d7c8513239 Mon Sep 17 00:00:00 2001 From: "epiforecasts-workflows[bot]" Date: Fri, 14 Nov 2025 09:05:52 +0000 Subject: [PATCH 16/51] Update documentation --- man/plot.forecast_secondary.Rd | 28 ++++++++++++++++++++++++++++ 1 file changed, 28 insertions(+) create mode 100644 man/plot.forecast_secondary.Rd diff --git a/man/plot.forecast_secondary.Rd b/man/plot.forecast_secondary.Rd new file mode 100644 index 000000000..7fdec4e94 --- /dev/null +++ b/man/plot.forecast_secondary.Rd @@ -0,0 +1,28 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/estimate_secondary.R +\name{plot.forecast_secondary} +\alias{plot.forecast_secondary} +\title{Plot method for forecast_secondary objects} +\usage{ +\method{plot}{forecast_secondary}(x, primary = FALSE, from = NULL, to = NULL, new_obs = NULL, ...) +} +\arguments{ +\item{x}{A list of output as produced by \code{estimate_secondary}} + +\item{primary}{Logical, defaults to \code{FALSE}. Should \code{primary} reports also +be plot?} + +\item{from}{Date object indicating when to plot from.} + +\item{to}{Date object indicating when to plot up to.} + +\item{new_obs}{A \verb{} containing the columns \code{date} and \code{secondary} +which replace the secondary observations stored in the \code{estimate_secondary} +output.} + +\item{...}{Pass additional arguments to plot function. Not currently in use.} +} +\description{ +\ifelse{html}{\href{https://lifecycle.r-lib.org/articles/stages.html#stable}{\figure{lifecycle-stable.svg}{options: alt='[Stable]'}}}{\strong{[Stable]}} +Plot method for forecast secondary observations. +} From b0c3f039f4b0e6f73cfcbf70c2bd1e5281f7459f Mon Sep 17 00:00:00 2001 From: Claude Date: Fri, 14 Nov 2025 09:58:15 +0000 Subject: [PATCH 17/51] Pass CrIs parameter through in get_samples.estimate_secondary Allow CrIs to be customized when calling get_samples(), matching the old behavior where CrIs could be specified. --- R/get.R | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/R/get.R b/R/get.R index a6034190f..f74f2253d 100644 --- a/R/get.R +++ b/R/get.R @@ -268,9 +268,10 @@ get_samples.forecast_infections <- function(object, ...) { #' @rdname get_samples #' @export -get_samples.estimate_secondary <- function(object, ...) { +get_samples.estimate_secondary <- function(object, CrIs = c(0.2, 0.5, 0.9), + ...) { # Extract posterior samples from the fit - extract_stan_param(object$fit, CrIs = c(0.2, 0.5, 0.9)) + extract_stan_param(object$fit, CrIs = CrIs) } #' @rdname get_samples From 7b4036f1ee546772b32db46c03cfc5ffbbae8ded Mon Sep 17 00:00:00 2001 From: "epiforecasts-workflows[bot]" Date: Fri, 14 Nov 2025 10:04:46 +0000 Subject: [PATCH 18/51] Update documentation --- man/get_samples.Rd | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/man/get_samples.Rd b/man/get_samples.Rd index e4c9d7736..4d62b04a7 100644 --- a/man/get_samples.Rd +++ b/man/get_samples.Rd @@ -14,7 +14,7 @@ get_samples(object, ...) \method{get_samples}{forecast_infections}(object, ...) -\method{get_samples}{estimate_secondary}(object, ...) +\method{get_samples}{estimate_secondary}(object, CrIs = c(0.2, 0.5, 0.9), ...) \method{get_samples}{forecast_secondary}(object, ...) } From f33c208d6b0e7ecc9169c48a80d43ab22204d6d9 Mon Sep 17 00:00:00 2001 From: Claude Date: Fri, 14 Nov 2025 10:52:48 +0000 Subject: [PATCH 19/51] Clarify pull request workflow in CLAUDE.md --- CLAUDE.md | 20 ++++++++++++++++++++ 1 file changed, 20 insertions(+) diff --git a/CLAUDE.md b/CLAUDE.md index 22734a31e..a75a98f75 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -26,9 +26,29 @@ ## Pull Requests +### Workflow Overview +When working on a task, follow this sequence: + +1. **During development**: Make commits with descriptive messages (NO issue numbers) +2. **Before opening PR**: Add a NEWS.md item for user-facing changes +3. **When opening PR**: Link the issue in PR description using "This PR closes #XXX" +4. **In PR template**: Complete all applicable checklist items + +### Key Distinctions +- **Commit messages**: Describe what the code does, never reference issues + - Good: "Standardise return structure for estimate_secondary" + - Bad: "Standardise return structure for estimate_secondary (#1142)" +- **PR descriptions**: Link to issues here using the template + - Format: "This PR closes #1142." +- **NEWS items**: Describe user-facing changes, never reference issues or PRs + - Good: "Added S3 methods for estimate_secondary objects" + - Bad: "Added S3 methods for estimate_secondary (#1142)" + +### PR Template Requirements - Follow the pull request template in `.github/PULL_REQUEST_TEMPLATE.md` - Complete the checklist items in the template - Link to the related issue in the PR description (not in commit messages) +- Ensure you've added a NEWS.md item before submitting ## Testing From 9b13ce3a20adb89c6d54c6bef3f6abc8b8e29ce9 Mon Sep 17 00:00:00 2001 From: Claude Date: Fri, 14 Nov 2025 10:52:55 +0000 Subject: [PATCH 20/51] Add NEWS item for estimate_secondary S3 standardisation --- NEWS.md | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/NEWS.md b/NEWS.md index f459481cb..1387d8b39 100644 --- a/NEWS.md +++ b/NEWS.md @@ -15,6 +15,13 @@ - **Deprecated**: `summary(object, type = "samples")` now issues a deprecation warning. Use `get_samples(object)` instead. - **Deprecated**: Internal function `extract_parameter_samples()` renamed to `format_simulation_output()` for clarity. - `forecast_infections()` now returns an independent S3 class `"forecast_infections"` instead of inheriting from `"estimate_infections"`. This clarifies the distinction between fitted models (which contain a Stan fit for diagnostics) and forecast simulations (which contain pre-computed samples). Dedicated `summary()`, `plot()`, and `get_samples()` methods are provided. +- `estimate_secondary()` now returns an S3 object of class `c("epinowfit", "estimate_secondary", "list")` with elements `fit`, `args`, and `observations`, matching the structure of `estimate_infections()`. + - Use `get_samples(object)` to extract formatted posterior samples for delay and scaling parameters. + - Use `get_predictions(object)` to get predicted secondary observations with credible intervals merged with observations. + - Use `summary(object)` to get summarised parameter estimates. Use `type = "compact"` for key parameters only, or `type = "parameters"` with a `params` argument to select specific parameters. + - Access the Stan fit directly via `object$fit`, model arguments via `object$args`, and observations via `object$observations`. + - **Breaking change**: The previous return structure with `predictions`, `posterior`, and `data` elements is no longer supported. Use the accessor methods instead. +- `forecast_secondary()` now returns an independent S3 class `"forecast_secondary"` instead of inheriting from `"estimate_secondary"`, with dedicated `get_samples()`, `get_predictions()`, and `plot()` methods. - `plot.estimate_infections()` and `plot.forecast_infections()` now accept a `CrIs` argument to control which credible intervals are displayed. ## Model changes From 983a271f193ed795470662e9c12588a3ae161ad5 Mon Sep 17 00:00:00 2001 From: Claude Date: Fri, 14 Nov 2025 10:53:04 +0000 Subject: [PATCH 21/51] Fix get_samples to return raw posterior samples get_samples.estimate_secondary() now returns individual MCMC draws rather than summary statistics, consistent with the method's purpose. Tests updated to calculate summary statistics from the raw samples. --- R/get.R | 58 ++++++++++++++++++++++-- tests/testthat/test-estimate_secondary.R | 29 +++++++++--- 2 files changed, 77 insertions(+), 10 deletions(-) diff --git a/R/get.R b/R/get.R index f74f2253d..ad0cb3c70 100644 --- a/R/get.R +++ b/R/get.R @@ -268,10 +268,60 @@ get_samples.forecast_infections <- function(object, ...) { #' @rdname get_samples #' @export -get_samples.estimate_secondary <- function(object, CrIs = c(0.2, 0.5, 0.9), - ...) { - # Extract posterior samples from the fit - extract_stan_param(object$fit, CrIs = CrIs) +get_samples.estimate_secondary <- function(object, ...) { + # Extract raw posterior samples from the fit + raw_samples <- extract_samples(object$fit) + + # Convert to long format data.table + samples_list <- list() + + # Extract delay_params if present + if (!is.null(raw_samples$delay_params)) { + n_delay_params <- ncol(raw_samples$delay_params) + for (i in seq_len(n_delay_params)) { + param_name <- paste0("delay_params[", i, "]") + samples_list[[param_name]] <- data.table::data.table( + variable = param_name, + sample = seq_along(raw_samples$delay_params[, i]), + value = raw_samples$delay_params[, i] + ) + } + } + + # Extract params if present + if (!is.null(raw_samples$params)) { + n_params <- ncol(raw_samples$params) + for (i in seq_len(n_params)) { + param_name <- paste0("params[", i, "]") + samples_list[[param_name]] <- data.table::data.table( + variable = param_name, + sample = seq_along(raw_samples$params[, i]), + value = raw_samples$params[, i] + ) + } + } + + # Combine all samples + samples <- data.table::rbindlist(samples_list, fill = TRUE) + + # Add placeholder columns for consistency with estimate_infections format + if (!("date" %in% names(samples))) { + samples[, date := as.Date(NA)] + } + if (!("strat" %in% names(samples))) { + samples[, strat := NA_character_] + } + if (!("type" %in% names(samples))) { + samples[, type := NA_character_] + } + if (!("time" %in% names(samples))) { + samples[, time := NA_integer_] + } + + # Reorder columns + data.table::setcolorder(samples, c("date", "variable", "strat", "sample", "time", "value", "type")) + + return(samples[]) } #' @rdname get_samples diff --git a/tests/testthat/test-estimate_secondary.R b/tests/testthat/test-estimate_secondary.R index 44fc19110..0854c089a 100644 --- a/tests/testthat/test-estimate_secondary.R +++ b/tests/testthat/test-estimate_secondary.R @@ -154,20 +154,30 @@ test_that("estimate_secondary works when only estimating scaling", { }) test_that("estimate_secondary can recover simulated parameters", { + # Calculate summary statistics from raw samples + inc_summary <- inc_posterior[, .( + mean = mean(value), + median = stats::median(value) + ), by = variable] + prev_summary <- prev_posterior[, .( + mean = mean(value), + median = stats::median(value) + ), by = variable] + expect_equal( - inc_posterior[, mean], c(1.8, 0.5, 0.4), + inc_summary$mean, c(1.8, 0.5, 0.4), tolerance = 0.1 ) expect_equal( - inc_posterior[, median], c(1.8, 0.5, 0.4), + inc_summary$median, c(1.8, 0.5, 0.4), tolerance = 0.1 ) expect_equal( - prev_posterior[, mean], c(1.6, 0.8, 0.3), + prev_summary$mean, c(1.6, 0.8, 0.3), tolerance = 0.2 ) expect_equal( - prev_posterior[, median], c(1.6, 0.8, 0.3), + prev_summary$median, c(1.6, 0.8, 0.3), tolerance = 0.2 ) }) @@ -182,12 +192,19 @@ test_that("estimate_secondary can recover simulated parameters with the ) ))) inc_posterior_cmdstanr <- get_samples(inc_cmdstanr)[variable %in% params] + + # Calculate summary statistics from raw samples + inc_summary_cmdstanr <- inc_posterior_cmdstanr[, .( + mean = mean(value), + median = stats::median(value) + ), by = variable] + expect_equal( - inc_posterior_cmdstanr[, mean], c(1.8, 0.5, 0.4), + inc_summary_cmdstanr$mean, c(1.8, 0.5, 0.4), tolerance = 0.1 ) expect_equal( - inc_posterior_cmdstanr[, median], c(1.8, 0.5, 0.4), + inc_summary_cmdstanr$median, c(1.8, 0.5, 0.4), tolerance = 0.1 ) }) From 5bd96c840cf0951e0a155bebcd4fc54fd60e0089 Mon Sep 17 00:00:00 2001 From: "epiforecasts-workflows[bot]" Date: Fri, 14 Nov 2025 11:00:16 +0000 Subject: [PATCH 22/51] Update documentation --- man/get_samples.Rd | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/man/get_samples.Rd b/man/get_samples.Rd index 4d62b04a7..e4c9d7736 100644 --- a/man/get_samples.Rd +++ b/man/get_samples.Rd @@ -14,7 +14,7 @@ get_samples(object, ...) \method{get_samples}{forecast_infections}(object, ...) -\method{get_samples}{estimate_secondary}(object, CrIs = c(0.2, 0.5, 0.9), ...) +\method{get_samples}{estimate_secondary}(object, ...) \method{get_samples}{forecast_secondary}(object, ...) } From d1452e7843edc7d13f77a58df3bbed10c99b328d Mon Sep 17 00:00:00 2001 From: Claude Date: Fri, 14 Nov 2025 15:00:38 +0000 Subject: [PATCH 23/51] Add extract_array_parameter helper for consistent param extraction Created extract_array_parameter() to standardise extraction of matrix parameters (delay_params, params) into indexed format (e.g., params[1], params[2]). Now used consistently by both estimate_infections (via format_samples_with_dates) and estimate_secondary (via get_samples.estimate_secondary). --- R/extract.R | 28 +++++++++++++++++++++++++ R/format.R | 29 ++++---------------------- R/get.R | 59 ++++++++++++++++------------------------------------- 3 files changed, 50 insertions(+), 66 deletions(-) diff --git a/R/extract.R b/R/extract.R index 6456b04fe..84600eac1 100644 --- a/R/extract.R +++ b/R/extract.R @@ -76,6 +76,34 @@ extract_parameter <- function(param, samples) { ) } +#' Extract samples from an array parameter +#' +#' Extracts samples from a parameter stored as a matrix (samples x dimension) +#' and returns them in long format with indexed parameter names. +#' +#' @param param_name Character string for the parameter name (e.g., +#' "delay_params", "params") +#' @param param_array A matrix of samples where rows are MCMC samples and +#' columns are parameter dimensions +#' @return A `` with columns: parameter, sample, value, or NULL if +#' param_array is NULL +#' @keywords internal +extract_array_parameter <- function(param_name, param_array) { + if (is.null(param_array)) { + return(NULL) + } + + n_cols <- ncol(param_array) + samples_list <- lapply(seq_len(n_cols), function(i) { + data.table::data.table( + parameter = paste0(param_name, "[", i, "]"), + sample = seq_along(param_array[, i]), + value = param_array[, i] + ) + }) + data.table::rbindlist(samples_list) +} + #' Extract all samples from a stan fit #' #' If the `object` argument is a `` object, it simply returns the diff --git a/R/format.R b/R/format.R index 30006813e..152e19768 100644 --- a/R/format.R +++ b/R/format.R @@ -278,34 +278,13 @@ format_samples_with_dates <- function(raw_samples, args, observations) { # Delay parameters if (args$delay_params_length > 0) { - delay_params <- extract_latent_state( - "delay_params", raw_samples, seq_len(args$delay_params_length) + out$delay_params <- extract_array_parameter( + "delay_params", raw_samples$delay_params ) - if (!is.null(delay_params)) { - out$delay_params <- delay_params[, strat := as.character(time)][ - , time := NULL - ][, date := NULL] - } } - # Auto-detect and extract all static parameters from params matrix - param_id_names <- names(raw_samples)[ - startsWith(names(raw_samples), "param_id_") - ] - param_names <- sub("^param_id_", "", param_id_names) - - for (param in param_names) { - result <- extract_parameter(param, raw_samples) - if (!is.null(result)) { - # Use standard naming conventions - param_name <- switch(param, - "dispersion" = "reporting_overdispersion", - "frac_obs" = "fraction_observed", - param # default: use param name as-is - ) - out[[param_name]] <- result - } - } + # Params matrix + out$params <- extract_array_parameter("params", raw_samples$params) # Combine all parameters into single data.table combined <- data.table::rbindlist(out, fill = TRUE, idcol = "variable") diff --git a/R/get.R b/R/get.R index ad0cb3c70..c2dadd6eb 100644 --- a/R/get.R +++ b/R/get.R @@ -272,54 +272,31 @@ get_samples.estimate_secondary <- function(object, ...) { # Extract raw posterior samples from the fit raw_samples <- extract_samples(object$fit) - # Convert to long format data.table - samples_list <- list() - - # Extract delay_params if present - if (!is.null(raw_samples$delay_params)) { - n_delay_params <- ncol(raw_samples$delay_params) - for (i in seq_len(n_delay_params)) { - param_name <- paste0("delay_params[", i, "]") - samples_list[[param_name]] <- data.table::data.table( - variable = param_name, - sample = seq_along(raw_samples$delay_params[, i]), - value = raw_samples$delay_params[, i] - ) - } - } - - # Extract params if present - if (!is.null(raw_samples$params)) { - n_params <- ncol(raw_samples$params) - for (i in seq_len(n_params)) { - param_name <- paste0("params[", i, "]") - samples_list[[param_name]] <- data.table::data.table( - variable = param_name, - sample = seq_along(raw_samples$params[, i]), - value = raw_samples$params[, i] - ) - } - } + # Extract delay_params and params using helper function + samples_list <- list( + extract_array_parameter("delay_params", raw_samples$delay_params), + extract_array_parameter("params", raw_samples$params) + ) # Combine all samples samples <- data.table::rbindlist(samples_list, fill = TRUE) + # Rename 'parameter' column to 'variable' for consistency + data.table::setnames(samples, "parameter", "variable") + # Add placeholder columns for consistency with estimate_infections format - if (!("date" %in% names(samples))) { - samples[, date := as.Date(NA)] - } - if (!("strat" %in% names(samples))) { - samples[, strat := NA_character_] - } - if (!("type" %in% names(samples))) { - samples[, type := NA_character_] - } - if (!("time" %in% names(samples))) { - samples[, time := NA_integer_] - } + samples[, `:=`( + date = as.Date(NA), + strat = NA_character_, + time = NA_integer_, + type = NA_character_ + )] # Reorder columns - data.table::setcolorder(samples, c("date", "variable", "strat", "sample", "time", "value", "type")) + data.table::setcolorder( + samples, + c("date", "variable", "strat", "sample", "time", "value", "type") + ) return(samples[]) } From c6eea247c6f7727c3c33c35ea664ed2a7c02872d Mon Sep 17 00:00:00 2001 From: Claude Date: Fri, 14 Nov 2025 15:00:48 +0000 Subject: [PATCH 24/51] Simplify plot.forecast_secondary to call plot.estimate_secondary Follows same pattern as plot.forecast_infections which delegates to plot.estimate_infections. Reduces code duplication. --- R/estimate_secondary.R | 46 ++---------------------------------------- 1 file changed, 2 insertions(+), 44 deletions(-) diff --git a/R/estimate_secondary.R b/R/estimate_secondary.R index 1cafd1d9c..7eae844b3 100644 --- a/R/estimate_secondary.R +++ b/R/estimate_secondary.R @@ -433,51 +433,9 @@ plot.forecast_secondary <- function(x, primary = FALSE, from = NULL, to = NULL, new_obs = NULL, ...) { - predictions <- data.table::copy(get_predictions(x)) - - if (!is.null(new_obs)) { - new_obs <- data.table::as.data.table(new_obs) - new_obs <- new_obs[, .(date, secondary)] - predictions <- predictions[, secondary := NULL] - predictions <- data.table::merge.data.table( - predictions, new_obs, - all = TRUE, by = "date" - ) - } - if (!is.null(from)) { - predictions <- predictions[date >= from] - } - if (!is.null(to)) { - predictions <- predictions[date <= to] - } - - p <- ggplot2::ggplot(predictions, ggplot2::aes(x = date, y = secondary)) + - ggplot2::geom_col( - fill = "grey", col = "white", - show.legend = FALSE, na.rm = TRUE - ) - - if (primary) { - p <- p + - ggplot2::geom_point( - data = predictions, - ggplot2::aes(y = primary), - alpha = 0.4, size = 0.8 - ) + - ggplot2::geom_line( - data = predictions, - ggplot2::aes(y = primary), alpha = 0.4 - ) - } - p <- plot_CrIs(p, extract_CrIs(predictions), - alpha = 0.6, linewidth = 1 + plot.estimate_secondary( + x, primary = primary, from = from, to = to, new_obs = new_obs, ... ) - p + - ggplot2::theme_bw() + - ggplot2::labs(y = "Reports per day", x = "Date") + - ggplot2::scale_x_date(date_breaks = "week", date_labels = "%b %d") + - ggplot2::scale_y_continuous(labels = scales::comma) + - ggplot2::theme(axis.text.x = ggplot2::element_text(angle = 90)) } #' Convolve and scale a time series From b8fbfcc093b2a2b8f8cb9bea70275cf0e992efc5 Mon Sep 17 00:00:00 2001 From: "epiforecasts-workflows[bot]" Date: Wed, 19 Nov 2025 13:36:40 +0000 Subject: [PATCH 25/51] Update documentation --- man/extract_array_parameter.Rd | 24 ++++++++++++++++++++++++ 1 file changed, 24 insertions(+) create mode 100644 man/extract_array_parameter.Rd diff --git a/man/extract_array_parameter.Rd b/man/extract_array_parameter.Rd new file mode 100644 index 000000000..4b7273110 --- /dev/null +++ b/man/extract_array_parameter.Rd @@ -0,0 +1,24 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/extract.R +\name{extract_array_parameter} +\alias{extract_array_parameter} +\title{Extract samples from an array parameter} +\usage{ +extract_array_parameter(param_name, param_array) +} +\arguments{ +\item{param_name}{Character string for the parameter name (e.g., +"delay_params", "params")} + +\item{param_array}{A matrix of samples where rows are MCMC samples and +columns are parameter dimensions} +} +\value{ +A \verb{} with columns: parameter, sample, value, or NULL if +param_array is NULL +} +\description{ +Extracts samples from a parameter stored as a matrix (samples x dimension) +and returns them in long format with indexed parameter names. +} +\keyword{internal} From a106aee7368074438912ee178c02a2a1c8b6f8a1 Mon Sep 17 00:00:00 2001 From: Sebastian Funk Date: Wed, 19 Nov 2025 17:50:46 +0000 Subject: [PATCH 26/51] Refactor estimate_secondary to use modern extraction pattern MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Updated get_samples.estimate_secondary() to extract all parameters including sim_secondary using extract_latent_state(). Refactored summary.estimate_secondary() and get_predictions.estimate_secondary() to use get_samples() and calc_summary_measures() instead of extract_stan_param(), following the same pattern as estimate_infections. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude --- R/get.R | 53 ++++++++++++++++-------- R/summarise.R | 16 ++++++- tests/testthat/test-estimate_secondary.R | 4 +- 3 files changed, 52 insertions(+), 21 deletions(-) diff --git a/R/get.R b/R/get.R index c2dadd6eb..85b9177a0 100644 --- a/R/get.R +++ b/R/get.R @@ -272,25 +272,39 @@ get_samples.estimate_secondary <- function(object, ...) { # Extract raw posterior samples from the fit raw_samples <- extract_samples(object$fit) - # Extract delay_params and params using helper function + # Extract array parameters (delay_params and params) samples_list <- list( extract_array_parameter("delay_params", raw_samples$delay_params), extract_array_parameter("params", raw_samples$params) ) + # Extract time-varying generated quantities + # Get dates for time-indexed parameters (post burn-in) + burn_in <- object$args$burn_in + dates <- object$observations[(burn_in + 1):.N]$date + + # Extract sim_secondary (generated quantity, post burn-in) if available + sim_secondary_samples <- extract_latent_state( + "sim_secondary", raw_samples, dates + ) + if (!is.null(sim_secondary_samples)) { + samples_list <- c(samples_list, list(sim_secondary_samples)) + } + # Combine all samples samples <- data.table::rbindlist(samples_list, fill = TRUE) - # Rename 'parameter' column to 'variable' for consistency - data.table::setnames(samples, "parameter", "variable") + # Rename 'parameter' column to 'variable' for consistency if needed + if ("parameter" %in% names(samples)) { + data.table::setnames(samples, "parameter", "variable") + } # Add placeholder columns for consistency with estimate_infections format - samples[, `:=`( - date = as.Date(NA), - strat = NA_character_, - time = NA_integer_, - type = NA_character_ - )] + # Only add if not already present + if (!"date" %in% names(samples)) samples[, date := as.Date(NA)] + if (!"strat" %in% names(samples)) samples[, strat := NA_character_] + if (!"time" %in% names(samples)) samples[, time := NA_integer_] + if (!"type" %in% names(samples)) samples[, type := NA_character_] # Reorder columns data.table::setcolorder( @@ -366,15 +380,20 @@ get_predictions.estimate_infections <- function(object, get_predictions.estimate_secondary <- function(object, CrIs = c(0.2, 0.5, 0.9), ...) { - # Extract predictions from the fit - predictions <- extract_stan_param(object$fit, "sim_secondary", CrIs = CrIs) - predictions <- predictions[, lapply(.SD, round, 1)] + # Get samples for simulated secondary observations + samples <- get_samples(object) + sim_secondary_samples <- samples[variable == "sim_secondary"] - # Add dates based on burn_in - burn_in <- object$args$burn_in - predictions <- predictions[ - , date := object$observations[(burn_in + 1):.N]$date - ] + # Calculate summary measures + predictions <- calc_summary_measures( + sim_secondary_samples, + summarise_by = "date", + order_by = "date", + CrIs = CrIs + ) + + # Round predictions + predictions <- predictions[, lapply(.SD, round, 1)] # Merge with observations predictions <- data.table::merge.data.table( diff --git a/R/summarise.R b/R/summarise.R index f3d26c5dc..36c94a539 100644 --- a/R/summarise.R +++ b/R/summarise.R @@ -915,8 +915,20 @@ summary.estimate_secondary <- function(object, CrIs = c(0.2, 0.5, 0.9), ...) { type <- arg_match(type) - # Extract all parameters with summary statistics - out <- extract_stan_param(object$fit, CrIs = CrIs) + # Get all posterior samples + samples <- get_samples(object) + + # Filter to non-time-varying parameters (delay_params and params) + # Time-varying parameters like secondary and sim_secondary have dates + param_samples <- samples[is.na(date)] + + # Calculate summary statistics + out <- calc_summary_measures( + param_samples, + summarise_by = "variable", + order_by = "variable", + CrIs = CrIs + ) if (type == "compact") { # Return only key parameters for a compact summary diff --git a/tests/testthat/test-estimate_secondary.R b/tests/testthat/test-estimate_secondary.R index 0854c089a..bf271ca0d 100644 --- a/tests/testthat/test-estimate_secondary.R +++ b/tests/testthat/test-estimate_secondary.R @@ -80,8 +80,8 @@ test_that("estimate_secondary can return values from simulated data and plot expect_equal( names(predictions), c( - "date", "primary", "secondary", "accumulate", "mean", "se_mean", "sd", - "lower_90", "lower_50", "lower_20", "median", "upper_20", "upper_50", "upper_90" + "date", "primary", "secondary", "accumulate", "median", "mean", "sd", + "lower_90", "lower_50", "lower_20", "upper_20", "upper_50", "upper_90" ) ) From 134691cc37e595483925ede4346df993a5802462 Mon Sep 17 00:00:00 2001 From: Sebastian Funk Date: Wed, 19 Nov 2025 20:15:39 +0000 Subject: [PATCH 27/51] Unify parameter extraction with extract_parameters() and extract_delays() MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Simplified and unified the parameter extraction API: - extract_parameters(): extracts all params with named lookups - extract_delays(): extracts all delays with named lookups (ready for future delay ID system similar to param IDs) Both functions are plural, both always use lookups, both extract everything. This makes the API consistent and prepares for implementing delay_id_* variables in Stan (e.g., delay_id_gt, delay_id_reporting, delay_id_truncation). Removed extract_parameter() and extract_array_parameter() as they're no longer needed. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude --- R/extract.R | 124 +++++++++++++++++++++++++-------- R/format.R | 34 ++++----- R/get.R | 6 +- man/extract_array_parameter.Rd | 24 ------- man/extract_delays.Rd | 21 ++++++ man/extract_parameter.Rd | 21 ------ man/extract_parameters.Rd | 19 +++++ 7 files changed, 154 insertions(+), 95 deletions(-) delete mode 100644 man/extract_array_parameter.Rd create mode 100644 man/extract_delays.Rd delete mode 100644 man/extract_parameter.Rd create mode 100644 man/extract_parameters.Rd diff --git a/R/extract.R b/R/extract.R index 84600eac1..2b17f8e3d 100644 --- a/R/extract.R +++ b/R/extract.R @@ -51,59 +51,123 @@ extract_latent_state <- function(param, samples, dates) { } -#' Extract Samples from a Parameter with a Single Dimension +#' Extract Samples from All Parameters #' -#' @param param Character string indicating the parameter to extract #' @param samples Extracted stan model (using [rstan::extract()]) -#' @return A `` containing the parameter name, sample id and sample -#' value, or NULL if the parameter doesn't exist in the samples +#' @return A `` containing the parameter name, sample id and sample +#' value, or NULL if parameters don't exist in the samples #' @keywords internal -extract_parameter <- function(param, samples) { - id_name <- paste("param_id", param, sep = "_") - - # Return NULL if parameter ID doesn't exist - if (!(id_name %in% names(samples))) { +extract_parameters <- function(samples) { + # Check if params exist + if (!("params" %in% names(samples))) { return(NULL) } - id <- samples[[id_name]] + # Extract all parameters + param_array <- samples[["params"]] + n_cols <- ncol(param_array) - lookup <- samples[["params_variable_lookup"]][id] - data.table::data.table( - parameter = param, - sample = seq_along(samples[["params"]][, lookup]), - value = samples[["params"]][, lookup] - ) + # Build reverse lookup: column index -> parameter name + param_names <- rep(NA_character_, n_cols) + + # Check all param_id_* variables to build the mapping + id_vars <- grep("^param_id_", names(samples), value = TRUE) + for (id_var in id_vars) { + param_name <- sub("^param_id_", "", id_var) + id <- samples[[id_var]] + lookup_idx <- samples[["params_variable_lookup"]][id] + if (lookup_idx > 0 && lookup_idx <= n_cols) { + param_names[lookup_idx] <- param_name + } + } + + # Extract all columns + samples_list <- lapply(seq_len(n_cols), function(i) { + # Use named parameter if available, otherwise use indexed name + par_name <- if (!is.na(param_names[i])) { + param_names[i] + } else { + paste0("params[", i, "]") + } + + data.table::data.table( + parameter = par_name, + sample = seq_along(param_array[, i]), + value = param_array[, i] + ) + }) + + data.table::rbindlist(samples_list) } -#' Extract samples from an array parameter +#' Extract Samples from All Delay Parameters #' -#' Extracts samples from a parameter stored as a matrix (samples x dimension) -#' and returns them in long format with indexed parameter names. +#' Extracts samples from all delay parameters using the delay ID lookup system. +#' Similar to extract_parameters(), this extracts all delay distribution +#' parameters and uses the delay_id_* variables to assign meaningful names. #' -#' @param param_name Character string for the parameter name (e.g., -#' "delay_params", "params") -#' @param param_array A matrix of samples where rows are MCMC samples and -#' columns are parameter dimensions +#' @param samples Extracted stan model (using [rstan::extract()]) #' @return A `` with columns: parameter, sample, value, or NULL if -#' param_array is NULL +#' delay parameters don't exist in the samples #' @keywords internal -extract_array_parameter <- function(param_name, param_array) { - if (is.null(param_array)) { +extract_delays <- function(samples) { + # Check if delay_params exist + if (!("delay_params" %in% names(samples))) { return(NULL) } - n_cols <- ncol(param_array) + # Extract all delay parameters + delay_params <- samples[["delay_params"]] + n_cols <- ncol(delay_params) + + # Build reverse lookup: column index -> delay name + delay_names <- rep(NA_character_, n_cols) + + # Check all delay_id_* variables to build the mapping + id_vars <- grep("^delay_id_", names(samples), value = TRUE) + if (length(id_vars) > 0 && "delay_params_groups" %in% names(samples)) { + delay_params_groups <- samples[["delay_params_groups"]] + + for (id_var in id_vars) { + delay_name <- sub("^delay_id_", "", id_var) + delay_id <- samples[[id_var]] + + # Check if this delay exists (ID > 0) + if (delay_id > 0 && delay_id < length(delay_params_groups)) { + start_idx <- delay_params_groups[delay_id] + end_idx <- delay_params_groups[delay_id + 1] - 1 + + # Mark columns for this delay + for (i in seq_along(start_idx:end_idx)) { + col_idx <- start_idx + i - 1 + if (col_idx <= n_cols) { + delay_names[col_idx] <- paste0("delay_", delay_name, "[", i, "]") + } + } + } + } + } + + # Extract all columns samples_list <- lapply(seq_len(n_cols), function(i) { + # Use named delay if available, otherwise use indexed name + par_name <- if (!is.na(delay_names[i])) { + delay_names[i] + } else { + paste0("delay_params[", i, "]") + } + data.table::data.table( - parameter = paste0(param_name, "[", i, "]"), - sample = seq_along(param_array[, i]), - value = param_array[, i] + parameter = par_name, + sample = seq_along(delay_params[, i]), + value = delay_params[, i] ) }) + data.table::rbindlist(samples_list) } + #' Extract all samples from a stan fit #' #' If the `object` argument is a `` object, it simply returns the diff --git a/R/format.R b/R/format.R index 152e19768..b22a711b2 100644 --- a/R/format.R +++ b/R/format.R @@ -191,20 +191,22 @@ format_simulation_output <- function(stan_fit, data, reported_dates, ] } # Auto-detect and extract all static parameters from params matrix - # Find all parameter IDs (names starting with "param_id_") - param_id_names <- names(samples)[startsWith(names(samples), "param_id_")] - param_names <- sub("^param_id_", "", param_id_names) + all_params <- extract_parameters(samples) + if (!is.null(all_params)) { + # Get unique parameter names + param_names <- unique(all_params$parameter) - for (param in param_names) { - result <- extract_parameter(param, samples) - if (!is.null(result)) { - # Use standard naming conventions - param_name <- switch(param, - "dispersion" = "reporting_overdispersion", - "frac_obs" = "fraction_observed", - param # default: use param name as-is - ) - out[[param_name]] <- result + for (param in param_names) { + result <- all_params[parameter == param] + if (nrow(result) > 0) { + # Use standard naming conventions + param_name <- switch(param, + "dispersion" = "reporting_overdispersion", + "frac_obs" = "fraction_observed", + param # default: use param name as-is + ) + out[[param_name]] <- result + } } } return(out) @@ -278,13 +280,11 @@ format_samples_with_dates <- function(raw_samples, args, observations) { # Delay parameters if (args$delay_params_length > 0) { - out$delay_params <- extract_array_parameter( - "delay_params", raw_samples$delay_params - ) + out$delay_params <- extract_delays(raw_samples) } # Params matrix - out$params <- extract_array_parameter("params", raw_samples$params) + out$params <- extract_parameters(raw_samples) # Combine all parameters into single data.table combined <- data.table::rbindlist(out, fill = TRUE, idcol = "variable") diff --git a/R/get.R b/R/get.R index 85b9177a0..96e9709c2 100644 --- a/R/get.R +++ b/R/get.R @@ -272,10 +272,10 @@ get_samples.estimate_secondary <- function(object, ...) { # Extract raw posterior samples from the fit raw_samples <- extract_samples(object$fit) - # Extract array parameters (delay_params and params) + # Extract parameters (delays and params) samples_list <- list( - extract_array_parameter("delay_params", raw_samples$delay_params), - extract_array_parameter("params", raw_samples$params) + extract_delays(raw_samples), + extract_parameters(raw_samples) ) # Extract time-varying generated quantities diff --git a/man/extract_array_parameter.Rd b/man/extract_array_parameter.Rd deleted file mode 100644 index 4b7273110..000000000 --- a/man/extract_array_parameter.Rd +++ /dev/null @@ -1,24 +0,0 @@ -% Generated by roxygen2: do not edit by hand -% Please edit documentation in R/extract.R -\name{extract_array_parameter} -\alias{extract_array_parameter} -\title{Extract samples from an array parameter} -\usage{ -extract_array_parameter(param_name, param_array) -} -\arguments{ -\item{param_name}{Character string for the parameter name (e.g., -"delay_params", "params")} - -\item{param_array}{A matrix of samples where rows are MCMC samples and -columns are parameter dimensions} -} -\value{ -A \verb{} with columns: parameter, sample, value, or NULL if -param_array is NULL -} -\description{ -Extracts samples from a parameter stored as a matrix (samples x dimension) -and returns them in long format with indexed parameter names. -} -\keyword{internal} diff --git a/man/extract_delays.Rd b/man/extract_delays.Rd new file mode 100644 index 000000000..6edecad6a --- /dev/null +++ b/man/extract_delays.Rd @@ -0,0 +1,21 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/extract.R +\name{extract_delays} +\alias{extract_delays} +\title{Extract Samples from All Delay Parameters} +\usage{ +extract_delays(samples) +} +\arguments{ +\item{samples}{Extracted stan model (using \code{\link[rstan:stanfit-method-extract]{rstan::extract()}})} +} +\value{ +A \verb{} with columns: parameter, sample, value, or NULL if +delay parameters don't exist in the samples +} +\description{ +Extracts samples from all delay parameters using the delay ID lookup system. +Similar to extract_parameters(), this extracts all delay distribution +parameters and uses the delay_id_* variables to assign meaningful names. +} +\keyword{internal} diff --git a/man/extract_parameter.Rd b/man/extract_parameter.Rd deleted file mode 100644 index 4c9b00a52..000000000 --- a/man/extract_parameter.Rd +++ /dev/null @@ -1,21 +0,0 @@ -% Generated by roxygen2: do not edit by hand -% Please edit documentation in R/extract.R -\name{extract_parameter} -\alias{extract_parameter} -\title{Extract Samples from a Parameter with a Single Dimension} -\usage{ -extract_parameter(param, samples) -} -\arguments{ -\item{param}{Character string indicating the parameter to extract} - -\item{samples}{Extracted stan model (using \code{\link[rstan:stanfit-method-extract]{rstan::extract()}})} -} -\value{ -A \verb{} containing the parameter name, sample id and sample -value, or NULL if the parameter doesn't exist in the samples -} -\description{ -Extract Samples from a Parameter with a Single Dimension -} -\keyword{internal} diff --git a/man/extract_parameters.Rd b/man/extract_parameters.Rd new file mode 100644 index 000000000..681e49931 --- /dev/null +++ b/man/extract_parameters.Rd @@ -0,0 +1,19 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/extract.R +\name{extract_parameters} +\alias{extract_parameters} +\title{Extract Samples from All Parameters} +\usage{ +extract_parameters(samples) +} +\arguments{ +\item{samples}{Extracted stan model (using \code{\link[rstan:stanfit-method-extract]{rstan::extract()}})} +} +\value{ +A \verb{} containing the parameter name, sample id and sample +value, or NULL if parameters don't exist in the samples +} +\description{ +Extract Samples from All Parameters +} +\keyword{internal} From f6e5c59da4f393e61f55f33efc7200f26682a7c0 Mon Sep 17 00:00:00 2001 From: Sebastian Funk Date: Wed, 19 Nov 2025 20:28:25 +0000 Subject: [PATCH 28/51] Fix function titles to use sentence case per CLAUDE.md --- R/extract.R | 10 +++++----- man/extract_delays.Rd | 2 +- man/extract_latent_state.Rd | 2 +- man/extract_parameter_samples.Rd | 2 +- man/extract_parameters.Rd | 4 ++-- man/extract_stan_param.Rd | 2 +- 6 files changed, 11 insertions(+), 11 deletions(-) diff --git a/R/extract.R b/R/extract.R index 2b17f8e3d..b90518900 100644 --- a/R/extract.R +++ b/R/extract.R @@ -1,4 +1,4 @@ -#' Extract Samples for a Latent State from a Stan model +#' Extract samples for a latent state from a Stan model #' #' @description `r lifecycle::badge("stable")` #' Extracts a time-varying latent state from a list of stan output and returns @@ -51,7 +51,7 @@ extract_latent_state <- function(param, samples, dates) { } -#' Extract Samples from All Parameters +#' Extract samples from all parameters #' #' @param samples Extracted stan model (using [rstan::extract()]) #' @return A `` containing the parameter name, sample id and sample @@ -100,7 +100,7 @@ extract_parameters <- function(samples) { data.table::rbindlist(samples_list) } -#' Extract Samples from All Delay Parameters +#' Extract samples from all delay parameters #' #' Extracts samples from all delay parameters using the delay ID lookup system. #' Similar to extract_parameters(), this extracts all delay distribution @@ -243,7 +243,7 @@ extract_samples <- function(stan_fit, pars = NULL, include = TRUE) { return(samples) } -#' Extract Parameter Samples from a Stan Model +#' Extract parameter samples from a Stan model #' #' @description `r lifecycle::badge("deprecated")` #' This function has been deprecated. Use [format_simulation_output()] for @@ -272,7 +272,7 @@ extract_parameter_samples <- function(stan_fit, data, reported_dates, ) } -#' Extract a Parameter Summary from a Stan Object +#' Extract a parameter summary from a Stan object #' #' @description `r lifecycle::badge("stable")` #' Extracts summarised parameter posteriors from a `stanfit` object using diff --git a/man/extract_delays.Rd b/man/extract_delays.Rd index 6edecad6a..29618577f 100644 --- a/man/extract_delays.Rd +++ b/man/extract_delays.Rd @@ -2,7 +2,7 @@ % Please edit documentation in R/extract.R \name{extract_delays} \alias{extract_delays} -\title{Extract Samples from All Delay Parameters} +\title{Extract samples from all delay parameters} \usage{ extract_delays(samples) } diff --git a/man/extract_latent_state.Rd b/man/extract_latent_state.Rd index 89aaff260..eb38b5b52 100644 --- a/man/extract_latent_state.Rd +++ b/man/extract_latent_state.Rd @@ -2,7 +2,7 @@ % Please edit documentation in R/extract.R \name{extract_latent_state} \alias{extract_latent_state} -\title{Extract Samples for a Latent State from a Stan model} +\title{Extract samples for a latent state from a Stan model} \usage{ extract_latent_state(param, samples, dates) } diff --git a/man/extract_parameter_samples.Rd b/man/extract_parameter_samples.Rd index cea50c37f..8b35932d9 100644 --- a/man/extract_parameter_samples.Rd +++ b/man/extract_parameter_samples.Rd @@ -2,7 +2,7 @@ % Please edit documentation in R/extract.R \name{extract_parameter_samples} \alias{extract_parameter_samples} -\title{Extract Parameter Samples from a Stan Model} +\title{Extract parameter samples from a Stan model} \usage{ extract_parameter_samples( stan_fit, diff --git a/man/extract_parameters.Rd b/man/extract_parameters.Rd index 681e49931..b87c02b0c 100644 --- a/man/extract_parameters.Rd +++ b/man/extract_parameters.Rd @@ -2,7 +2,7 @@ % Please edit documentation in R/extract.R \name{extract_parameters} \alias{extract_parameters} -\title{Extract Samples from All Parameters} +\title{Extract samples from all parameters} \usage{ extract_parameters(samples) } @@ -14,6 +14,6 @@ A \verb{} containing the parameter name, sample id and sample value, or NULL if parameters don't exist in the samples } \description{ -Extract Samples from All Parameters +Extract samples from all parameters } \keyword{internal} diff --git a/man/extract_stan_param.Rd b/man/extract_stan_param.Rd index 9cf9c7b9c..e4f132b9c 100644 --- a/man/extract_stan_param.Rd +++ b/man/extract_stan_param.Rd @@ -2,7 +2,7 @@ % Please edit documentation in R/extract.R \name{extract_stan_param} \alias{extract_stan_param} -\title{Extract a Parameter Summary from a Stan Object} +\title{Extract a parameter summary from a Stan object} \usage{ extract_stan_param( fit, From 5271cf1ebfa42c9b6b5333f7b6b7167dfdca0f22 Mon Sep 17 00:00:00 2001 From: Sebastian Funk Date: Wed, 19 Nov 2025 20:40:55 +0000 Subject: [PATCH 29/51] Implement delay lookup system in extract_delays Update extract_delays() to use the delay ID lookup system similar to extract_parameters(). The function now searches for *_id variables (e.g., delay_id, trunc_id) and uses delay_params_groups to map them to meaningful parameter names. --- R/extract.R | 21 ++++++++++++--------- 1 file changed, 12 insertions(+), 9 deletions(-) diff --git a/R/extract.R b/R/extract.R index b90518900..0e6348003 100644 --- a/R/extract.R +++ b/R/extract.R @@ -75,9 +75,11 @@ extract_parameters <- function(samples) { for (id_var in id_vars) { param_name <- sub("^param_id_", "", id_var) id <- samples[[id_var]] - lookup_idx <- samples[["params_variable_lookup"]][id] - if (lookup_idx > 0 && lookup_idx <= n_cols) { - param_names[lookup_idx] <- param_name + if (!is.na(id) && id > 0) { + lookup_idx <- samples[["params_variable_lookup"]][id] + if (!is.na(lookup_idx) && lookup_idx > 0 && lookup_idx <= n_cols) { + param_names[lookup_idx] <- param_name + } } } @@ -104,7 +106,8 @@ extract_parameters <- function(samples) { #' #' Extracts samples from all delay parameters using the delay ID lookup system. #' Similar to extract_parameters(), this extracts all delay distribution -#' parameters and uses the delay_id_* variables to assign meaningful names. +#' parameters and uses the *_id variables (e.g., delay_id, trunc_id) to assign +#' meaningful names. #' #' @param samples Extracted stan model (using [rstan::extract()]) #' @return A `` with columns: parameter, sample, value, or NULL if @@ -123,14 +126,14 @@ extract_delays <- function(samples) { # Build reverse lookup: column index -> delay name delay_names <- rep(NA_character_, n_cols) - # Check all delay_id_* variables to build the mapping - id_vars <- grep("^delay_id_", names(samples), value = TRUE) + # Check all *_id variables to build the mapping (e.g., delay_id, trunc_id) + id_vars <- grep("_id$", names(samples), value = TRUE) if (length(id_vars) > 0 && "delay_params_groups" %in% names(samples)) { delay_params_groups <- samples[["delay_params_groups"]] for (id_var in id_vars) { - delay_name <- sub("^delay_id_", "", id_var) - delay_id <- samples[[id_var]] + delay_name <- sub("_id$", "", id_var) + delay_id <- samples[[id_var]][1] # Take first value (same across samples) # Check if this delay exists (ID > 0) if (delay_id > 0 && delay_id < length(delay_params_groups)) { @@ -141,7 +144,7 @@ extract_delays <- function(samples) { for (i in seq_along(start_idx:end_idx)) { col_idx <- start_idx + i - 1 if (col_idx <= n_cols) { - delay_names[col_idx] <- paste0("delay_", delay_name, "[", i, "]") + delay_names[col_idx] <- paste0(delay_name, "[", i, "]") } } } From 5b299326e6896898f4981a5bbc2a7c82edaaca56 Mon Sep 17 00:00:00 2001 From: "epiforecasts-workflows[bot]" Date: Wed, 19 Nov 2025 20:47:02 +0000 Subject: [PATCH 30/51] Update documentation --- man/extract_delays.Rd | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/man/extract_delays.Rd b/man/extract_delays.Rd index 29618577f..7256b217c 100644 --- a/man/extract_delays.Rd +++ b/man/extract_delays.Rd @@ -16,6 +16,7 @@ delay parameters don't exist in the samples \description{ Extracts samples from all delay parameters using the delay ID lookup system. Similar to extract_parameters(), this extracts all delay distribution -parameters and uses the delay_id_* variables to assign meaningful names. +parameters and uses the *_id variables (e.g., delay_id, trunc_id) to assign +meaningful names. } \keyword{internal} From 180fe5e3aa5290f823b280bbc9e78ec6edf5791a Mon Sep 17 00:00:00 2001 From: Sebastian Funk Date: Wed, 19 Nov 2025 20:54:50 +0000 Subject: [PATCH 31/51] Implement delay_id_* naming system with semantic delay names MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Update the delay parameter system to use semantic naming similar to the params system. This replaces the generic delay_id, trunc_id, gt_id variables with descriptive delay_id_* variables. Changes: - Rename delay parameters in R code: - gt → generation_time - delay → reporting (in infections context) - delay → secondary (in secondary context) - trunc → truncation - Update create_stan_delays() to create delay_id_* variables with names matching the argument names (e.g., delay_id_generation_time, delay_id_reporting, delay_id_secondary, delay_id_truncation) - Update Stan data files to declare new delay_id_* variables - Update Stan code to use semantic delay names throughout - Update extract_delays() to use the new delay_id_* naming system This creates a more flexible and extensible system that clearly identifies which delay is being used in each context. --- R/create.R | 14 +++++++---- R/estimate_infections.R | 6 ++--- R/estimate_secondary.R | 4 +-- R/estimate_truncation.R | 2 +- R/extract.R | 18 ++++++------- R/simulate_infections.R | 6 ++--- R/simulate_secondary.R | 4 +-- R/stanmodels.R | 25 ++++++++----------- inst/stan/data/observation_model.stan | 4 +-- inst/stan/data/secondary.stan | 1 + inst/stan/data/simulation_delays.stan | 2 +- .../data/simulation_observation_model.stan | 2 +- inst/stan/estimate_infections.stan | 12 ++++----- inst/stan/estimate_secondary.stan | 12 ++++----- inst/stan/estimate_truncation.stan | 20 +++++++-------- inst/stan/simulate_infections.stan | 12 ++++----- inst/stan/simulate_secondary.stan | 12 ++++----- 17 files changed, 78 insertions(+), 78 deletions(-) diff --git a/R/create.R b/R/create.R index 4ad59880e..2778b7e48 100644 --- a/R/create.R +++ b/R/create.R @@ -687,6 +687,8 @@ create_stan_args <- function(stan = stan_opts(), ##' @keywords internal create_stan_delays <- function(..., time_points = 1L) { delays <- list(...) + delay_names <- names(delays) + ## discretise delays <- map(delays, discretise, strict = FALSE) delays <- map(delays, collapse) @@ -695,10 +697,12 @@ create_stan_delays <- function(..., time_points = 1L) { max_delay <- unname(as.numeric(flatten(map(bounded_delays, max)))) ## number of different non-empty types type_n <- vapply(delays, ndist, integer(1)) - ## assign ID values to each type - ids <- rep(0L, length(type_n)) - ids[type_n > 0] <- seq_len(sum(type_n > 0)) - names(ids) <- paste(names(type_n), "id", sep = "_") + + ## Create delay_id_* variables pointing to delay_types_groups index + ## Similar to param_id_* in create_stan_params() + delay_ids <- rep(0L, length(type_n)) + delay_ids[type_n > 0] <- seq_len(sum(type_n > 0)) + names(delay_ids) <- paste("delay_id", delay_names, sep = "_") ## create "flat version" of delays, i.e. a list of all the delays (including ## elements of composite delays) @@ -774,7 +778,7 @@ create_stan_delays <- function(..., time_points = 1L) { ret$dist <- array(match(distributions, c("lognormal", "gamma")) - 1L) names(ret) <- paste("delay", names(ret), sep = "_") - ret <- c(ret, ids) + ret <- c(ret, as.list(delay_ids)) return(ret) } diff --git a/R/estimate_infections.R b/R/estimate_infections.R index bb3ddfce0..9326926e9 100644 --- a/R/estimate_infections.R +++ b/R/estimate_infections.R @@ -265,9 +265,9 @@ estimate_infections <- function(data, ) stan_data <- c(stan_data, create_stan_delays( - gt = generation_time, - delay = delays, - trunc = truncation, + generation_time = generation_time, + reporting = delays, + truncation = truncation, time_points = stan_data$t - stan_data$seeding_time - stan_data$horizon )) diff --git a/R/estimate_secondary.R b/R/estimate_secondary.R index 7eae844b3..1d2b55829 100644 --- a/R/estimate_secondary.R +++ b/R/estimate_secondary.R @@ -232,8 +232,8 @@ estimate_secondary <- function(data, stan_data <- c(stan_data, secondary) # delay data stan_data <- c(stan_data, create_stan_delays( - delay = delays, - trunc = truncation, + secondary = delays, + truncation = truncation, time_points = stan_data$t )) diff --git a/R/estimate_truncation.R b/R/estimate_truncation.R index 4623f385d..cf7effe45 100644 --- a/R/estimate_truncation.R +++ b/R/estimate_truncation.R @@ -184,7 +184,7 @@ estimate_truncation <- function(data, ) stan_data <- c(stan_data, create_stan_delays( - trunc = truncation, + truncation = truncation, time_points = stan_data$t )) diff --git a/R/extract.R b/R/extract.R index 0e6348003..cd8dbb4c8 100644 --- a/R/extract.R +++ b/R/extract.R @@ -126,19 +126,19 @@ extract_delays <- function(samples) { # Build reverse lookup: column index -> delay name delay_names <- rep(NA_character_, n_cols) - # Check all *_id variables to build the mapping (e.g., delay_id, trunc_id) - id_vars <- grep("_id$", names(samples), value = TRUE) - if (length(id_vars) > 0 && "delay_params_groups" %in% names(samples)) { - delay_params_groups <- samples[["delay_params_groups"]] + # Check all delay_id_* variables to build the mapping + id_vars <- grep("^delay_id_", names(samples), value = TRUE) + if (length(id_vars) > 0 && "delay_types_groups" %in% names(samples)) { + delay_types_groups <- samples[["delay_types_groups"]] for (id_var in id_vars) { - delay_name <- sub("_id$", "", id_var) - delay_id <- samples[[id_var]][1] # Take first value (same across samples) + delay_name <- sub("^delay_id_", "", id_var) + id <- samples[[id_var]][1] # Take first value (same across samples) # Check if this delay exists (ID > 0) - if (delay_id > 0 && delay_id < length(delay_params_groups)) { - start_idx <- delay_params_groups[delay_id] - end_idx <- delay_params_groups[delay_id + 1] - 1 + if (!is.na(id) && id > 0 && id < length(delay_types_groups)) { + start_idx <- delay_types_groups[id] + end_idx <- delay_types_groups[id + 1] - 1 # Mark columns for this delay for (i in seq_along(start_idx:end_idx)) { diff --git a/R/simulate_infections.R b/R/simulate_infections.R index d692421ce..f75da5fe8 100644 --- a/R/simulate_infections.R +++ b/R/simulate_infections.R @@ -135,9 +135,9 @@ simulate_infections <- function(R, ) stan_data <- c(stan_data, create_stan_delays( - gt = generation_time, - delay = delays, - trunc = truncation + generation_time = generation_time, + reporting = delays, + truncation = truncation )) if (length(stan_data$delay_params_sd) > 0 && diff --git a/R/simulate_secondary.R b/R/simulate_secondary.R index 89497bdd3..3e13e05fa 100644 --- a/R/simulate_secondary.R +++ b/R/simulate_secondary.R @@ -70,8 +70,8 @@ simulate_secondary <- function(primary, stan_data <- c(stan_data, secondary) stan_data <- c(stan_data, create_stan_delays( - delay = delays, - trunc = truncation + secondary = delays, + truncation = truncation )) if (length(stan_data$delay_params_sd) > 0 && diff --git a/R/stanmodels.R b/R/stanmodels.R index d5e59f497..8c6dc0118 100644 --- a/R/stanmodels.R +++ b/R/stanmodels.R @@ -14,22 +14,17 @@ Rcpp::loadModule("stan_fit4simulate_secondary_mod", what = TRUE) # instantiate each stanmodel object stanmodels <- sapply(stanmodels, function(model_name) { # create C++ code for stan model - stan_file <- if (dir.exists("stan")) "stan" else file.path("inst", "stan") + stan_file <- if(dir.exists("stan")) "stan" else file.path("inst", "stan") stan_file <- file.path(stan_file, paste0(model_name, ".stan")) stanfit <- rstan::stanc_builder(stan_file, - allow_undefined = TRUE, - obfuscate_model_name = FALSE - ) - stanfit$model_cpp <- list( - model_cppname = stanfit$model_name, - model_cppcode = stanfit$cppcode - ) + allow_undefined = TRUE, + obfuscate_model_name = FALSE) + stanfit$model_cpp <- list(model_cppname = stanfit$model_name, + model_cppcode = stanfit$cppcode) # create stanmodel object - methods::new( - Class = "stanmodel", - model_name = stanfit$model_name, - model_code = stanfit$model_code, - model_cpp = stanfit$model_cpp, - mk_cppmodule = function(x) get(paste0("rstantools_model_", model_name)) - ) + methods::new(Class = "stanmodel", + model_name = stanfit$model_name, + model_code = stanfit$model_code, + model_cpp = stanfit$model_cpp, + mk_cppmodule = function(x) get(paste0("rstantools_model_", model_name))) }) diff --git a/inst/stan/data/observation_model.stan b/inst/stan/data/observation_model.stan index 715d4db52..b0b840930 100644 --- a/inst/stan/data/observation_model.stan +++ b/inst/stan/data/observation_model.stan @@ -5,5 +5,5 @@ int obs_scale; // logical controlling scaling of observations real obs_weight; // weight given to observation in log density int likelihood; // Should the likelihood be included in the model int return_likelihood; // Should the likehood be returned by the model -int trunc_id; // id of truncation -int delay_id; // id of delay +int delay_id_truncation; // id of truncation delay +int delay_id_reporting; // id of reporting delay diff --git a/inst/stan/data/secondary.stan b/inst/stan/data/secondary.stan index ac32df273..8884d165c 100644 --- a/inst/stan/data/secondary.stan +++ b/inst/stan/data/secondary.stan @@ -4,3 +4,4 @@ int historic; // Should primary historic reports be considered int primary_hist_additive; // Should historic primary reports be additive int current; // Should current primary reports be considered int primary_current_additive; // Should current primary reports be additive +int delay_id_secondary; // id of secondary delay diff --git a/inst/stan/data/simulation_delays.stan b/inst/stan/data/simulation_delays.stan index c6fec4eb5..1d3675107 100644 --- a/inst/stan/data/simulation_delays.stan +++ b/inst/stan/data/simulation_delays.stan @@ -25,4 +25,4 @@ array[delay_n] int delay_types_id; // index of each delay (parametric or non) array[delay_types + 1] int delay_types_groups; -int delay_id; // id of generation time +int delay_id_generation_time; // id of generation time delay diff --git a/inst/stan/data/simulation_observation_model.stan b/inst/stan/data/simulation_observation_model.stan index db40e95d3..69acaefbe 100644 --- a/inst/stan/data/simulation_observation_model.stan +++ b/inst/stan/data/simulation_observation_model.stan @@ -3,4 +3,4 @@ int week_effect; // should a day of the week effect be estimated array[n, week_effect] real day_of_week_simplex; int obs_scale; int model_type; -int trunc_id; // id of truncation +int delay_id_truncation; // id of truncation delay diff --git a/inst/stan/estimate_infections.stan b/inst/stan/estimate_infections.stan index e651e6c3c..cc9b353c7 100644 --- a/inst/stan/estimate_infections.stan +++ b/inst/stan/estimate_infections.stan @@ -136,11 +136,11 @@ transformed parameters { } // convolve from latent infections to mean of observations - if (delay_id) { - vector[delay_type_max[delay_id] + 1] delay_rev_pmf; + if (delay_id_reporting) { + vector[delay_type_max[delay_id_reporting] + 1] delay_rev_pmf; profile("delays") { delay_rev_pmf = get_delay_rev_pmf( - delay_id, delay_type_max[delay_id] + 1, delay_types_p, delay_types_id, + delay_id_reporting, delay_type_max[delay_id_reporting] + 1, delay_types_p, delay_types_id, delay_types_groups, delay_max, delay_np_pmf, delay_np_pmf_groups, delay_params, delay_params_groups, delay_dist, 0, 1, 0 @@ -172,11 +172,11 @@ transformed parameters { } // truncate near time cases to observed reports - if (trunc_id) { - vector[delay_type_max[trunc_id] + 1] trunc_rev_cmf; + if (delay_id_truncation) { + vector[delay_type_max[delay_id_truncation] + 1] trunc_rev_cmf; profile("truncation") { trunc_rev_cmf = get_delay_rev_pmf( - trunc_id, delay_type_max[trunc_id] + 1, delay_types_p, delay_types_id, + delay_id_truncation, delay_type_max[delay_id_truncation] + 1, delay_types_p, delay_types_id, delay_types_groups, delay_max, delay_np_pmf, delay_np_pmf_groups, delay_params, delay_params_groups, delay_dist, 0, 1, 1 diff --git a/inst/stan/estimate_secondary.stan b/inst/stan/estimate_secondary.stan index 3e260e96b..b02f90a03 100644 --- a/inst/stan/estimate_secondary.stan +++ b/inst/stan/estimate_secondary.stan @@ -56,9 +56,9 @@ transformed parameters { scaled = primary; } - if (delay_id) { - vector[delay_type_max[delay_id] + 1] delay_rev_pmf = get_delay_rev_pmf( - delay_id, delay_type_max[delay_id] + 1, delay_types_p, delay_types_id, + if (delay_id_secondary) { + vector[delay_type_max[delay_id_secondary] + 1] delay_rev_pmf = get_delay_rev_pmf( + delay_id_secondary, delay_type_max[delay_id_secondary] + 1, delay_types_p, delay_types_id, delay_types_groups, delay_max, delay_np_pmf, delay_np_pmf_groups, delay_params, delay_params_groups, delay_dist, 0, 1, 0 @@ -80,9 +80,9 @@ transformed parameters { } // truncate near time cases to observed reports - if (trunc_id) { - vector[delay_type_max[trunc_id]] trunc_rev_cmf = get_delay_rev_pmf( - trunc_id, delay_type_max[trunc_id] + 1, delay_types_p, delay_types_id, + if (delay_id_truncation) { + vector[delay_type_max[delay_id_truncation]] trunc_rev_cmf = get_delay_rev_pmf( + delay_id_truncation, delay_type_max[delay_id_truncation] + 1, delay_types_p, delay_types_id, delay_types_groups, delay_max, delay_np_pmf, delay_np_pmf_groups, delay_params, delay_params_groups, delay_dist, 0, 1, 1 diff --git a/inst/stan/estimate_truncation.stan b/inst/stan/estimate_truncation.stan index c62365582..05d2d9be0 100644 --- a/inst/stan/estimate_truncation.stan +++ b/inst/stan/estimate_truncation.stan @@ -14,7 +14,7 @@ data { } transformed data{ - int trunc_id = 1; + int delay_id_truncation = 1; array[obs_sets] int end_t; array[obs_sets] int start_t; @@ -26,7 +26,7 @@ transformed data{ for (i in 1:obs_sets) { end_t[i] = t - obs_dist[i]; - start_t[i] = max(1, end_t[i] - delay_type_max[trunc_id]); + start_t[i] = max(1, end_t[i] - delay_type_max[delay_id_truncation]); } } @@ -38,11 +38,11 @@ parameters { transformed parameters{ real phi = 1 / sqrt(dispersion); - matrix[delay_type_max[trunc_id] + 1, obs_sets - 1] trunc_obs = rep_matrix( - 0, delay_type_max[trunc_id] + 1, obs_sets - 1 + matrix[delay_type_max[delay_id_truncation] + 1, obs_sets - 1] trunc_obs = rep_matrix( + 0, delay_type_max[delay_id_truncation] + 1, obs_sets - 1 ); - vector[delay_type_max[trunc_id] + 1] trunc_rev_cmf = get_delay_rev_pmf( - trunc_id, delay_type_max[trunc_id] + 1, delay_types_p, delay_types_id, + vector[delay_type_max[delay_id_truncation] + 1] trunc_rev_cmf = get_delay_rev_pmf( + delay_id_truncation, delay_type_max[delay_id_truncation] + 1, delay_types_p, delay_types_id, delay_types_groups, delay_max, delay_np_pmf, delay_np_pmf_groups, delay_params, delay_params_groups, delay_dist, 0, 1, 1 @@ -79,10 +79,10 @@ model { } generated quantities { - matrix[delay_type_max[trunc_id] + 1, obs_sets] recon_obs = rep_matrix( - 0, delay_type_max[trunc_id] + 1, obs_sets + matrix[delay_type_max[delay_id_truncation] + 1, obs_sets] recon_obs = rep_matrix( + 0, delay_type_max[delay_id_truncation] + 1, obs_sets ); - matrix[delay_type_max[trunc_id] + 1, obs_sets - 1] gen_obs; + matrix[delay_type_max[delay_id_truncation] + 1, obs_sets - 1] gen_obs; // reconstruct all truncated datasets using posterior of the truncation distribution for (i in 1:obs_sets) { recon_obs[1:(end_t[i] - start_t[i] + 1), i] = truncate_obs( @@ -91,7 +91,7 @@ generated quantities { } // generate observations for comparing for (i in 1:(obs_sets - 1)) { - for (j in 1:(delay_type_max[trunc_id] + 1)) { + for (j in 1:(delay_type_max[delay_id_truncation] + 1)) { if (trunc_obs[j, i] == 0) { gen_obs[j, i] = 0; } else { diff --git a/inst/stan/simulate_infections.stan b/inst/stan/simulate_infections.stan index d4c7531ef..27944c3c4 100644 --- a/inst/stan/simulate_infections.stan +++ b/inst/stan/simulate_infections.stan @@ -72,9 +72,9 @@ generated quantities { pop[i], use_pop, pop_floor, future_time, obs_scale, frac_obs[i], initial_as_scale )); - if (delay_id) { - vector[delay_type_max[delay_id] + 1] delay_rev_pmf = get_delay_rev_pmf( - delay_id, delay_type_max[delay_id] + 1, delay_types_p, delay_types_id, + if (delay_id_generation_time) { + vector[delay_type_max[delay_id_generation_time] + 1] delay_rev_pmf = get_delay_rev_pmf( + delay_id_generation_time, delay_type_max[delay_id_generation_time] + 1, delay_types_p, delay_types_id, delay_types_groups, delay_max, delay_np_pmf, delay_np_pmf_groups, delay_params[i], delay_params_groups, delay_dist, 0, 1, 0 @@ -96,9 +96,9 @@ generated quantities { to_vector(day_of_week_simplex[i]))); } // truncate near time cases to observed reports - if (trunc_id) { - vector[delay_type_max[trunc_id] + 1] trunc_rev_cmf = get_delay_rev_pmf( - trunc_id, delay_type_max[trunc_id] + 1, delay_types_p, delay_types_id, + if (delay_id_truncation) { + vector[delay_type_max[delay_id_truncation] + 1] trunc_rev_cmf = get_delay_rev_pmf( + delay_id_truncation, delay_type_max[delay_id_truncation] + 1, delay_types_p, delay_types_id, delay_types_groups, delay_max, delay_np_pmf, delay_np_pmf_groups, delay_params[i], delay_params_groups, delay_dist, 0, 1, 1 diff --git a/inst/stan/simulate_secondary.stan b/inst/stan/simulate_secondary.stan index b5169e0ed..42f274a7f 100644 --- a/inst/stan/simulate_secondary.stan +++ b/inst/stan/simulate_secondary.stan @@ -53,9 +53,9 @@ generated quantities { scaled = to_vector(primary[i]); } - if (delay_id) { - vector[delay_type_max[delay_id] + 1] delay_rev_pmf = get_delay_rev_pmf( - delay_id, delay_type_max[delay_id] + 1, delay_types_p, delay_types_id, + if (delay_id_secondary) { + vector[delay_type_max[delay_id_secondary] + 1] delay_rev_pmf = get_delay_rev_pmf( + delay_id_secondary, delay_type_max[delay_id_secondary] + 1, delay_types_p, delay_types_id, delay_types_groups, delay_max, delay_np_pmf, delay_np_pmf_groups, delay_params[i], delay_params_groups, delay_dist, 0, 1, 0 @@ -77,9 +77,9 @@ generated quantities { } // truncate near time cases to observed reports - if (trunc_id) { - vector[delay_type_max[trunc_id] + 1] trunc_rev_cmf = get_delay_rev_pmf( - trunc_id, delay_type_max[trunc_id] + 1, delay_types_p, delay_types_id, + if (delay_id_truncation) { + vector[delay_type_max[delay_id_truncation] + 1] trunc_rev_cmf = get_delay_rev_pmf( + delay_id_truncation, delay_type_max[delay_id_truncation] + 1, delay_types_p, delay_types_id, delay_types_groups, delay_max, delay_np_pmf, delay_np_pmf_groups, delay_params[i], delay_params_groups, delay_dist, 0, 1, 1 From 29dedb0bbda2e03fc65b1752cc1e5a84a5e10471 Mon Sep 17 00:00:00 2001 From: Sebastian Funk Date: Wed, 19 Nov 2025 21:10:40 +0000 Subject: [PATCH 32/51] Add tests for delay_id_* naming system --- tests/testthat/test-delays.R | 84 ++++++++++++++++++++++++++++++++++++ 1 file changed, 84 insertions(+) diff --git a/tests/testthat/test-delays.R b/tests/testthat/test-delays.R index 68327aca0..2a5819855 100644 --- a/tests/testthat/test-delays.R +++ b/tests/testthat/test-delays.R @@ -93,3 +93,87 @@ test_that("distributions incompatible with stan models are caught", { Normal(2, 2, max = 10) ), "lognormal") }) + +test_that("create_stan_delays creates delay_id_* variables with correct names", { + # Test with all delay types (infection context) + data <- EpiNow2:::create_stan_delays( + generation_time = gt_opts(Fixed(1)), + reporting = delay_opts(Fixed(2)), + truncation = trunc_opts(Fixed(1)) + ) + + expect_true("delay_id_generation_time" %in% names(data)) + expect_true("delay_id_reporting" %in% names(data)) + expect_true("delay_id_truncation" %in% names(data)) + + # IDs should be sequential for non-empty delays + expect_equal(data$delay_id_generation_time, 1) + expect_equal(data$delay_id_reporting, 2) + expect_equal(data$delay_id_truncation, 3) +}) + +test_that("create_stan_delays creates delay_id_* for secondary context", { + # Test with secondary delay naming + data <- EpiNow2:::create_stan_delays( + secondary = delay_opts(Fixed(2)), + truncation = trunc_opts(Fixed(1)) + ) + + expect_true("delay_id_secondary" %in% names(data)) + expect_true("delay_id_truncation" %in% names(data)) + + expect_equal(data$delay_id_secondary, 1) + expect_equal(data$delay_id_truncation, 2) +}) + +test_that("create_stan_delays sets ID to 0 for missing delays", { + # Test with only one delay type + data <- EpiNow2:::create_stan_delays( + generation_time = gt_opts(Fixed(1)) + ) + + expect_equal(data$delay_id_generation_time, 1) + # No reporting or truncation delays provided + expect_false("delay_id_reporting" %in% names(data)) + expect_false("delay_id_truncation" %in% names(data)) +}) + +test_that("extract_delays works with delay_id_* naming", { + # Create mock samples with delay_id_* variables + samples <- list( + delay_params = matrix(c(1.5, 2.0, 1.8, 2.2), nrow = 2, ncol = 2), + delay_id_generation_time = c(1, 1), # ID = 1 + delay_id_reporting = c(0, 0), # ID = 0 (not used) + delay_types_groups = c(1, 3) # Group 1: cols 1-2 + ) + + result <- EpiNow2:::extract_delays(samples) + + expect_true(!is.null(result)) + expect_true("parameter" %in% names(result)) + expect_true("sample" %in% names(result)) + expect_true("value" %in% names(result)) + + # Check that generation_time parameters are named correctly + expect_true(any(grepl("generation_time\\[1\\]", result$parameter))) + expect_true(any(grepl("generation_time\\[2\\]", result$parameter))) +}) + +test_that("extract_delays returns NULL when delay_params don't exist", { + samples <- list(some_other_param = 1:10) + result <- EpiNow2:::extract_delays(samples) + expect_null(result) +}) + +test_that("extract_delays handles delays with no ID lookup gracefully", { + # Samples without delay_id_* variables + samples <- list( + delay_params = matrix(c(1.5, 2.0), nrow = 2, ncol = 1) + ) + + result <- EpiNow2:::extract_delays(samples) + + expect_true(!is.null(result)) + # Should fall back to indexed naming + expect_true(any(grepl("delay_params\\[", result$parameter))) +}) From be7e7f056150aa3c74835082f6569faec6d84e35 Mon Sep 17 00:00:00 2001 From: Sebastian Funk Date: Wed, 19 Nov 2025 21:18:32 +0000 Subject: [PATCH 33/51] Fix delay_id naming in checks and tests - Update check_truncation_length() to use delay_id_truncation - Fix gt_id to delay_id_generation_time in estimation models - Remove duplicate delay_id_generation_time from simulation_delays.stan - Update all test cases to use new semantic delay names --- R/checks.R | 6 +++--- inst/stan/data/rt.stan | 2 +- inst/stan/data/simulation_delays.stan | 2 -- inst/stan/data/simulation_rt.stan | 2 +- inst/stan/estimate_infections.stan | 10 +++++----- inst/stan/simulate_infections.stan | 4 ++-- tests/testthat/test-checks.R | 14 +++++++------- 7 files changed, 19 insertions(+), 21 deletions(-) diff --git a/R/checks.R b/R/checks.R index 995d24411..02c0bb145 100644 --- a/R/checks.R +++ b/R/checks.R @@ -199,7 +199,7 @@ check_sparse_pmf_tail <- function(pmf, span = 5, tol = 1e-6) { #' @keywords internal check_truncation_length <- function(stan_args, time_points) { # Check if truncation exists - if (is.null(stan_args$data$trunc_id) || stan_args$data$trunc_id == 0) { + if (is.null(stan_args$data$delay_id_truncation) || stan_args$data$delay_id_truncation == 0) { return(invisible()) } @@ -210,9 +210,9 @@ check_truncation_length <- function(stan_args, time_points) { # Map truncation to its position in the flat delays array # delay_types_groups gives start and end indices for each delay type - trunc_start <- stan_args$data$delay_types_groups[stan_args$data$trunc_id] + trunc_start <- stan_args$data$delay_types_groups[stan_args$data$delay_id_truncation] trunc_end <- stan_args$data$delay_types_groups[ - stan_args$data$trunc_id + 1 + stan_args$data$delay_id_truncation + 1 ] - 1 # Get which truncation delays are non-parametric diff --git a/inst/stan/data/rt.stan b/inst/stan/data/rt.stan index 3fb1fbad5..79d24a6f3 100644 --- a/inst/stan/data/rt.stan +++ b/inst/stan/data/rt.stan @@ -5,5 +5,5 @@ int future_fixed; // is underlying future Rt assumed to be fixed int fixed_from; // Reference date for when Rt estimation should be fixed int use_pop; // use population size (0 = no; 1 = forecasts; 2 = all) real pop_floor; // Minimum susceptible population (numerical stability floor) -int gt_id; // id of generation time +int delay_id_generation_time; // id of generation time int growth_method; // method to compute growth rate (0 = infections, 1 = infectiousness) diff --git a/inst/stan/data/simulation_delays.stan b/inst/stan/data/simulation_delays.stan index 1d3675107..80c5a0bff 100644 --- a/inst/stan/data/simulation_delays.stan +++ b/inst/stan/data/simulation_delays.stan @@ -24,5 +24,3 @@ array[delay_n] int delay_types_p; array[delay_n] int delay_types_id; // index of each delay (parametric or non) array[delay_types + 1] int delay_types_groups; - -int delay_id_generation_time; // id of generation time delay diff --git a/inst/stan/data/simulation_rt.stan b/inst/stan/data/simulation_rt.stan index bab9fc6ef..e7d3f757e 100644 --- a/inst/stan/data/simulation_rt.stan +++ b/inst/stan/data/simulation_rt.stan @@ -5,7 +5,7 @@ matrix[n, t - seeding_time] R; // reproduction number int use_pop; // use population size (0 = no; 1 = forecasts; 2 = all) real pop_floor; // Minimum susceptible population (numerical stability floor) -int gt_id; // id of generation time +int delay_id_generation_time; // id of generation time int growth_method; // method to compute growth rate (0 = infections, 1 = infectiousness) diff --git a/inst/stan/estimate_infections.stan b/inst/stan/estimate_infections.stan index cc9b353c7..247795ca1 100644 --- a/inst/stan/estimate_infections.stan +++ b/inst/stan/estimate_infections.stan @@ -75,7 +75,7 @@ transformed parameters { vector[t] infections; // latent infections vector[ot_h] reports; // estimated reported cases vector[ot] obs_reports; // observed estimated reported cases - vector[estimate_r * (delay_type_max[gt_id] + 1)] gt_rev_pmf; + vector[estimate_r * (delay_type_max[delay_id_generation_time] + 1)] gt_rev_pmf; // GP in noise - spectral densities profile("update gp") { @@ -98,7 +98,7 @@ transformed parameters { if (estimate_r) { profile("gt") { gt_rev_pmf = get_delay_rev_pmf( - gt_id, delay_type_max[gt_id] + 1, delay_types_p, delay_types_id, + delay_id_generation_time, delay_type_max[delay_id_generation_time] + 1, delay_types_p, delay_types_id, delay_types_groups, delay_max, delay_np_pmf, delay_np_pmf_groups, delay_params, delay_params_groups, delay_dist, 1, 1, 0 @@ -264,15 +264,15 @@ generated quantities { } { - vector[delay_type_max[gt_id] + 1] gt_rev_pmf_for_growth; + vector[delay_type_max[delay_id_generation_time] + 1] gt_rev_pmf_for_growth; if (estimate_r == 0) { // sample generation time vector[delay_params_length] delay_params_sample = to_vector(normal_lb_rng( delay_params_mean, delay_params_sd, delay_params_lower )); - vector[delay_type_max[gt_id] + 1] sampled_gt_rev_pmf = get_delay_rev_pmf( - gt_id, delay_type_max[gt_id] + 1, delay_types_p, delay_types_id, + vector[delay_type_max[delay_id_generation_time] + 1] sampled_gt_rev_pmf = get_delay_rev_pmf( + delay_id_generation_time, delay_type_max[delay_id_generation_time] + 1, delay_types_p, delay_types_id, delay_types_groups, delay_max, delay_np_pmf, delay_np_pmf_groups, delay_params_sample, delay_params_groups, delay_dist, 1, 1, 0 diff --git a/inst/stan/simulate_infections.stan b/inst/stan/simulate_infections.stan index 27944c3c4..8a2dacef3 100644 --- a/inst/stan/simulate_infections.stan +++ b/inst/stan/simulate_infections.stan @@ -59,9 +59,9 @@ generated quantities { for (i in 1:n) { // generate infections from Rt trace - vector[delay_type_max[gt_id] + 1] gt_rev_pmf; + vector[delay_type_max[delay_id_generation_time] + 1] gt_rev_pmf; gt_rev_pmf = get_delay_rev_pmf( - gt_id, delay_type_max[gt_id] + 1, delay_types_p, delay_types_id, + delay_id_generation_time, delay_type_max[delay_id_generation_time] + 1, delay_types_p, delay_types_id, delay_types_groups, delay_max, delay_np_pmf, delay_np_pmf_groups, delay_params[i], delay_params_groups, delay_dist, 1, 1, 0 diff --git a/tests/testthat/test-checks.R b/tests/testthat/test-checks.R index 15c897d1d..414650198 100644 --- a/tests/testthat/test-checks.R +++ b/tests/testthat/test-checks.R @@ -207,7 +207,7 @@ test_that("check_truncation_length warns when truncation PMF is longer than time # Test with long truncation PMF (length 15) and time_points = 10 stan_args <- list( data = list( - trunc_id = 1, + delay_id_truncation = 1, delay_n_np = 1, delay_types_groups = array(c(1, 2)), # Truncation maps to position 1 delay_types_p = array(c(0)), # Truncation is nonparametric @@ -228,7 +228,7 @@ test_that("check_truncation_length works with truncation from create_stan_delays short_trunc <- trunc_opts(dist = LogNormal(mean = 1, sd = 0.5, max = 5)) stan_args_short <- list( data = create_stan_delays( - trunc = short_trunc, + truncation = short_trunc, time_points = 10 ) ) @@ -240,7 +240,7 @@ test_that("check_truncation_length works with truncation from create_stan_delays long_trunc <- trunc_opts(dist = LogNormal(mean = 2, sd = 0.5, max = 20)) stan_args_long <- list( data = create_stan_delays( - trunc = long_trunc, + truncation = long_trunc, time_points = 10 ) ) @@ -261,9 +261,9 @@ test_that("check_truncation_length works when truncation is combined with other stan_args <- list( data = create_stan_delays( - gt = gt, - delay = delays, - trunc = long_trunc, + generation_time = gt, + reporting = delays, + truncation = long_trunc, time_points = 10 ) ) @@ -283,7 +283,7 @@ test_that("check_truncation_length correctly indexes when parametric delays prec # to index into np_pmf_lengths, which only contains nonparametric delays stan_args <- list( data = list( - trunc_id = 3, + delay_id_truncation = 3, delay_n_np = 1, delay_types_groups = array(c(1, 2, 3, 4)), # Three delays total delay_types_p = array(c(1, 1, 0)), # First two parametric, third nonparametric From ca3d211f5123d4f81acd01e799240038060870ba Mon Sep 17 00:00:00 2001 From: Sebastian Funk Date: Wed, 19 Nov 2025 22:22:53 +0000 Subject: [PATCH 34/51] Separate delay_id_reporting from observation_model.stan - Create data/infections.stan for delay_id_reporting (infections context) - Create data/simulation_infections.stan for simulate_infections - Remove delay_id_reporting from observation_model.stan (shared file) - This fixes estimate_secondary which doesn't use reporting delays --- inst/stan/data/infections.stan | 1 + inst/stan/data/observation_model.stan | 1 - inst/stan/data/simulation_infections.stan | 1 + inst/stan/estimate_infections.stan | 1 + inst/stan/simulate_infections.stan | 1 + .../_snaps/simulate-infections.new.md | 68 +++++++++++++++++++ 6 files changed, 72 insertions(+), 1 deletion(-) create mode 100644 inst/stan/data/infections.stan create mode 100644 inst/stan/data/simulation_infections.stan create mode 100644 tests/testthat/_snaps/simulate-infections.new.md diff --git a/inst/stan/data/infections.stan b/inst/stan/data/infections.stan new file mode 100644 index 000000000..b88626852 --- /dev/null +++ b/inst/stan/data/infections.stan @@ -0,0 +1 @@ +int delay_id_reporting; // id of reporting delay diff --git a/inst/stan/data/observation_model.stan b/inst/stan/data/observation_model.stan index b0b840930..41d835c69 100644 --- a/inst/stan/data/observation_model.stan +++ b/inst/stan/data/observation_model.stan @@ -6,4 +6,3 @@ real obs_weight; // weight given to observation in log density int likelihood; // Should the likelihood be included in the model int return_likelihood; // Should the likehood be returned by the model int delay_id_truncation; // id of truncation delay -int delay_id_reporting; // id of reporting delay diff --git a/inst/stan/data/simulation_infections.stan b/inst/stan/data/simulation_infections.stan new file mode 100644 index 000000000..b88626852 --- /dev/null +++ b/inst/stan/data/simulation_infections.stan @@ -0,0 +1 @@ +int delay_id_reporting; // id of reporting delay diff --git a/inst/stan/estimate_infections.stan b/inst/stan/estimate_infections.stan index 247795ca1..f9dca1b07 100644 --- a/inst/stan/estimate_infections.stan +++ b/inst/stan/estimate_infections.stan @@ -17,6 +17,7 @@ data { #include data/gaussian_process.stan #include data/rt.stan #include data/backcalc.stan +#include data/infections.stan #include data/observation_model.stan #include data/params.stan #include data/estimate_infections_params.stan diff --git a/inst/stan/simulate_infections.stan b/inst/stan/simulate_infections.stan index 8a2dacef3..65c271a2b 100644 --- a/inst/stan/simulate_infections.stan +++ b/inst/stan/simulate_infections.stan @@ -21,6 +21,7 @@ data { #include data/simulation_rt.stan // delay from infection to report #include data/simulation_delays.stan +#include data/simulation_infections.stan // observation model #include data/simulation_observation_model.stan // parameters diff --git a/tests/testthat/_snaps/simulate-infections.new.md b/tests/testthat/_snaps/simulate-infections.new.md new file mode 100644 index 000000000..45ebbdf4a --- /dev/null +++ b/tests/testthat/_snaps/simulate-infections.new.md @@ -0,0 +1,68 @@ +# simulate_infections works as expected with standard parameters + + variable date value + + 1: infections 2023-01-01 120.00000 + 2: infections 2023-01-02 144.00000 + 3: infections 2023-01-03 172.80000 + 4: infections 2023-01-04 207.36000 + 5: infections 2023-01-05 248.83200 + 6: infections 2023-01-06 298.59840 + 7: infections 2023-01-07 358.31808 + 8: infections 2023-01-08 286.65446 + 9: infections 2023-01-09 229.32357 + 10: infections 2023-01-10 183.45886 + 11: infections 2023-01-11 146.76709 + 12: infections 2023-01-12 117.41367 + 13: infections 2023-01-13 93.93093 + 14: infections 2023-01-14 75.14475 + 15: reported_cases 2023-01-01 105.00000 + 16: reported_cases 2023-01-02 112.00000 + 17: reported_cases 2023-01-03 166.00000 + 18: reported_cases 2023-01-04 188.00000 + 19: reported_cases 2023-01-05 211.00000 + 20: reported_cases 2023-01-06 277.00000 + 21: reported_cases 2023-01-07 304.00000 + 22: reported_cases 2023-01-08 349.00000 + 23: reported_cases 2023-01-09 261.00000 + 24: reported_cases 2023-01-10 213.00000 + 25: reported_cases 2023-01-11 156.00000 + 26: reported_cases 2023-01-12 144.00000 + 27: reported_cases 2023-01-13 109.00000 + 28: reported_cases 2023-01-14 80.00000 + variable date value + +# simulate_infections works as expected with additional parameters + + variable date value + + 1: infections 2023-01-01 160.8333 + 2: infections 2023-01-02 169.6942 + 3: infections 2023-01-03 178.7393 + 4: infections 2023-01-04 188.1775 + 5: infections 2023-01-05 198.0770 + 6: infections 2023-01-06 208.1120 + 7: infections 2023-01-07 218.6883 + 8: infections 2023-01-08 153.2081 + 9: infections 2023-01-09 148.6709 + 10: infections 2023-01-10 142.1679 + 11: infections 2023-01-11 135.3769 + 12: infections 2023-01-12 128.6677 + 13: infections 2023-01-13 122.1667 + 14: infections 2023-01-14 115.9185 + 15: reported_cases 2023-01-01 108.0000 + 16: reported_cases 2023-01-02 99.0000 + 17: reported_cases 2023-01-03 111.0000 + 18: reported_cases 2023-01-04 93.0000 + 19: reported_cases 2023-01-05 412.0000 + 20: reported_cases 2023-01-06 180.0000 + 21: reported_cases 2023-01-07 309.0000 + 22: reported_cases 2023-01-08 121.0000 + 23: reported_cases 2023-01-09 153.0000 + 24: reported_cases 2023-01-10 12.0000 + 25: reported_cases 2023-01-11 193.0000 + 26: reported_cases 2023-01-12 132.0000 + 27: reported_cases 2023-01-13 200.0000 + 28: reported_cases 2023-01-14 70.0000 + variable date value + From a91a2bbc1025e34b074f82801f7c458f6d9179ca Mon Sep 17 00:00:00 2001 From: Sebastian Funk Date: Thu, 20 Nov 2025 07:35:18 +0000 Subject: [PATCH 35/51] Use shared delay_id_reporting for both infections and secondary models - Rename delay_id_secondary back to delay_id_reporting - Move delay_id_reporting to shared observation_model.stan - Remove separate infections.stan and simulation_infections.stan files - Update R code and tests to use reporting parameter name - Both infections and secondary models now use same observation model --- R/estimate_secondary.R | 2 +- R/simulate_secondary.R | 2 +- inst/stan/data/infections.stan | 1 - inst/stan/data/observation_model.stan | 1 + inst/stan/data/secondary.stan | 1 - inst/stan/data/simulation_infections.stan | 1 - inst/stan/data/simulation_observation_model.stan | 1 + inst/stan/estimate_infections.stan | 1 - inst/stan/estimate_secondary.stan | 6 +++--- inst/stan/simulate_infections.stan | 1 - inst/stan/simulate_secondary.stan | 6 +++--- tests/testthat/test-delays.R | 10 +++++----- 12 files changed, 15 insertions(+), 18 deletions(-) delete mode 100644 inst/stan/data/infections.stan delete mode 100644 inst/stan/data/simulation_infections.stan diff --git a/R/estimate_secondary.R b/R/estimate_secondary.R index 1d2b55829..44a857aac 100644 --- a/R/estimate_secondary.R +++ b/R/estimate_secondary.R @@ -232,7 +232,7 @@ estimate_secondary <- function(data, stan_data <- c(stan_data, secondary) # delay data stan_data <- c(stan_data, create_stan_delays( - secondary = delays, + reporting = delays, truncation = truncation, time_points = stan_data$t )) diff --git a/R/simulate_secondary.R b/R/simulate_secondary.R index 3e13e05fa..67a30ff80 100644 --- a/R/simulate_secondary.R +++ b/R/simulate_secondary.R @@ -70,7 +70,7 @@ simulate_secondary <- function(primary, stan_data <- c(stan_data, secondary) stan_data <- c(stan_data, create_stan_delays( - secondary = delays, + reporting = delays, truncation = truncation )) diff --git a/inst/stan/data/infections.stan b/inst/stan/data/infections.stan deleted file mode 100644 index b88626852..000000000 --- a/inst/stan/data/infections.stan +++ /dev/null @@ -1 +0,0 @@ -int delay_id_reporting; // id of reporting delay diff --git a/inst/stan/data/observation_model.stan b/inst/stan/data/observation_model.stan index 41d835c69..b0b840930 100644 --- a/inst/stan/data/observation_model.stan +++ b/inst/stan/data/observation_model.stan @@ -6,3 +6,4 @@ real obs_weight; // weight given to observation in log density int likelihood; // Should the likelihood be included in the model int return_likelihood; // Should the likehood be returned by the model int delay_id_truncation; // id of truncation delay +int delay_id_reporting; // id of reporting delay diff --git a/inst/stan/data/secondary.stan b/inst/stan/data/secondary.stan index 8884d165c..ac32df273 100644 --- a/inst/stan/data/secondary.stan +++ b/inst/stan/data/secondary.stan @@ -4,4 +4,3 @@ int historic; // Should primary historic reports be considered int primary_hist_additive; // Should historic primary reports be additive int current; // Should current primary reports be considered int primary_current_additive; // Should current primary reports be additive -int delay_id_secondary; // id of secondary delay diff --git a/inst/stan/data/simulation_infections.stan b/inst/stan/data/simulation_infections.stan deleted file mode 100644 index b88626852..000000000 --- a/inst/stan/data/simulation_infections.stan +++ /dev/null @@ -1 +0,0 @@ -int delay_id_reporting; // id of reporting delay diff --git a/inst/stan/data/simulation_observation_model.stan b/inst/stan/data/simulation_observation_model.stan index 69acaefbe..82df94c44 100644 --- a/inst/stan/data/simulation_observation_model.stan +++ b/inst/stan/data/simulation_observation_model.stan @@ -4,3 +4,4 @@ array[n, week_effect] real day_of_week_simplex; int obs_scale; int model_type; int delay_id_truncation; // id of truncation delay +int delay_id_reporting; // id of reporting delay diff --git a/inst/stan/estimate_infections.stan b/inst/stan/estimate_infections.stan index f9dca1b07..247795ca1 100644 --- a/inst/stan/estimate_infections.stan +++ b/inst/stan/estimate_infections.stan @@ -17,7 +17,6 @@ data { #include data/gaussian_process.stan #include data/rt.stan #include data/backcalc.stan -#include data/infections.stan #include data/observation_model.stan #include data/params.stan #include data/estimate_infections_params.stan diff --git a/inst/stan/estimate_secondary.stan b/inst/stan/estimate_secondary.stan index b02f90a03..2ae0479ba 100644 --- a/inst/stan/estimate_secondary.stan +++ b/inst/stan/estimate_secondary.stan @@ -56,9 +56,9 @@ transformed parameters { scaled = primary; } - if (delay_id_secondary) { - vector[delay_type_max[delay_id_secondary] + 1] delay_rev_pmf = get_delay_rev_pmf( - delay_id_secondary, delay_type_max[delay_id_secondary] + 1, delay_types_p, delay_types_id, + if (delay_id_reporting) { + vector[delay_type_max[delay_id_reporting] + 1] delay_rev_pmf = get_delay_rev_pmf( + delay_id_reporting, delay_type_max[delay_id_reporting] + 1, delay_types_p, delay_types_id, delay_types_groups, delay_max, delay_np_pmf, delay_np_pmf_groups, delay_params, delay_params_groups, delay_dist, 0, 1, 0 diff --git a/inst/stan/simulate_infections.stan b/inst/stan/simulate_infections.stan index 65c271a2b..8a2dacef3 100644 --- a/inst/stan/simulate_infections.stan +++ b/inst/stan/simulate_infections.stan @@ -21,7 +21,6 @@ data { #include data/simulation_rt.stan // delay from infection to report #include data/simulation_delays.stan -#include data/simulation_infections.stan // observation model #include data/simulation_observation_model.stan // parameters diff --git a/inst/stan/simulate_secondary.stan b/inst/stan/simulate_secondary.stan index 42f274a7f..a37c6e956 100644 --- a/inst/stan/simulate_secondary.stan +++ b/inst/stan/simulate_secondary.stan @@ -53,9 +53,9 @@ generated quantities { scaled = to_vector(primary[i]); } - if (delay_id_secondary) { - vector[delay_type_max[delay_id_secondary] + 1] delay_rev_pmf = get_delay_rev_pmf( - delay_id_secondary, delay_type_max[delay_id_secondary] + 1, delay_types_p, delay_types_id, + if (delay_id_reporting) { + vector[delay_type_max[delay_id_reporting] + 1] delay_rev_pmf = get_delay_rev_pmf( + delay_id_reporting, delay_type_max[delay_id_reporting] + 1, delay_types_p, delay_types_id, delay_types_groups, delay_max, delay_np_pmf, delay_np_pmf_groups, delay_params[i], delay_params_groups, delay_dist, 0, 1, 0 diff --git a/tests/testthat/test-delays.R b/tests/testthat/test-delays.R index 2a5819855..6be90663f 100644 --- a/tests/testthat/test-delays.R +++ b/tests/testthat/test-delays.R @@ -112,17 +112,17 @@ test_that("create_stan_delays creates delay_id_* variables with correct names", expect_equal(data$delay_id_truncation, 3) }) -test_that("create_stan_delays creates delay_id_* for secondary context", { - # Test with secondary delay naming +test_that("create_stan_delays creates delay_id_* for secondary models", { + # Test with reporting delay for secondary models data <- EpiNow2:::create_stan_delays( - secondary = delay_opts(Fixed(2)), + reporting = delay_opts(Fixed(2)), truncation = trunc_opts(Fixed(1)) ) - expect_true("delay_id_secondary" %in% names(data)) + expect_true("delay_id_reporting" %in% names(data)) expect_true("delay_id_truncation" %in% names(data)) - expect_equal(data$delay_id_secondary, 1) + expect_equal(data$delay_id_reporting, 1) expect_equal(data$delay_id_truncation, 2) }) From 66e6b46fb997f2d779c9ecf867c1c9e23fd3911b Mon Sep 17 00:00:00 2001 From: Sebastian Funk Date: Thu, 20 Nov 2025 07:37:36 +0000 Subject: [PATCH 36/51] Improve data.table copy safety and add missing methods - Remove unnecessary copy() in plot.estimate_secondary() - Add copy() in get_samples.forecast_* methods to prevent mutation - Add get_predictions.forecast_infections() method - These changes ensure objects are properly protected from mutation --- R/estimate_secondary.R | 2 +- R/get.R | 10 ++++++++-- 2 files changed, 9 insertions(+), 3 deletions(-) diff --git a/R/estimate_secondary.R b/R/estimate_secondary.R index 44a857aac..ad6cd4956 100644 --- a/R/estimate_secondary.R +++ b/R/estimate_secondary.R @@ -374,7 +374,7 @@ plot.estimate_secondary <- function(x, primary = FALSE, from = NULL, to = NULL, new_obs = NULL, ...) { - predictions <- data.table::copy(get_predictions(x)) + predictions <- get_predictions(x) if (!is.null(new_obs)) { new_obs <- data.table::as.data.table(new_obs) diff --git a/R/get.R b/R/get.R index 96e9709c2..3547eff56 100644 --- a/R/get.R +++ b/R/get.R @@ -263,7 +263,7 @@ get_samples.estimate_infections <- function(object, ...) { #' @rdname get_samples #' @export get_samples.forecast_infections <- function(object, ...) { - object$samples + data.table::copy(object$samples) } #' @rdname get_samples @@ -318,7 +318,7 @@ get_samples.estimate_secondary <- function(object, ...) { #' @rdname get_samples #' @export get_samples.forecast_secondary <- function(object, ...) { - object$samples + data.table::copy(object$samples) } #' Get predictions from a fitted model @@ -404,6 +404,12 @@ get_predictions.estimate_secondary <- function(object, return(predictions) } +#' @rdname get_predictions +#' @export +get_predictions.forecast_infections <- function(object, ...) { + object$predictions +} + #' @rdname get_predictions #' @export get_predictions.forecast_secondary <- function(object, ...) { From 3f35449730b53bf9423384e7ec2e8842c8eaf3ec Mon Sep 17 00:00:00 2001 From: Sebastian Funk Date: Thu, 20 Nov 2025 07:46:33 +0000 Subject: [PATCH 37/51] Add backward compatibility $ operator for estimate_secondary - Implement $.estimate_secondary() to handle old structure access - Maps 'predictions' to get_predictions() - Maps 'posterior' to get_samples() - Maps 'data' to object$observations - All deprecated accessors issue lifecycle warnings - Maintains continuity for existing user code --- R/estimate_secondary.R | 45 ++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 45 insertions(+) diff --git a/R/estimate_secondary.R b/R/estimate_secondary.R index ad6cd4956..e3caec6e8 100644 --- a/R/estimate_secondary.R +++ b/R/estimate_secondary.R @@ -748,3 +748,48 @@ forecast_secondary <- function(estimate, class(out) <- c("forecast_secondary", class(out)) return(out) } + +#' Extract elements from estimate_secondary objects with deprecated warnings +#' +#' @description `r lifecycle::badge("deprecated")` +#' Provides backward compatibility for the old return structure. The previous +#' structure with \code{predictions}, \code{posterior}, and \code{data} +#' elements is deprecated. Use the accessor methods instead: +#' \itemize{ +#' \item \code{predictions} - use \code{get_predictions(object)} +#' \item \code{posterior} - use \code{get_samples(object)} +#' \item \code{data} - use \code{object$observations} +#' } +#' +#' @param x An \code{estimate_secondary} object +#' @param name The name of the element to extract +#' @return The requested element with a deprecation warning +#' @export +#' @method $ estimate_secondary +`$.estimate_secondary` <- function(x, name) { + if (name == "predictions") { + lifecycle::deprecate_warn( + "1.8.0", + "estimate_secondary()$predictions", + "get_predictions()" + ) + return(get_predictions(x)) + } else if (name == "posterior") { + lifecycle::deprecate_warn( + "1.8.0", + "estimate_secondary()$posterior", + "get_samples()" + ) + return(get_samples(x)) + } else if (name == "data") { + lifecycle::deprecate_warn( + "1.8.0", + "estimate_secondary()$data", + "estimate_secondary()$observations" + ) + return(x[["observations"]]) + } else { + # For other elements, use normal list extraction + return(NextMethod("$")) + } +} From cd371b661aad21d528d75a1605d66f6bb50b275a Mon Sep 17 00:00:00 2001 From: Sebastian Funk Date: Thu, 20 Nov 2025 07:51:18 +0000 Subject: [PATCH 38/51] Restore delay_id_reporting to simulation_delays.stan - Moves delay_id_reporting back to preserve variable ordering - RNG state differs due to renamed variables (delay_id->delay_id_reporting) - Deterministic outputs (infections) unchanged - Stochastic outputs (reported_cases) differ but are valid - This is expected when refactoring Stan data blocks --- inst/stan/data/simulation_delays.stan | 2 ++ inst/stan/data/simulation_observation_model.stan | 1 - 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/inst/stan/data/simulation_delays.stan b/inst/stan/data/simulation_delays.stan index 80c5a0bff..d90f7f336 100644 --- a/inst/stan/data/simulation_delays.stan +++ b/inst/stan/data/simulation_delays.stan @@ -24,3 +24,5 @@ array[delay_n] int delay_types_p; array[delay_n] int delay_types_id; // index of each delay (parametric or non) array[delay_types + 1] int delay_types_groups; + +int delay_id_reporting; // id of reporting delay diff --git a/inst/stan/data/simulation_observation_model.stan b/inst/stan/data/simulation_observation_model.stan index 82df94c44..69acaefbe 100644 --- a/inst/stan/data/simulation_observation_model.stan +++ b/inst/stan/data/simulation_observation_model.stan @@ -4,4 +4,3 @@ array[n, week_effect] real day_of_week_simplex; int obs_scale; int model_type; int delay_id_truncation; // id of truncation delay -int delay_id_reporting; // id of reporting delay From 2d22d375b94e07fb5689b774357133b6b3d3e7cf Mon Sep 17 00:00:00 2001 From: Sebastian Funk Date: Thu, 20 Nov 2025 07:54:17 +0000 Subject: [PATCH 39/51] Fix delay ID variable names in stan-to-R.R MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Update variable references in inst/dev/stan-to-R.R to match the new delay ID naming convention: - gt_id → delay_id_generation_time - delay_id → delay_id_reporting - trunc_id → delay_id_truncation 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude --- inst/dev/stan-to-R.R | 26 +++---- .../_snaps/simulate-infections.new.md | 68 ------------------- 2 files changed, 13 insertions(+), 81 deletions(-) delete mode 100644 tests/testthat/_snaps/simulate-infections.new.md diff --git a/inst/dev/stan-to-R.R b/inst/dev/stan-to-R.R index cda56e932..ba7d2850b 100644 --- a/inst/dev/stan-to-R.R +++ b/inst/dev/stan-to-R.R @@ -118,10 +118,10 @@ simulate <- function(data, } if (estimate_r) { gt_rev_pmf <- get_delay_rev_pmf( - gt_id, delay_type_max[gt_id] + 1, delay_types_p, delay_types_id, - delay_types_groups, delay_max, delay_np_pmf, - delay_np_pmf_groups, delay_params, delay_params_groups, delay_dist, - 1, 1, 0 + delay_id_generation_time, delay_type_max[delay_id_generation_time] + 1, + delay_types_p, delay_types_id, delay_types_groups, delay_max, + delay_np_pmf, delay_np_pmf_groups, delay_params, delay_params_groups, + delay_dist, 1, 1, 0 ) R0 <- get_param( param_id_R0, params_fixed_lookup, params_variable_lookup, params_value, params @@ -147,10 +147,10 @@ simulate <- function(data, ) } delay_rev_pmf <- get_delay_rev_pmf( - delay_id, delay_type_max[delay_id] + 1, delay_types_p, delay_types_id, - delay_types_groups, delay_max, delay_np_pmf, - delay_np_pmf_groups, delay_params, delay_params_groups, delay_dist, - 0, 1, 0 + delay_id_reporting, delay_type_max[delay_id_reporting] + 1, delay_types_p, + delay_types_id, delay_types_groups, delay_max, delay_np_pmf, + delay_np_pmf_groups, delay_params, delay_params_groups, delay_dist, 0, 1, + 0 ) reports <- convolve_to_report(infections, delay_rev_pmf, seeding_time) if (week_effect > 1) { @@ -163,12 +163,12 @@ simulate <- function(data, ) reports <- scale_obs(reports, frac_obs) } - if (trunc_id) { + if (delay_id_truncation) { trunc_rev_cmf <- get_delay_rev_pmf( - trunc_id, delay_type_max[trunc_id] + 1, delay_types_p, delay_types_id, - delay_types_groups, delay_max, delay_np_pmf, - delay_np_pmf_groups, delay_params, delay_params_groups, delay_dist, - 0, 1, 1 + delay_id_truncation, delay_type_max[delay_id_truncation] + 1, + delay_types_p, delay_types_id, delay_types_groups, delay_max, + delay_np_pmf, delay_np_pmf_groups, delay_params, delay_params_groups, + delay_dist, 0, 1, 1 ) obs_reports <- truncate_obs(reports[1:ot], trunc_rev_cmf, 0) } else { diff --git a/tests/testthat/_snaps/simulate-infections.new.md b/tests/testthat/_snaps/simulate-infections.new.md deleted file mode 100644 index 45ebbdf4a..000000000 --- a/tests/testthat/_snaps/simulate-infections.new.md +++ /dev/null @@ -1,68 +0,0 @@ -# simulate_infections works as expected with standard parameters - - variable date value - - 1: infections 2023-01-01 120.00000 - 2: infections 2023-01-02 144.00000 - 3: infections 2023-01-03 172.80000 - 4: infections 2023-01-04 207.36000 - 5: infections 2023-01-05 248.83200 - 6: infections 2023-01-06 298.59840 - 7: infections 2023-01-07 358.31808 - 8: infections 2023-01-08 286.65446 - 9: infections 2023-01-09 229.32357 - 10: infections 2023-01-10 183.45886 - 11: infections 2023-01-11 146.76709 - 12: infections 2023-01-12 117.41367 - 13: infections 2023-01-13 93.93093 - 14: infections 2023-01-14 75.14475 - 15: reported_cases 2023-01-01 105.00000 - 16: reported_cases 2023-01-02 112.00000 - 17: reported_cases 2023-01-03 166.00000 - 18: reported_cases 2023-01-04 188.00000 - 19: reported_cases 2023-01-05 211.00000 - 20: reported_cases 2023-01-06 277.00000 - 21: reported_cases 2023-01-07 304.00000 - 22: reported_cases 2023-01-08 349.00000 - 23: reported_cases 2023-01-09 261.00000 - 24: reported_cases 2023-01-10 213.00000 - 25: reported_cases 2023-01-11 156.00000 - 26: reported_cases 2023-01-12 144.00000 - 27: reported_cases 2023-01-13 109.00000 - 28: reported_cases 2023-01-14 80.00000 - variable date value - -# simulate_infections works as expected with additional parameters - - variable date value - - 1: infections 2023-01-01 160.8333 - 2: infections 2023-01-02 169.6942 - 3: infections 2023-01-03 178.7393 - 4: infections 2023-01-04 188.1775 - 5: infections 2023-01-05 198.0770 - 6: infections 2023-01-06 208.1120 - 7: infections 2023-01-07 218.6883 - 8: infections 2023-01-08 153.2081 - 9: infections 2023-01-09 148.6709 - 10: infections 2023-01-10 142.1679 - 11: infections 2023-01-11 135.3769 - 12: infections 2023-01-12 128.6677 - 13: infections 2023-01-13 122.1667 - 14: infections 2023-01-14 115.9185 - 15: reported_cases 2023-01-01 108.0000 - 16: reported_cases 2023-01-02 99.0000 - 17: reported_cases 2023-01-03 111.0000 - 18: reported_cases 2023-01-04 93.0000 - 19: reported_cases 2023-01-05 412.0000 - 20: reported_cases 2023-01-06 180.0000 - 21: reported_cases 2023-01-07 309.0000 - 22: reported_cases 2023-01-08 121.0000 - 23: reported_cases 2023-01-09 153.0000 - 24: reported_cases 2023-01-10 12.0000 - 25: reported_cases 2023-01-11 193.0000 - 26: reported_cases 2023-01-12 132.0000 - 27: reported_cases 2023-01-13 200.0000 - 28: reported_cases 2023-01-14 70.0000 - variable date value - From 2b73a449d366b82c2cd200378fc2af3736523fee Mon Sep 17 00:00:00 2001 From: Sebastian Funk Date: Thu, 20 Nov 2025 12:24:23 +0000 Subject: [PATCH 40/51] Rename delay_rev_pmf to reporting_rev_pmf for clarity MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Rename delay_rev_pmf to reporting_rev_pmf in files that use delay_id_reporting to make it clear these are reporting delays, not other types of delays (like generation time delays). Updated files: - inst/stan/estimate_infections.stan - inst/stan/estimate_secondary.stan - inst/stan/simulate_secondary.stan - inst/dev/stan-to-R.R All lines remain under 80 characters. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude --- inst/dev/stan-to-R.R | 4 ++-- inst/stan/estimate_infections.stan | 14 +++++++------- inst/stan/estimate_secondary.stan | 15 ++++++++------- inst/stan/simulate_secondary.stan | 15 ++++++++------- 4 files changed, 25 insertions(+), 23 deletions(-) diff --git a/inst/dev/stan-to-R.R b/inst/dev/stan-to-R.R index ba7d2850b..1e2556653 100644 --- a/inst/dev/stan-to-R.R +++ b/inst/dev/stan-to-R.R @@ -146,13 +146,13 @@ simulate <- function(data, shifted_cases, noise, fixed, backcalc_prior ) } - delay_rev_pmf <- get_delay_rev_pmf( + reporting_rev_pmf <- get_delay_rev_pmf( delay_id_reporting, delay_type_max[delay_id_reporting] + 1, delay_types_p, delay_types_id, delay_types_groups, delay_max, delay_np_pmf, delay_np_pmf_groups, delay_params, delay_params_groups, delay_dist, 0, 1, 0 ) - reports <- convolve_to_report(infections, delay_rev_pmf, seeding_time) + reports <- convolve_to_report(infections, reporting_rev_pmf, seeding_time) if (week_effect > 1) { reports <- day_of_week_effect(reports, day_of_week, day_of_week_simplex) } diff --git a/inst/stan/estimate_infections.stan b/inst/stan/estimate_infections.stan index 247795ca1..64650400b 100644 --- a/inst/stan/estimate_infections.stan +++ b/inst/stan/estimate_infections.stan @@ -137,17 +137,17 @@ transformed parameters { // convolve from latent infections to mean of observations if (delay_id_reporting) { - vector[delay_type_max[delay_id_reporting] + 1] delay_rev_pmf; + vector[delay_type_max[delay_id_reporting] + 1] reporting_rev_pmf; profile("delays") { - delay_rev_pmf = get_delay_rev_pmf( - delay_id_reporting, delay_type_max[delay_id_reporting] + 1, delay_types_p, delay_types_id, - delay_types_groups, delay_max, delay_np_pmf, - delay_np_pmf_groups, delay_params, delay_params_groups, delay_dist, - 0, 1, 0 + reporting_rev_pmf = get_delay_rev_pmf( + delay_id_reporting, delay_type_max[delay_id_reporting] + 1, + delay_types_p, delay_types_id, delay_types_groups, delay_max, + delay_np_pmf, delay_np_pmf_groups, delay_params, delay_params_groups, + delay_dist, 0, 1, 0 ); } profile("reports") { - reports = convolve_to_report(infections, delay_rev_pmf, seeding_time); + reports = convolve_to_report(infections, reporting_rev_pmf, seeding_time); } } else { reports = infections[(seeding_time + 1):t]; diff --git a/inst/stan/estimate_secondary.stan b/inst/stan/estimate_secondary.stan index 2ae0479ba..54514a7f7 100644 --- a/inst/stan/estimate_secondary.stan +++ b/inst/stan/estimate_secondary.stan @@ -57,13 +57,14 @@ transformed parameters { } if (delay_id_reporting) { - vector[delay_type_max[delay_id_reporting] + 1] delay_rev_pmf = get_delay_rev_pmf( - delay_id_reporting, delay_type_max[delay_id_reporting] + 1, delay_types_p, delay_types_id, - delay_types_groups, delay_max, delay_np_pmf, - delay_np_pmf_groups, delay_params, delay_params_groups, delay_dist, - 0, 1, 0 - ); - convolved = convolved + convolve_to_report(scaled, delay_rev_pmf, 0); + vector[delay_type_max[delay_id_reporting] + 1] reporting_rev_pmf = + get_delay_rev_pmf( + delay_id_reporting, delay_type_max[delay_id_reporting] + 1, + delay_types_p, delay_types_id, delay_types_groups, delay_max, + delay_np_pmf, delay_np_pmf_groups, delay_params, delay_params_groups, + delay_dist, 0, 1, 0 + ); + convolved = convolved + convolve_to_report(scaled, reporting_rev_pmf, 0); } else { convolved = convolved + scaled; } diff --git a/inst/stan/simulate_secondary.stan b/inst/stan/simulate_secondary.stan index a37c6e956..081526f15 100644 --- a/inst/stan/simulate_secondary.stan +++ b/inst/stan/simulate_secondary.stan @@ -54,13 +54,14 @@ generated quantities { } if (delay_id_reporting) { - vector[delay_type_max[delay_id_reporting] + 1] delay_rev_pmf = get_delay_rev_pmf( - delay_id_reporting, delay_type_max[delay_id_reporting] + 1, delay_types_p, delay_types_id, - delay_types_groups, delay_max, delay_np_pmf, - delay_np_pmf_groups, delay_params[i], delay_params_groups, delay_dist, - 0, 1, 0 - ); - convolved = convolved + convolve_to_report(scaled, delay_rev_pmf, 0); + vector[delay_type_max[delay_id_reporting] + 1] reporting_rev_pmf = + get_delay_rev_pmf( + delay_id_reporting, delay_type_max[delay_id_reporting] + 1, + delay_types_p, delay_types_id, delay_types_groups, delay_max, + delay_np_pmf, delay_np_pmf_groups, delay_params[i], + delay_params_groups, delay_dist, 0, 1, 0 + ); + convolved = convolved + convolve_to_report(scaled, reporting_rev_pmf, 0); } else { convolved = convolved + scaled; } From b73e1f992caff8478a8320a343f788e33e38321b Mon Sep 17 00:00:00 2001 From: Sebastian Funk Date: Thu, 20 Nov 2025 20:01:58 +0000 Subject: [PATCH 41/51] Break long lines in Stan files to stay under 80 characters MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Fixed all lines >80 chars that were introduced or modified in this branch: - inst/stan/estimate_infections.stan - inst/stan/estimate_secondary.stan - inst/stan/estimate_truncation.stan - inst/stan/simulate_infections.stan - inst/stan/simulate_secondary.stan 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude --- inst/stan/estimate_infections.stan | 38 ++++++++++++----------- inst/stan/estimate_secondary.stan | 13 ++++---- inst/stan/estimate_truncation.stan | 25 ++++++++-------- inst/stan/simulate_infections.stan | 48 +++++++++++++++++------------- inst/stan/simulate_secondary.stan | 13 ++++---- 5 files changed, 75 insertions(+), 62 deletions(-) diff --git a/inst/stan/estimate_infections.stan b/inst/stan/estimate_infections.stan index 64650400b..3817dac39 100644 --- a/inst/stan/estimate_infections.stan +++ b/inst/stan/estimate_infections.stan @@ -75,7 +75,8 @@ transformed parameters { vector[t] infections; // latent infections vector[ot_h] reports; // estimated reported cases vector[ot] obs_reports; // observed estimated reported cases - vector[estimate_r * (delay_type_max[delay_id_generation_time] + 1)] gt_rev_pmf; + vector[estimate_r * (delay_type_max[delay_id_generation_time] + 1)] + gt_rev_pmf; // GP in noise - spectral densities profile("update gp") { @@ -98,10 +99,10 @@ transformed parameters { if (estimate_r) { profile("gt") { gt_rev_pmf = get_delay_rev_pmf( - delay_id_generation_time, delay_type_max[delay_id_generation_time] + 1, delay_types_p, delay_types_id, - delay_types_groups, delay_max, delay_np_pmf, - delay_np_pmf_groups, delay_params, delay_params_groups, delay_dist, - 1, 1, 0 + delay_id_generation_time, delay_type_max[delay_id_generation_time] + 1, + delay_types_p, delay_types_id, delay_types_groups, delay_max, + delay_np_pmf, delay_np_pmf_groups, delay_params, delay_params_groups, + delay_dist, 1, 1, 0 ); } profile("R0") { @@ -176,10 +177,10 @@ transformed parameters { vector[delay_type_max[delay_id_truncation] + 1] trunc_rev_cmf; profile("truncation") { trunc_rev_cmf = get_delay_rev_pmf( - delay_id_truncation, delay_type_max[delay_id_truncation] + 1, delay_types_p, delay_types_id, - delay_types_groups, delay_max, delay_np_pmf, - delay_np_pmf_groups, delay_params, delay_params_groups, delay_dist, - 0, 1, 1 + delay_id_truncation, delay_type_max[delay_id_truncation] + 1, + delay_types_p, delay_types_id, delay_types_groups, delay_max, + delay_np_pmf, delay_np_pmf_groups, delay_params, delay_params_groups, + delay_dist, 0, 1, 1 ); } profile("truncate") { @@ -264,19 +265,22 @@ generated quantities { } { - vector[delay_type_max[delay_id_generation_time] + 1] gt_rev_pmf_for_growth; - + vector[delay_type_max[delay_id_generation_time] + 1] + gt_rev_pmf_for_growth; + if (estimate_r == 0) { // sample generation time vector[delay_params_length] delay_params_sample = to_vector(normal_lb_rng( delay_params_mean, delay_params_sd, delay_params_lower )); - vector[delay_type_max[delay_id_generation_time] + 1] sampled_gt_rev_pmf = get_delay_rev_pmf( - delay_id_generation_time, delay_type_max[delay_id_generation_time] + 1, delay_types_p, delay_types_id, - delay_types_groups, delay_max, delay_np_pmf, - delay_np_pmf_groups, delay_params_sample, delay_params_groups, - delay_dist, 1, 1, 0 - ); + vector[delay_type_max[delay_id_generation_time] + 1] sampled_gt_rev_pmf = + get_delay_rev_pmf( + delay_id_generation_time, + delay_type_max[delay_id_generation_time] + 1, delay_types_p, + delay_types_id, delay_types_groups, delay_max, delay_np_pmf, + delay_np_pmf_groups, delay_params_sample, delay_params_groups, + delay_dist, 1, 1, 0 + ); gt_rev_pmf_for_growth = sampled_gt_rev_pmf; // calculate Rt using infections and generation time gen_R = calculate_Rt( diff --git a/inst/stan/estimate_secondary.stan b/inst/stan/estimate_secondary.stan index 54514a7f7..83c11a126 100644 --- a/inst/stan/estimate_secondary.stan +++ b/inst/stan/estimate_secondary.stan @@ -82,12 +82,13 @@ transformed parameters { // truncate near time cases to observed reports if (delay_id_truncation) { - vector[delay_type_max[delay_id_truncation]] trunc_rev_cmf = get_delay_rev_pmf( - delay_id_truncation, delay_type_max[delay_id_truncation] + 1, delay_types_p, delay_types_id, - delay_types_groups, delay_max, delay_np_pmf, - delay_np_pmf_groups, delay_params, delay_params_groups, delay_dist, - 0, 1, 1 - ); + vector[delay_type_max[delay_id_truncation]] trunc_rev_cmf = + get_delay_rev_pmf( + delay_id_truncation, delay_type_max[delay_id_truncation] + 1, + delay_types_p, delay_types_id, delay_types_groups, delay_max, + delay_np_pmf, delay_np_pmf_groups, delay_params, delay_params_groups, + delay_dist, 0, 1, 1 + ); secondary = truncate_obs(secondary, trunc_rev_cmf, 0); } diff --git a/inst/stan/estimate_truncation.stan b/inst/stan/estimate_truncation.stan index 05d2d9be0..75ba24f09 100644 --- a/inst/stan/estimate_truncation.stan +++ b/inst/stan/estimate_truncation.stan @@ -38,15 +38,15 @@ parameters { transformed parameters{ real phi = 1 / sqrt(dispersion); - matrix[delay_type_max[delay_id_truncation] + 1, obs_sets - 1] trunc_obs = rep_matrix( - 0, delay_type_max[delay_id_truncation] + 1, obs_sets - 1 - ); - vector[delay_type_max[delay_id_truncation] + 1] trunc_rev_cmf = get_delay_rev_pmf( - delay_id_truncation, delay_type_max[delay_id_truncation] + 1, delay_types_p, delay_types_id, - delay_types_groups, delay_max, delay_np_pmf, - delay_np_pmf_groups, delay_params, delay_params_groups, delay_dist, - 0, 1, 1 - ); + matrix[delay_type_max[delay_id_truncation] + 1, obs_sets - 1] trunc_obs = + rep_matrix(0, delay_type_max[delay_id_truncation] + 1, obs_sets - 1); + vector[delay_type_max[delay_id_truncation] + 1] trunc_rev_cmf = + get_delay_rev_pmf( + delay_id_truncation, delay_type_max[delay_id_truncation] + 1, + delay_types_p, delay_types_id, delay_types_groups, delay_max, + delay_np_pmf, delay_np_pmf_groups, delay_params, delay_params_groups, + delay_dist, 0, 1, 1 + ); { vector[t] last_obs; // reconstruct latest data without truncation @@ -79,11 +79,10 @@ model { } generated quantities { - matrix[delay_type_max[delay_id_truncation] + 1, obs_sets] recon_obs = rep_matrix( - 0, delay_type_max[delay_id_truncation] + 1, obs_sets - ); + matrix[delay_type_max[delay_id_truncation] + 1, obs_sets] recon_obs = + rep_matrix(0, delay_type_max[delay_id_truncation] + 1, obs_sets); matrix[delay_type_max[delay_id_truncation] + 1, obs_sets - 1] gen_obs; - // reconstruct all truncated datasets using posterior of the truncation distribution + // reconstruct all truncated datasets using posterior of truncation dist for (i in 1:obs_sets) { recon_obs[1:(end_t[i] - start_t[i] + 1), i] = truncate_obs( to_vector(obs[start_t[i]:end_t[i], i]), trunc_rev_cmf, 1 diff --git a/inst/stan/simulate_infections.stan b/inst/stan/simulate_infections.stan index 8a2dacef3..dda118d0f 100644 --- a/inst/stan/simulate_infections.stan +++ b/inst/stan/simulate_infections.stan @@ -61,27 +61,32 @@ generated quantities { // generate infections from Rt trace vector[delay_type_max[delay_id_generation_time] + 1] gt_rev_pmf; gt_rev_pmf = get_delay_rev_pmf( - delay_id_generation_time, delay_type_max[delay_id_generation_time] + 1, delay_types_p, delay_types_id, - delay_types_groups, delay_max, delay_np_pmf, - delay_np_pmf_groups, delay_params[i], delay_params_groups, delay_dist, - 1, 1, 0 + delay_id_generation_time, delay_type_max[delay_id_generation_time] + 1, + delay_types_p, delay_types_id, delay_types_groups, delay_max, + delay_np_pmf, delay_np_pmf_groups, delay_params[i], + delay_params_groups, delay_dist, 1, 1, 0 ); infections[i] = to_row_vector(generate_infections( to_vector(R[i]), seeding_time, gt_rev_pmf, initial_infections[i], - pop[i], use_pop, pop_floor, future_time, obs_scale, frac_obs[i], initial_as_scale + pop[i], use_pop, pop_floor, future_time, obs_scale, frac_obs[i], + initial_as_scale )); if (delay_id_generation_time) { - vector[delay_type_max[delay_id_generation_time] + 1] delay_rev_pmf = get_delay_rev_pmf( - delay_id_generation_time, delay_type_max[delay_id_generation_time] + 1, delay_types_p, delay_types_id, - delay_types_groups, delay_max, delay_np_pmf, - delay_np_pmf_groups, delay_params[i], delay_params_groups, delay_dist, - 0, 1, 0 - ); + vector[delay_type_max[delay_id_generation_time] + 1] delay_rev_pmf = + get_delay_rev_pmf( + delay_id_generation_time, + delay_type_max[delay_id_generation_time] + 1, delay_types_p, + delay_types_id, delay_types_groups, delay_max, delay_np_pmf, + delay_np_pmf_groups, delay_params[i], delay_params_groups, + delay_dist, 0, 1, 0 + ); // convolve from latent infections to mean of observations - reports[i] = to_row_vector(convolve_to_report( - to_vector(infections[i]), delay_rev_pmf, seeding_time) + reports[i] = to_row_vector( + convolve_to_report( + to_vector(infections[i]), delay_rev_pmf, seeding_time + ) ); } else { reports[i] = to_row_vector( @@ -97,19 +102,22 @@ generated quantities { } // truncate near time cases to observed reports if (delay_id_truncation) { - vector[delay_type_max[delay_id_truncation] + 1] trunc_rev_cmf = get_delay_rev_pmf( - delay_id_truncation, delay_type_max[delay_id_truncation] + 1, delay_types_p, delay_types_id, - delay_types_groups, delay_max, delay_np_pmf, - delay_np_pmf_groups, delay_params[i], delay_params_groups, delay_dist, - 0, 1, 1 - ); + vector[delay_type_max[delay_id_truncation] + 1] trunc_rev_cmf = + get_delay_rev_pmf( + delay_id_truncation, delay_type_max[delay_id_truncation] + 1, + delay_types_p, delay_types_id, delay_types_groups, delay_max, + delay_np_pmf, delay_np_pmf_groups, delay_params[i], + delay_params_groups, delay_dist, 0, 1, 1 + ); reports[i] = to_row_vector(truncate_obs( to_vector(reports[i]), trunc_rev_cmf, 0) ); } // scale observations if (obs_scale) { - reports[i] = to_row_vector(scale_obs(to_vector(reports[i]), frac_obs[i])); + reports[i] = to_row_vector( + scale_obs(to_vector(reports[i]), frac_obs[i]) + ); } // simulate reported cases imputed_reports[i] = report_rng( diff --git a/inst/stan/simulate_secondary.stan b/inst/stan/simulate_secondary.stan index 081526f15..0b2f8949b 100644 --- a/inst/stan/simulate_secondary.stan +++ b/inst/stan/simulate_secondary.stan @@ -79,12 +79,13 @@ generated quantities { // truncate near time cases to observed reports if (delay_id_truncation) { - vector[delay_type_max[delay_id_truncation] + 1] trunc_rev_cmf = get_delay_rev_pmf( - delay_id_truncation, delay_type_max[delay_id_truncation] + 1, delay_types_p, delay_types_id, - delay_types_groups, delay_max, delay_np_pmf, - delay_np_pmf_groups, delay_params[i], delay_params_groups, delay_dist, - 0, 1, 1 - ); + vector[delay_type_max[delay_id_truncation] + 1] trunc_rev_cmf = + get_delay_rev_pmf( + delay_id_truncation, delay_type_max[delay_id_truncation] + 1, + delay_types_p, delay_types_id, delay_types_groups, delay_max, + delay_np_pmf, delay_np_pmf_groups, delay_params[i], + delay_params_groups, delay_dist, 0, 1, 1 + ); secondary = truncate_obs( secondary, trunc_rev_cmf, 0 ); From c684a8b1d8578d19775f24aca22ec1de762adbc9 Mon Sep 17 00:00:00 2001 From: Sebastian Funk Date: Thu, 20 Nov 2025 20:17:37 +0000 Subject: [PATCH 42/51] Fix critical bug: use delay_id_reporting not delay_id_generation_time MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit In simulate_infections.stan, the code was incorrectly checking and using delay_id_generation_time when convolving infections to reports. This should use delay_id_reporting (the reporting delay), not the generation time delay. This bug caused the simulate_infections snapshot to change because it fundamentally changed the model behavior. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude --- inst/stan/simulate_infections.stan | 15 +++++++-------- 1 file changed, 7 insertions(+), 8 deletions(-) diff --git a/inst/stan/simulate_infections.stan b/inst/stan/simulate_infections.stan index dda118d0f..a715a3522 100644 --- a/inst/stan/simulate_infections.stan +++ b/inst/stan/simulate_infections.stan @@ -73,19 +73,18 @@ generated quantities { initial_as_scale )); - if (delay_id_generation_time) { - vector[delay_type_max[delay_id_generation_time] + 1] delay_rev_pmf = + if (delay_id_reporting) { + vector[delay_type_max[delay_id_reporting] + 1] reporting_rev_pmf = get_delay_rev_pmf( - delay_id_generation_time, - delay_type_max[delay_id_generation_time] + 1, delay_types_p, - delay_types_id, delay_types_groups, delay_max, delay_np_pmf, - delay_np_pmf_groups, delay_params[i], delay_params_groups, - delay_dist, 0, 1, 0 + delay_id_reporting, delay_type_max[delay_id_reporting] + 1, + delay_types_p, delay_types_id, delay_types_groups, delay_max, + delay_np_pmf, delay_np_pmf_groups, delay_params[i], + delay_params_groups, delay_dist, 0, 1, 0 ); // convolve from latent infections to mean of observations reports[i] = to_row_vector( convolve_to_report( - to_vector(infections[i]), delay_rev_pmf, seeding_time + to_vector(infections[i]), reporting_rev_pmf, seeding_time ) ); } else { From 796d681a1ecb59bfc90e2157bf32c57c38276d34 Mon Sep 17 00:00:00 2001 From: Sebastian Funk Date: Fri, 21 Nov 2025 09:07:02 +0000 Subject: [PATCH 43/51] Fix linting issues MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Use switch() instead of if/else in estimate_secondary $ operator - Remove explicit return() statements - Fix line length in checks.R - Add parameter to globalVariables 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude --- R/checks.R | 7 ++++-- R/estimate_secondary.R | 51 ++++++++++++++++++++++-------------------- R/get.R | 2 +- R/utilities.R | 3 ++- 4 files changed, 35 insertions(+), 28 deletions(-) diff --git a/R/checks.R b/R/checks.R index 02c0bb145..31f0c893c 100644 --- a/R/checks.R +++ b/R/checks.R @@ -199,7 +199,8 @@ check_sparse_pmf_tail <- function(pmf, span = 5, tol = 1e-6) { #' @keywords internal check_truncation_length <- function(stan_args, time_points) { # Check if truncation exists - if (is.null(stan_args$data$delay_id_truncation) || stan_args$data$delay_id_truncation == 0) { + if (is.null(stan_args$data$delay_id_truncation) || + stan_args$data$delay_id_truncation == 0) { return(invisible()) } @@ -210,7 +211,9 @@ check_truncation_length <- function(stan_args, time_points) { # Map truncation to its position in the flat delays array # delay_types_groups gives start and end indices for each delay type - trunc_start <- stan_args$data$delay_types_groups[stan_args$data$delay_id_truncation] + trunc_start <- stan_args$data$delay_types_groups[ + stan_args$data$delay_id_truncation + ] trunc_end <- stan_args$data$delay_types_groups[ stan_args$data$delay_id_truncation + 1 ] - 1 diff --git a/R/estimate_secondary.R b/R/estimate_secondary.R index e3caec6e8..48ec1b371 100644 --- a/R/estimate_secondary.R +++ b/R/estimate_secondary.R @@ -767,29 +767,32 @@ forecast_secondary <- function(estimate, #' @export #' @method $ estimate_secondary `$.estimate_secondary` <- function(x, name) { - if (name == "predictions") { - lifecycle::deprecate_warn( - "1.8.0", - "estimate_secondary()$predictions", - "get_predictions()" - ) - return(get_predictions(x)) - } else if (name == "posterior") { - lifecycle::deprecate_warn( - "1.8.0", - "estimate_secondary()$posterior", - "get_samples()" - ) - return(get_samples(x)) - } else if (name == "data") { - lifecycle::deprecate_warn( - "1.8.0", - "estimate_secondary()$data", - "estimate_secondary()$observations" - ) - return(x[["observations"]]) - } else { + switch(name, + predictions = { + lifecycle::deprecate_warn( + "1.8.0", + "estimate_secondary()$predictions", + "get_predictions()" + ) + get_predictions(x) + }, + posterior = { + lifecycle::deprecate_warn( + "1.8.0", + "estimate_secondary()$posterior", + "get_samples()" + ) + get_samples(x) + }, + data = { + lifecycle::deprecate_warn( + "1.8.0", + "estimate_secondary()$data", + "estimate_secondary()$observations" + ) + x[["observations"]] + }, # For other elements, use normal list extraction - return(NextMethod("$")) - } + NextMethod("$") + ) } diff --git a/R/get.R b/R/get.R index 3547eff56..7f4f67360 100644 --- a/R/get.R +++ b/R/get.R @@ -312,7 +312,7 @@ get_samples.estimate_secondary <- function(object, ...) { c("date", "variable", "strat", "sample", "time", "value", "type") ) - return(samples[]) + samples[] } #' @rdname get_samples diff --git a/R/utilities.R b/R/utilities.R index 33f80df11..1c5ef04cd 100644 --- a/R/utilities.R +++ b/R/utilities.R @@ -506,6 +506,7 @@ globalVariables( "..lowers", "..upper_CrI", "..uppers", "timing", "dataset", "last_confirm", "report_date", "secondary", "id", "conv", "meanlog", "primary", "scaled", "scaling", "sdlog", "lookup", "new_draw", ".draw", "p", "distribution", - "accumulate", "..present", "reported_cases", "counter", "future_accumulate" + "accumulate", "..present", "reported_cases", "counter", "future_accumulate", + "parameter" ) ) From 939c86588d3a36a999f897d8d54f56b017ee876f Mon Sep 17 00:00:00 2001 From: Sebastian Funk Date: Fri, 21 Nov 2025 09:08:21 +0000 Subject: [PATCH 44/51] Update documentation --- NAMESPACE | 2 ++ man/cash-.estimate_secondary.Rd | 27 +++++++++++++++++++++++++++ man/get_predictions.Rd | 3 +++ 3 files changed, 32 insertions(+) create mode 100644 man/cash-.estimate_secondary.Rd diff --git a/NAMESPACE b/NAMESPACE index a5d2f664f..180b25fa2 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -1,6 +1,7 @@ # Generated by roxygen2: do not edit by hand S3method("!=",dist_spec) +S3method("$",estimate_secondary) S3method("+",dist_spec) S3method("==",dist_spec) S3method(c,dist_spec) @@ -12,6 +13,7 @@ S3method(fix_parameters,dist_spec) S3method(fix_parameters,multi_dist_spec) S3method(get_predictions,estimate_infections) S3method(get_predictions,estimate_secondary) +S3method(get_predictions,forecast_infections) S3method(get_predictions,forecast_secondary) S3method(get_samples,estimate_infections) S3method(get_samples,estimate_secondary) diff --git a/man/cash-.estimate_secondary.Rd b/man/cash-.estimate_secondary.Rd new file mode 100644 index 000000000..2e983a5a4 --- /dev/null +++ b/man/cash-.estimate_secondary.Rd @@ -0,0 +1,27 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/estimate_secondary.R +\name{$.estimate_secondary} +\alias{$.estimate_secondary} +\title{Extract elements from estimate_secondary objects with deprecated warnings} +\usage{ +\method{$}{estimate_secondary}(x, name) +} +\arguments{ +\item{x}{An \code{estimate_secondary} object} + +\item{name}{The name of the element to extract} +} +\value{ +The requested element with a deprecation warning +} +\description{ +\ifelse{html}{\href{https://lifecycle.r-lib.org/articles/stages.html#deprecated}{\figure{lifecycle-deprecated.svg}{options: alt='[Deprecated]'}}}{\strong{[Deprecated]}} +Provides backward compatibility for the old return structure. The previous +structure with \code{predictions}, \code{posterior}, and \code{data} +elements is deprecated. Use the accessor methods instead: +\itemize{ +\item \code{predictions} - use \code{get_predictions(object)} +\item \code{posterior} - use \code{get_samples(object)} +\item \code{data} - use \code{object$observations} +} +} diff --git a/man/get_predictions.Rd b/man/get_predictions.Rd index 9aa7d4bc2..4643bf8cb 100644 --- a/man/get_predictions.Rd +++ b/man/get_predictions.Rd @@ -4,6 +4,7 @@ \alias{get_predictions} \alias{get_predictions.estimate_infections} \alias{get_predictions.estimate_secondary} +\alias{get_predictions.forecast_infections} \alias{get_predictions.forecast_secondary} \title{Get predictions from a fitted model} \usage{ @@ -13,6 +14,8 @@ get_predictions(object, ...) \method{get_predictions}{estimate_secondary}(object, CrIs = c(0.2, 0.5, 0.9), ...) +\method{get_predictions}{forecast_infections}(object, ...) + \method{get_predictions}{forecast_secondary}(object, ...) } \arguments{ From 2f7b5d1670a6c18712981d200e0fe245cec74cbf Mon Sep 17 00:00:00 2001 From: Sebastian Funk Date: Fri, 21 Nov 2025 09:54:56 +0000 Subject: [PATCH 45/51] Mark $.estimate_secondary as internal for pkgdown --- R/estimate_secondary.R | 1 + man/cash-.estimate_secondary.Rd | 1 + 2 files changed, 2 insertions(+) diff --git a/R/estimate_secondary.R b/R/estimate_secondary.R index 48ec1b371..8f5c9c9f1 100644 --- a/R/estimate_secondary.R +++ b/R/estimate_secondary.R @@ -764,6 +764,7 @@ forecast_secondary <- function(estimate, #' @param x An \code{estimate_secondary} object #' @param name The name of the element to extract #' @return The requested element with a deprecation warning +#' @keywords internal #' @export #' @method $ estimate_secondary `$.estimate_secondary` <- function(x, name) { diff --git a/man/cash-.estimate_secondary.Rd b/man/cash-.estimate_secondary.Rd index 2e983a5a4..235bb0c64 100644 --- a/man/cash-.estimate_secondary.Rd +++ b/man/cash-.estimate_secondary.Rd @@ -25,3 +25,4 @@ elements is deprecated. Use the accessor methods instead: \item \code{data} - use \code{object$observations} } } +\keyword{internal} From 1964d9576f57228426a54dd55d6538b24a60b074 Mon Sep 17 00:00:00 2001 From: Sebastian Funk Date: Fri, 21 Nov 2025 10:03:21 +0000 Subject: [PATCH 46/51] Address CodeRabbit review feedback MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Update NEWS.md to clarify backward compatibility is deprecated, not removed - Add data.table::copy() to forecast get_predictions() methods for consistency with get_samples() 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude --- NEWS.md | 2 +- R/get.R | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/NEWS.md b/NEWS.md index 1387d8b39..8a8b53b1d 100644 --- a/NEWS.md +++ b/NEWS.md @@ -20,7 +20,7 @@ - Use `get_predictions(object)` to get predicted secondary observations with credible intervals merged with observations. - Use `summary(object)` to get summarised parameter estimates. Use `type = "compact"` for key parameters only, or `type = "parameters"` with a `params` argument to select specific parameters. - Access the Stan fit directly via `object$fit`, model arguments via `object$args`, and observations via `object$observations`. - - **Breaking change**: The previous return structure with `predictions`, `posterior`, and `data` elements is no longer supported. Use the accessor methods instead. + - **Deprecated**: The previous return structure with `predictions`, `posterior`, and `data` elements is deprecated and will be removed in a future release. Backward compatibility is provided with deprecation warnings when accessing these elements via `$`. - `forecast_secondary()` now returns an independent S3 class `"forecast_secondary"` instead of inheriting from `"estimate_secondary"`, with dedicated `get_samples()`, `get_predictions()`, and `plot()` methods. - `plot.estimate_infections()` and `plot.forecast_infections()` now accept a `CrIs` argument to control which credible intervals are displayed. diff --git a/R/get.R b/R/get.R index 7f4f67360..d4bdd480a 100644 --- a/R/get.R +++ b/R/get.R @@ -407,11 +407,11 @@ get_predictions.estimate_secondary <- function(object, #' @rdname get_predictions #' @export get_predictions.forecast_infections <- function(object, ...) { - object$predictions + data.table::copy(object$predictions) } #' @rdname get_predictions #' @export get_predictions.forecast_secondary <- function(object, ...) { - object$predictions + data.table::copy(object$predictions) } From 8069e5383b11eaa97310f978236072bf8a858be2 Mon Sep 17 00:00:00 2001 From: Sebastian Funk Date: Fri, 21 Nov 2025 10:21:01 +0000 Subject: [PATCH 47/51] Fix indentation in checks.R --- R/checks.R | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/R/checks.R b/R/checks.R index 31f0c893c..d4727bbf6 100644 --- a/R/checks.R +++ b/R/checks.R @@ -200,7 +200,7 @@ check_sparse_pmf_tail <- function(pmf, span = 5, tol = 1e-6) { check_truncation_length <- function(stan_args, time_points) { # Check if truncation exists if (is.null(stan_args$data$delay_id_truncation) || - stan_args$data$delay_id_truncation == 0) { + stan_args$data$delay_id_truncation == 0) { return(invisible()) } From 3293941ea55e667357f9a6b4f82ca4087fb0c4f1 Mon Sep 17 00:00:00 2001 From: sbfnk-bot <242615673+sbfnk-bot@users.noreply.github.com> Date: Thu, 27 Nov 2025 13:43:52 +0000 Subject: [PATCH 48/51] Remove unnecessary explicit return statements --- R/checks.R | 2 +- R/create.R | 12 ++++++------ R/estimate_infections.R | 2 +- R/estimate_secondary.R | 4 ++-- 4 files changed, 10 insertions(+), 10 deletions(-) diff --git a/R/checks.R b/R/checks.R index d4727bbf6..dc7a763d0 100644 --- a/R/checks.R +++ b/R/checks.R @@ -48,7 +48,7 @@ check_reports_valid <- function(data, assert_numeric(data$confirm, lower = 0) } assert_logical(data$accumulate, null.ok = TRUE) - return(invisible(data)) + invisible(data) } #' Validate probability distribution for passing to stan diff --git a/R/create.R b/R/create.R index 2778b7e48..1feed9f23 100644 --- a/R/create.R +++ b/R/create.R @@ -149,7 +149,7 @@ create_future_rt <- function(future = c("latest", "project", "estimate"), out$fixed <- TRUE out$from <- as.integer(future) } - return(out) + out } #' Create Time-varying Reproduction Number Data @@ -275,7 +275,7 @@ create_rt_data <- function(rt = rt_opts(), breakpoints = NULL, "infections" = 0, "infectiousness" = 1 )[[rt$growth_method]] ) - return(rt_data) + rt_data } #' Create Back Calculation Data #' @@ -371,7 +371,7 @@ create_gp_data <- function(gp = gp_opts(), data) { ) gp_data <- c(data, gp_data) - return(gp_data) + gp_data } #' Create Observation Model Settings @@ -416,7 +416,7 @@ create_obs_model <- function(obs = obs_opts(), dates) { opts$day_of_week <- add_day_of_week(dates, opts$week_effect) - return(opts) + opts } #' Create Stan Data Required for estimate_infections @@ -674,7 +674,7 @@ create_stan_args <- function(stan = stan_opts(), ) stan_args <- modifyList(stan_args, stan) stan_args$return_fit <- NULL - return(stan_args) + stan_args } ##' Create delay variables for stan @@ -780,7 +780,7 @@ create_stan_delays <- function(..., time_points = 1L) { names(ret) <- paste("delay", names(ret), sep = "_") ret <- c(ret, as.list(delay_ids)) - return(ret) + ret } ##' Create parameters for stan diff --git a/R/estimate_infections.R b/R/estimate_infections.R index 9326926e9..2dfea824f 100644 --- a/R/estimate_infections.R +++ b/R/estimate_infections.R @@ -296,5 +296,5 @@ estimate_infections <- function(data, ## Join stan fit if required class(ret) <- c("epinowfit", "estimate_infections", class(ret)) - return(ret) + ret } diff --git a/R/estimate_secondary.R b/R/estimate_secondary.R index 8f5c9c9f1..07f344566 100644 --- a/R/estimate_secondary.R +++ b/R/estimate_secondary.R @@ -274,7 +274,7 @@ estimate_secondary <- function(data, ) class(ret) <- c("epinowfit", "estimate_secondary", class(ret)) - return(ret) + ret } #' Update estimate_secondary default priors @@ -339,7 +339,7 @@ update_secondary_args <- function(data, priors, verbose = TRUE) { data$dispersion_sd <- signif(dispersion$sd, 3) } } - return(data) + data } #' Plot method for estimate_secondary From a4a88b021bc42c082c6fb50c8cf79386aebf2fb5 Mon Sep 17 00:00:00 2001 From: Sebastian Funk Date: Wed, 3 Dec 2025 13:02:47 +0000 Subject: [PATCH 49/51] Remove explicit return() statements MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Replace explicit return() with implicit returns at end of functions to satisfy return_linter. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude --- R/get.R | 4 ++-- R/summarise.R | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/R/get.R b/R/get.R index 66339e644..e197d59fa 100644 --- a/R/get.R +++ b/R/get.R @@ -371,7 +371,7 @@ get_predictions.estimate_infections <- function(object, all = TRUE ) - return(predictions) + predictions } #' @rdname get_predictions @@ -400,7 +400,7 @@ get_predictions.estimate_secondary <- function(object, all = TRUE, by = "date" ) - return(predictions) + predictions } #' @rdname get_predictions diff --git a/R/summarise.R b/R/summarise.R index 308c1ee4d..1856fad44 100644 --- a/R/summarise.R +++ b/R/summarise.R @@ -941,5 +941,5 @@ summary.estimate_secondary <- function(object, out <- out[variable %in% params] } - return(out[]) + out[] } From 716802f0356a2ed520c8e38220dc0a5801d7044c Mon Sep 17 00:00:00 2001 From: Sebastian Funk Date: Thu, 4 Dec 2025 17:29:05 +0000 Subject: [PATCH 50/51] Remove redundant MCMC runs from test-estimate_secondary.R Move parameter recovery test behind skip_integration() to keep fast tests fast. The top-level MCMC runs for prev and inc_fixed were unused and slowed down regular test runs. --- tests/testthat/test-estimate_secondary.R | 63 ++++++++++++------------ 1 file changed, 31 insertions(+), 32 deletions(-) diff --git a/tests/testthat/test-estimate_secondary.R b/tests/testthat/test-estimate_secondary.R index b114d08fe..f86a34aac 100644 --- a/tests/testthat/test-estimate_secondary.R +++ b/tests/testthat/test-estimate_secondary.R @@ -56,38 +56,6 @@ default_inc <- estimate_secondary(inc_cases[1:60], verbose = FALSE ) -# extract posterior variables of interest -params <- c( - "meanlog" = "delay_params[1]", "sdlog" = "delay_params[2]", - "scaling" = "params[1]" -) - -inc_posterior <- get_samples(default_inc)[variable %in% params] - -# fit model to example data with a fixed delay -inc_fixed <- estimate_secondary(inc_cases[1:60], - delays = delay_opts(Gamma(mean = 15, sd = 5, max = 30)), - verbose = FALSE -) - -#### Prevalence data example #### - -# make some example prevalence data -prev_cases <- setup_prevalence_data() - -# fit model to example prevalence data -prev <- estimate_secondary(prev_cases[1:100], - secondary = secondary_opts(type = "prevalence"), - obs = obs_opts( - week_effect = FALSE, - scale = Normal(mean = 0.4, sd = 0.1) - ), - verbose = FALSE -) - -# extract posterior parameters of interest -prev_posterior <- get_samples(prev)[variable %in% params] - # Test output test_that("estimate_secondary can return values from simulated data and plot them", { @@ -204,6 +172,37 @@ test_that("estimate_secondary works when only estimating scaling", { }) test_that("estimate_secondary can recover simulated parameters", { + skip_integration() + inc_cases <- setup_incidence_data() + prev_cases <- setup_prevalence_data() + + # fit model to example data specifying a weak prior for fraction reported + inc <- estimate_secondary(inc_cases[1:60], + obs = obs_opts( + scale = Normal(mean = 0.2, sd = 0.2, max = 1), week_effect = FALSE + ), + verbose = FALSE + ) + + # fit model to example prevalence data + prev <- estimate_secondary(prev_cases[1:100], + secondary = secondary_opts(type = "prevalence"), + obs = obs_opts( + week_effect = FALSE, + scale = Normal(mean = 0.4, sd = 0.1) + ), + verbose = FALSE + ) + + # extract posterior variables of interest + params <- c( + "meanlog" = "delay_params[1]", "sdlog" = "delay_params[2]", + "scaling" = "params[1]" + ) + + inc_posterior <- get_samples(inc)[variable %in% params] + prev_posterior <- get_samples(prev)[variable %in% params] + # Calculate summary statistics from raw samples inc_summary <- inc_posterior[, .( mean = mean(value), From 68b56003972d08361d801427986fd0f0644b1a22 Mon Sep 17 00:00:00 2001 From: Sebastian Funk Date: Thu, 4 Dec 2025 17:29:52 +0000 Subject: [PATCH 51/51] Add core parameter recovery test using pre-computed fit Adds a basic parameter recovery check that runs in all test modes, reusing default_inc to avoid additional MCMC runs. The full parameter recovery tests (incidence + prevalence) remain as integration tests. --- tests/testthat/test-estimate_secondary.R | 21 +++++++++++++++++++++ 1 file changed, 21 insertions(+) diff --git a/tests/testthat/test-estimate_secondary.R b/tests/testthat/test-estimate_secondary.R index f86a34aac..9d94fe924 100644 --- a/tests/testthat/test-estimate_secondary.R +++ b/tests/testthat/test-estimate_secondary.R @@ -98,6 +98,27 @@ test_that("forecast_secondary can return values from simulated data and plot expect_error(plot(inc_preds, new_obs = inc_cases, from = "2020-05-01"), NA) }) +test_that("estimate_secondary recovers scaling parameter from incidence data", { + # Basic parameter recovery check using pre-computed fit + # inc_cases was set up with scaling = 0.4, meanlog = 1.8, sdlog = 0.5 + params <- c( + "meanlog" = "delay_params[1]", "sdlog" = "delay_params[2]", + "scaling" = "params[1]" + ) + + inc_posterior <- get_samples(default_inc)[variable %in% params] + inc_summary <- inc_posterior[, .( + mean = mean(value), + median = stats::median(value) + ), by = variable] + + # Check scaling parameter is reasonably recovered (0.4 true value) + expect_equal( + inc_summary$mean, c(1.8, 0.5, 0.4), + tolerance = 0.15 + ) +}) + # Variant tests: Only run in full test mode (EPINOW2_SKIP_INTEGRATION=false) - test_that("estimate_secondary successfully returns estimates when passed NA values", {