diff --git a/.github/workflows/touchstone-receive.yaml b/.github/workflows/touchstone-receive.yaml index 50c5f08b1..963300e53 100644 --- a/.github/workflows/touchstone-receive.yaml +++ b/.github/workflows/touchstone-receive.yaml @@ -53,7 +53,6 @@ jobs: cache-version: 1 extra-packages: any::cmdstanr extra-repositories: https://production.r-multiverse.org/2025-09-15 - touchstone_ref: '@289cfc7' benchmarking_repo: ${{ matrix.config.benchmarking_repo }} benchmarking_ref: ${{ matrix.config.benchmarking_ref }} benchmarking_path: ${{ matrix.config.benchmarking_path }} diff --git a/CLAUDE.md b/CLAUDE.md index 22734a31e..a75a98f75 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -26,9 +26,29 @@ ## Pull Requests +### Workflow Overview +When working on a task, follow this sequence: + +1. **During development**: Make commits with descriptive messages (NO issue numbers) +2. **Before opening PR**: Add a NEWS.md item for user-facing changes +3. **When opening PR**: Link the issue in PR description using "This PR closes #XXX" +4. **In PR template**: Complete all applicable checklist items + +### Key Distinctions +- **Commit messages**: Describe what the code does, never reference issues + - Good: "Standardise return structure for estimate_secondary" + - Bad: "Standardise return structure for estimate_secondary (#1142)" +- **PR descriptions**: Link to issues here using the template + - Format: "This PR closes #1142." +- **NEWS items**: Describe user-facing changes, never reference issues or PRs + - Good: "Added S3 methods for estimate_secondary objects" + - Bad: "Added S3 methods for estimate_secondary (#1142)" + +### PR Template Requirements - Follow the pull request template in `.github/PULL_REQUEST_TEMPLATE.md` - Complete the checklist items in the template - Link to the related issue in the PR description (not in commit messages) +- Ensure you've added a NEWS.md item before submitting ## Testing diff --git a/NAMESPACE b/NAMESPACE index f63c92fa9..760d1aab6 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -1,6 +1,7 @@ # Generated by roxygen2: do not edit by hand S3method("!=",dist_spec) +S3method("$",estimate_secondary) S3method("+",dist_spec) S3method("==",dist_spec) S3method(c,dist_spec) @@ -10,8 +11,14 @@ S3method(discretise,dist_spec) S3method(discretise,multi_dist_spec) S3method(fix_parameters,dist_spec) S3method(fix_parameters,multi_dist_spec) +S3method(get_predictions,estimate_infections) +S3method(get_predictions,estimate_secondary) +S3method(get_predictions,forecast_infections) +S3method(get_predictions,forecast_secondary) S3method(get_samples,estimate_infections) +S3method(get_samples,estimate_secondary) S3method(get_samples,forecast_infections) +S3method(get_samples,forecast_secondary) S3method(is_constrained,dist_spec) S3method(is_constrained,multi_dist_spec) S3method(max,dist_spec) @@ -24,6 +31,7 @@ S3method(plot,estimate_infections) S3method(plot,estimate_secondary) S3method(plot,estimate_truncation) S3method(plot,forecast_infections) +S3method(plot,forecast_secondary) S3method(print,dist_spec) S3method(print,epinowfit) S3method(sd,default) @@ -31,6 +39,7 @@ S3method(sd,dist_spec) S3method(sd,multi_dist_spec) S3method(summary,epinow) S3method(summary,estimate_infections) +S3method(summary,estimate_secondary) S3method(summary,forecast_infections) export(Fixed) export(Gamma) @@ -79,6 +88,7 @@ export(generation_time_opts) export(get_distribution) export(get_parameters) export(get_pmf) +export(get_predictions) export(get_regional_results) export(get_samples) export(gp_opts) diff --git a/NEWS.md b/NEWS.md index 16c0d3c5e..bf8119f93 100644 --- a/NEWS.md +++ b/NEWS.md @@ -17,6 +17,13 @@ - **Deprecated**: `summary(object, type = "samples")` now issues a deprecation warning. Use `get_samples(object)` instead. - **Deprecated**: Internal function `extract_parameter_samples()` renamed to `format_simulation_output()` for clarity. - `forecast_infections()` now returns an independent S3 class `"forecast_infections"` instead of inheriting from `"estimate_infections"`. This clarifies the distinction between fitted models (which contain a Stan fit for diagnostics) and forecast simulations (which contain pre-computed samples). Dedicated `summary()`, `plot()`, and `get_samples()` methods are provided. +- `estimate_secondary()` now returns an S3 object of class `c("epinowfit", "estimate_secondary", "list")` with elements `fit`, `args`, and `observations`, matching the structure of `estimate_infections()`. + - Use `get_samples(object)` to extract formatted posterior samples for delay and scaling parameters. + - Use `get_predictions(object)` to get predicted secondary observations with credible intervals merged with observations. + - Use `summary(object)` to get summarised parameter estimates. Use `type = "compact"` for key parameters only, or `type = "parameters"` with a `params` argument to select specific parameters. + - Access the Stan fit directly via `object$fit`, model arguments via `object$args`, and observations via `object$observations`. + - **Deprecated**: The previous return structure with `predictions`, `posterior`, and `data` elements is deprecated and will be removed in a future release. Backward compatibility is provided with deprecation warnings when accessing these elements via `$`. +- `forecast_secondary()` now returns an independent S3 class `"forecast_secondary"` instead of inheriting from `"estimate_secondary"`, with dedicated `get_samples()`, `get_predictions()`, and `plot()` methods. - `plot.estimate_infections()` and `plot.forecast_infections()` now accept a `CrIs` argument to control which credible intervals are displayed. ## Model changes diff --git a/R/checks.R b/R/checks.R index 618a30b3f..dc7a763d0 100644 --- a/R/checks.R +++ b/R/checks.R @@ -199,7 +199,8 @@ check_sparse_pmf_tail <- function(pmf, span = 5, tol = 1e-6) { #' @keywords internal check_truncation_length <- function(stan_args, time_points) { # Check if truncation exists - if (is.null(stan_args$data$trunc_id) || stan_args$data$trunc_id == 0) { + if (is.null(stan_args$data$delay_id_truncation) || + stan_args$data$delay_id_truncation == 0) { return(invisible()) } @@ -210,9 +211,11 @@ check_truncation_length <- function(stan_args, time_points) { # Map truncation to its position in the flat delays array # delay_types_groups gives start and end indices for each delay type - trunc_start <- stan_args$data$delay_types_groups[stan_args$data$trunc_id] + trunc_start <- stan_args$data$delay_types_groups[ + stan_args$data$delay_id_truncation + ] trunc_end <- stan_args$data$delay_types_groups[ - stan_args$data$trunc_id + 1 + stan_args$data$delay_id_truncation + 1 ] - 1 # Get which truncation delays are non-parametric diff --git a/R/create.R b/R/create.R index 991ef36d7..1feed9f23 100644 --- a/R/create.R +++ b/R/create.R @@ -370,7 +370,8 @@ create_gp_data <- function(gp = gp_opts(), data) { w0 = gp$w0 ) - c(data, gp_data) + gp_data <- c(data, gp_data) + gp_data } #' Create Observation Model Settings @@ -686,6 +687,8 @@ create_stan_args <- function(stan = stan_opts(), ##' @keywords internal create_stan_delays <- function(..., time_points = 1L) { delays <- list(...) + delay_names <- names(delays) + ## discretise delays <- map(delays, discretise, strict = FALSE) delays <- map(delays, collapse) @@ -694,10 +697,12 @@ create_stan_delays <- function(..., time_points = 1L) { max_delay <- unname(as.numeric(flatten(map(bounded_delays, max)))) ## number of different non-empty types type_n <- vapply(delays, ndist, integer(1)) - ## assign ID values to each type - ids <- rep(0L, length(type_n)) - ids[type_n > 0] <- seq_len(sum(type_n > 0)) - names(ids) <- paste(names(type_n), "id", sep = "_") + + ## Create delay_id_* variables pointing to delay_types_groups index + ## Similar to param_id_* in create_stan_params() + delay_ids <- rep(0L, length(type_n)) + delay_ids[type_n > 0] <- seq_len(sum(type_n > 0)) + names(delay_ids) <- paste("delay_id", delay_names, sep = "_") ## create "flat version" of delays, i.e. a list of all the delays (including ## elements of composite delays) @@ -773,7 +778,7 @@ create_stan_delays <- function(..., time_points = 1L) { ret$dist <- array(match(distributions, c("lognormal", "gamma")) - 1L) names(ret) <- paste("delay", names(ret), sep = "_") - ret <- c(ret, ids) + ret <- c(ret, as.list(delay_ids)) ret } diff --git a/R/estimate_infections.R b/R/estimate_infections.R index c6c2261f3..2dfea824f 100644 --- a/R/estimate_infections.R +++ b/R/estimate_infections.R @@ -73,8 +73,9 @@ #' @export #' @return An `` object which is a list of outputs #' including: the stan object (`fit`), arguments used to fit the model -#' (`args`), and the observed data (`observations`). The estimates included in -#' the fit can be accessed using the `summary()` function. +#' (`args`), and the observed data (`observations`). Use `summary()` to access +#' estimates, `get_samples()` to extract posterior samples, and +#' `get_predictions()` to access predicted reported cases. #' #' @seealso [epinow()] [regional_epinow()] [forecast_infections()] #' [estimate_truncation()] @@ -264,9 +265,9 @@ estimate_infections <- function(data, ) stan_data <- c(stan_data, create_stan_delays( - gt = generation_time, - delay = delays, - trunc = truncation, + generation_time = generation_time, + reporting = delays, + truncation = truncation, time_points = stan_data$t - stan_data$seeding_time - stan_data$horizon )) diff --git a/R/estimate_secondary.R b/R/estimate_secondary.R index b11da4ab2..21371a943 100644 --- a/R/estimate_secondary.R +++ b/R/estimate_secondary.R @@ -56,11 +56,11 @@ #' @param verbose Logical, should model fitting progress be returned. Defaults #' to [interactive()]. #' -#' @return A list containing: `predictions` (a `` ordered by date -#' with the primary, and secondary observations, and a summary of the model -#' estimated secondary observations), `posterior` which contains a summary of -#' the entire model posterior, `data` (a list of data used to fit the -#' model), and `fit` (the `stanfit` object). +#' @return An `` object which is a list of outputs +#' including: the stan object (`fit`), arguments used to fit the model +#' (`args`), and the observed data (`observations`). Use `summary()` to access +#' estimates, `get_samples()` to extract posterior samples, and +#' `get_predictions()` to access predictions. #' @export #' @inheritParams estimate_infections #' @inheritParams update_secondary_args @@ -232,8 +232,8 @@ estimate_secondary <- function(data, stan_data <- c(stan_data, secondary) # delay data stan_data <- c(stan_data, create_stan_delays( - delay = delays, - trunc = truncation, + reporting = delays, + truncation = truncation, time_points = stan_data$t )) @@ -266,22 +266,15 @@ estimate_secondary <- function(data, fit <- fit_model(stan_, id = "estimate_secondary") - out <- list() - out$predictions <- extract_stan_param(fit, "sim_secondary", CrIs = CrIs) - out$predictions <- out$predictions[, lapply(.SD, round, 1)] - out$predictions <- out$predictions[, date := reports[(burn_in + 1):.N]$date] - out$predictions <- data.table::merge.data.table( - reports, out$predictions, - all = TRUE, by = "date" - ) - out$posterior <- extract_stan_param( - fit, - CrIs = CrIs + # Create standardized S3 return structure + ret <- list( + fit = fit, + args = stan_data, + observations = reports ) - out$data <- stan_data - out$fit <- fit - class(out) <- c("estimate_secondary", class(out)) - out + + class(ret) <- c("epinowfit", "estimate_secondary", class(ret)) + ret } #' Update estimate_secondary default priors @@ -381,7 +374,7 @@ plot.estimate_secondary <- function(x, primary = FALSE, from = NULL, to = NULL, new_obs = NULL, ...) { - predictions <- data.table::copy(x$predictions) + predictions <- get_predictions(x) if (!is.null(new_obs)) { new_obs <- data.table::as.data.table(new_obs) @@ -428,6 +421,23 @@ plot.estimate_secondary <- function(x, primary = FALSE, ggplot2::theme(axis.text.x = ggplot2::element_text(angle = 90)) } +#' Plot method for forecast_secondary objects +#' +#' @description `r lifecycle::badge("stable")` +#' Plot method for forecast secondary observations. +#' +#' @inheritParams plot.estimate_secondary +#' @method plot forecast_secondary +#' @export +plot.forecast_secondary <- function(x, primary = FALSE, + from = NULL, to = NULL, + new_obs = NULL, + ...) { + plot.estimate_secondary( + x, primary = primary, from = from, to = to, new_obs = new_obs, ... + ) +} + #' Convolve and scale a time series #' #' This applies a lognormal convolution with given, potentially time-varying @@ -623,7 +633,7 @@ forecast_secondary <- function(estimate, if (inherits(primary, "estimate_infections")) { primary_samples <- get_samples(primary) primary <- primary_samples[variable == primary_variable] - primary <- primary[date > max(estimate$predictions$date, na.rm = TRUE)] + primary <- primary[date > max(get_predictions(estimate)$date, na.rm = TRUE)] primary <- primary[, .(date, sample, value)] if (!is.null(samples)) { primary <- primary[sample(.N, samples, replace = TRUE)] @@ -641,10 +651,10 @@ forecast_secondary <- function(estimate, include = FALSE ) # extract data from stanfit - stan_data <- estimate$data + stan_data <- estimate$args # combined primary from data and input primary - primary_fit <- estimate$predictions[ + primary_fit <- get_predictions(estimate)[ , .(date, value = primary, sample = list(unique(updated_primary$sample))) ] @@ -721,7 +731,7 @@ forecast_secondary <- function(estimate, # link previous prediction observations with forecast observations forecast_obs <- data.table::rbindlist( list( - estimate$predictions[, .(date, primary, secondary)], + get_predictions(estimate)[, .(date, primary, secondary)], data.table::copy(primary)[, .(primary = median(value)), by = "date"] ), use.names = TRUE, fill = TRUE @@ -735,6 +745,55 @@ forecast_secondary <- function(estimate, data.table::setcolorder( out$predictions, c("date", "primary", "secondary", "mean", "sd") ) - class(out) <- c("estimate_secondary", class(out)) + class(out) <- c("forecast_secondary", class(out)) out } + +#' Extract elements from estimate_secondary objects with deprecated warnings +#' +#' @description `r lifecycle::badge("deprecated")` +#' Provides backward compatibility for the old return structure. The previous +#' structure with \code{predictions}, \code{posterior}, and \code{data} +#' elements is deprecated. Use the accessor methods instead: +#' \itemize{ +#' \item \code{predictions} - use \code{get_predictions(object)} +#' \item \code{posterior} - use \code{get_samples(object)} +#' \item \code{data} - use \code{object$observations} +#' } +#' +#' @param x An \code{estimate_secondary} object +#' @param name The name of the element to extract +#' @return The requested element with a deprecation warning +#' @keywords internal +#' @export +#' @method $ estimate_secondary +`$.estimate_secondary` <- function(x, name) { + switch(name, + predictions = { + lifecycle::deprecate_warn( + "1.8.0", + "estimate_secondary()$predictions", + "get_predictions()" + ) + get_predictions(x) + }, + posterior = { + lifecycle::deprecate_warn( + "1.8.0", + "estimate_secondary()$posterior", + "get_samples()" + ) + get_samples(x) + }, + data = { + lifecycle::deprecate_warn( + "1.8.0", + "estimate_secondary()$data", + "estimate_secondary()$observations" + ) + x[["observations"]] + }, + # For other elements, use normal list extraction + NextMethod("$") + ) +} diff --git a/R/estimate_truncation.R b/R/estimate_truncation.R index 8ccd16206..2114fd2c6 100644 --- a/R/estimate_truncation.R +++ b/R/estimate_truncation.R @@ -184,7 +184,7 @@ estimate_truncation <- function(data, ) stan_data <- c(stan_data, create_stan_delays( - trunc = truncation, + truncation = truncation, time_points = stan_data$t )) diff --git a/R/extract.R b/R/extract.R index ffb40a516..78a594c50 100644 --- a/R/extract.R +++ b/R/extract.R @@ -1,4 +1,4 @@ -#' Extract Samples for a Latent State from a Stan model +#' Extract samples for a latent state from a Stan model #' #' @description `r lifecycle::badge("stable")` #' Extracts a time-varying latent state from a list of stan output and returns @@ -51,31 +51,126 @@ extract_latent_state <- function(param, samples, dates) { } -#' Extract Samples from a Parameter with a Single Dimension +#' Extract samples from all parameters #' -#' @param param Character string indicating the parameter to extract #' @param samples Extracted stan model (using [rstan::extract()]) -#' @return A `` containing the parameter name, sample id and sample -#' value, or NULL if the parameter doesn't exist in the samples +#' @return A `` containing the parameter name, sample id and sample +#' value, or NULL if parameters don't exist in the samples #' @keywords internal -extract_parameter <- function(param, samples) { - id_name <- paste("param_id", param, sep = "_") +extract_parameters <- function(samples) { + # Check if params exist + if (!("params" %in% names(samples))) { + return(NULL) + } + + # Extract all parameters + param_array <- samples[["params"]] + n_cols <- ncol(param_array) + + # Build reverse lookup: column index -> parameter name + param_names <- rep(NA_character_, n_cols) + + # Check all param_id_* variables to build the mapping + id_vars <- grep("^param_id_", names(samples), value = TRUE) + for (id_var in id_vars) { + param_name <- sub("^param_id_", "", id_var) + id <- samples[[id_var]] + if (!is.na(id) && id > 0) { + lookup_idx <- samples[["params_variable_lookup"]][id] + if (!is.na(lookup_idx) && lookup_idx > 0 && lookup_idx <= n_cols) { + param_names[lookup_idx] <- param_name + } + } + } - # Return NULL if parameter ID doesn't exist - if (!(id_name %in% names(samples))) { + # Extract all columns + samples_list <- lapply(seq_len(n_cols), function(i) { + # Use named parameter if available, otherwise use indexed name + par_name <- if (!is.na(param_names[i])) { + param_names[i] + } else { + paste0("params[", i, "]") + } + + data.table::data.table( + parameter = par_name, + sample = seq_along(param_array[, i]), + value = param_array[, i] + ) + }) + + data.table::rbindlist(samples_list) +} + +#' Extract samples from all delay parameters +#' +#' Extracts samples from all delay parameters using the delay ID lookup system. +#' Similar to extract_parameters(), this extracts all delay distribution +#' parameters and uses the *_id variables (e.g., delay_id, trunc_id) to assign +#' meaningful names. +#' +#' @param samples Extracted stan model (using [rstan::extract()]) +#' @return A `` with columns: parameter, sample, value, or NULL if +#' delay parameters don't exist in the samples +#' @keywords internal +extract_delays <- function(samples) { + # Check if delay_params exist + if (!("delay_params" %in% names(samples))) { return(NULL) } - id <- samples[[id_name]] + # Extract all delay parameters + delay_params <- samples[["delay_params"]] + n_cols <- ncol(delay_params) - lookup <- samples[["params_variable_lookup"]][id] - data.table::data.table( - parameter = param, - sample = seq_along(samples[["params"]][, lookup]), - value = samples[["params"]][, lookup] - ) + # Build reverse lookup: column index -> delay name + delay_names <- rep(NA_character_, n_cols) + + # Check all delay_id_* variables to build the mapping + id_vars <- grep("^delay_id_", names(samples), value = TRUE) + if (length(id_vars) > 0 && "delay_types_groups" %in% names(samples)) { + delay_types_groups <- samples[["delay_types_groups"]] + + for (id_var in id_vars) { + delay_name <- sub("^delay_id_", "", id_var) + id <- samples[[id_var]][1] # Take first value (same across samples) + + # Check if this delay exists (ID > 0) + if (!is.na(id) && id > 0 && id < length(delay_types_groups)) { + start_idx <- delay_types_groups[id] + end_idx <- delay_types_groups[id + 1] - 1 + + # Mark columns for this delay + for (i in seq_along(start_idx:end_idx)) { + col_idx <- start_idx + i - 1 + if (col_idx <= n_cols) { + delay_names[col_idx] <- paste0(delay_name, "[", i, "]") + } + } + } + } + } + + # Extract all columns + samples_list <- lapply(seq_len(n_cols), function(i) { + # Use named delay if available, otherwise use indexed name + par_name <- if (!is.na(delay_names[i])) { + delay_names[i] + } else { + paste0("delay_params[", i, "]") + } + + data.table::data.table( + parameter = par_name, + sample = seq_along(delay_params[, i]), + value = delay_params[, i] + ) + }) + + data.table::rbindlist(samples_list) } + #' Extract all samples from a stan fit #' #' If the `object` argument is a `` object, it simply returns the @@ -151,7 +246,7 @@ extract_samples <- function(stan_fit, pars = NULL, include = TRUE) { samples } -#' Extract Parameter Samples from a Stan Model +#' Extract parameter samples from a Stan model #' #' @description `r lifecycle::badge("deprecated")` #' This function has been deprecated. Use [format_simulation_output()] for @@ -180,7 +275,7 @@ extract_parameter_samples <- function(stan_fit, data, reported_dates, ) } -#' Extract a Parameter Summary from a Stan Object +#' Extract a parameter summary from a Stan object #' #' @description `r lifecycle::badge("stable")` #' Extracts summarised parameter posteriors from a `stanfit` object using diff --git a/R/format.R b/R/format.R index 90d9a1ca8..22cc9d269 100644 --- a/R/format.R +++ b/R/format.R @@ -191,20 +191,22 @@ format_simulation_output <- function(stan_fit, data, reported_dates, ] } # Auto-detect and extract all static parameters from params matrix - # Find all parameter IDs (names starting with "param_id_") - param_id_names <- names(samples)[startsWith(names(samples), "param_id_")] - param_names <- sub("^param_id_", "", param_id_names) + all_params <- extract_parameters(samples) + if (!is.null(all_params)) { + # Get unique parameter names + param_names <- unique(all_params$parameter) - for (param in param_names) { - result <- extract_parameter(param, samples) - if (!is.null(result)) { - # Use standard naming conventions - param_name <- switch(param, - "dispersion" = "reporting_overdispersion", - "frac_obs" = "fraction_observed", - param # default: use param name as-is - ) - out[[param_name]] <- result + for (param in param_names) { + result <- all_params[parameter == param] + if (nrow(result) > 0) { + # Use standard naming conventions + param_name <- switch(param, + "dispersion" = "reporting_overdispersion", + "frac_obs" = "fraction_observed", + param # default: use param name as-is + ) + out[[param_name]] <- result + } } } out @@ -279,34 +281,11 @@ format_samples_with_dates <- function(raw_samples, args, observations) { # Delay parameters if (args$delay_params_length > 0) { - delay_params <- extract_latent_state( - "delay_params", raw_samples, seq_len(args$delay_params_length) - ) - if (!is.null(delay_params)) { - out$delay_params <- delay_params[, strat := as.character(time)][ - , time := NULL - ][, date := NULL] - } + out$delay_params <- extract_delays(raw_samples) } - # Auto-detect and extract all static parameters from params matrix - param_id_names <- names(raw_samples)[ - startsWith(names(raw_samples), "param_id_") - ] - param_names <- sub("^param_id_", "", param_id_names) - - for (param in param_names) { - result <- extract_parameter(param, raw_samples) - if (!is.null(result)) { - # Use standard naming conventions - param_name <- switch(param, - "dispersion" = "reporting_overdispersion", - "frac_obs" = "fraction_observed", - param # default: use param name as-is - ) - out[[param_name]] <- result - } - } + # Params matrix + out$params <- extract_parameters(raw_samples) # Combine all parameters into single data.table combined <- data.table::rbindlist(out, fill = TRUE, idcol = "variable") diff --git a/R/get.R b/R/get.R index 44ba4a8ce..e197d59fa 100644 --- a/R/get.R +++ b/R/get.R @@ -262,5 +262,155 @@ get_samples.estimate_infections <- function(object, ...) { #' @rdname get_samples #' @export get_samples.forecast_infections <- function(object, ...) { - object$samples + data.table::copy(object$samples) +} + +#' @rdname get_samples +#' @export +get_samples.estimate_secondary <- function(object, ...) { + # Extract raw posterior samples from the fit + raw_samples <- extract_samples(object$fit) + + # Extract parameters (delays and params) + samples_list <- list( + extract_delays(raw_samples), + extract_parameters(raw_samples) + ) + + # Extract time-varying generated quantities + # Get dates for time-indexed parameters (post burn-in) + burn_in <- object$args$burn_in + dates <- object$observations[(burn_in + 1):.N]$date + + # Extract sim_secondary (generated quantity, post burn-in) if available + sim_secondary_samples <- extract_latent_state( + "sim_secondary", raw_samples, dates + ) + if (!is.null(sim_secondary_samples)) { + samples_list <- c(samples_list, list(sim_secondary_samples)) + } + + # Combine all samples + samples <- data.table::rbindlist(samples_list, fill = TRUE) + + # Rename 'parameter' column to 'variable' for consistency if needed + if ("parameter" %in% names(samples)) { + data.table::setnames(samples, "parameter", "variable") + } + + # Add placeholder columns for consistency with estimate_infections format + # Only add if not already present + if (!"date" %in% names(samples)) samples[, date := as.Date(NA)] + if (!"strat" %in% names(samples)) samples[, strat := NA_character_] + if (!"time" %in% names(samples)) samples[, time := NA_integer_] + if (!"type" %in% names(samples)) samples[, type := NA_character_] + + # Reorder columns + data.table::setcolorder( + samples, + c("date", "variable", "strat", "sample", "time", "value", "type") + ) + + samples[] +} + +#' @rdname get_samples +#' @export +get_samples.forecast_secondary <- function(object, ...) { + data.table::copy(object$samples) +} + +#' Get predictions from a fitted model +#' +#' @description `r lifecycle::badge("stable")` +#' Extracts predictions from a fitted model, combining observations with model +#' estimates. For `estimate_infections()` returns predicted reported cases, for +#' `estimate_secondary()` returns predicted secondary observations. +#' +#' @param object A fitted model object (e.g., from `estimate_infections()` or +#' `estimate_secondary()`) +#' @param CrIs Numeric vector of credible intervals to return. Defaults to +#' c(0.2, 0.5, 0.9). +#' @param ... Additional arguments (currently unused) +#' +#' @return A `data.table` with columns including date, observations, and summary +#' statistics (mean, sd, credible intervals) for the model predictions. +#' +#' @export +#' @examples +#' \dontrun{ +#' # After fitting a model +#' predictions <- get_predictions(fit) +#' } +get_predictions <- function(object, ...) { + UseMethod("get_predictions") +} + +#' @rdname get_predictions +#' @export +get_predictions.estimate_infections <- function(object, + CrIs = c(0.2, 0.5, 0.9), + ...) { + # Get samples for reported cases + samples <- get_samples(object) + reported_samples <- samples[variable == "reported_cases"] + + # Calculate summary measures + predictions <- calc_summary_measures( + reported_samples, + summarise_by = "date", + order_by = "date", + CrIs = CrIs + ) + + # Merge with observations + predictions <- data.table::merge.data.table( + object$observations[, .(date, confirm)], + predictions, + by = "date", + all = TRUE + ) + + predictions +} + +#' @rdname get_predictions +#' @export +get_predictions.estimate_secondary <- function(object, + CrIs = c(0.2, 0.5, 0.9), + ...) { + # Get samples for simulated secondary observations + samples <- get_samples(object) + sim_secondary_samples <- samples[variable == "sim_secondary"] + + # Calculate summary measures + predictions <- calc_summary_measures( + sim_secondary_samples, + summarise_by = "date", + order_by = "date", + CrIs = CrIs + ) + + # Round predictions + predictions <- predictions[, lapply(.SD, round, 1)] + + # Merge with observations + predictions <- data.table::merge.data.table( + object$observations, predictions, + all = TRUE, by = "date" + ) + + predictions +} + +#' @rdname get_predictions +#' @export +get_predictions.forecast_infections <- function(object, ...) { + data.table::copy(object$predictions) +} + +#' @rdname get_predictions +#' @export +get_predictions.forecast_secondary <- function(object, ...) { + data.table::copy(object$predictions) } diff --git a/R/simulate_infections.R b/R/simulate_infections.R index 306637d87..8f2208bef 100644 --- a/R/simulate_infections.R +++ b/R/simulate_infections.R @@ -135,9 +135,9 @@ simulate_infections <- function(R, ) stan_data <- c(stan_data, create_stan_delays( - gt = generation_time, - delay = delays, - trunc = truncation + generation_time = generation_time, + reporting = delays, + truncation = truncation )) if (length(stan_data$delay_params_sd) > 0 && diff --git a/R/simulate_secondary.R b/R/simulate_secondary.R index a436ff18d..3dfd246cf 100644 --- a/R/simulate_secondary.R +++ b/R/simulate_secondary.R @@ -70,8 +70,8 @@ simulate_secondary <- function(primary, stan_data <- c(stan_data, secondary) stan_data <- c(stan_data, create_stan_delays( - delay = delays, - trunc = truncation + reporting = delays, + truncation = truncation )) if (length(stan_data$delay_params_sd) > 0 && diff --git a/R/summarise.R b/R/summarise.R index 20b1c47aa..1856fad44 100644 --- a/R/summarise.R +++ b/R/summarise.R @@ -885,3 +885,61 @@ summary.forecast_infections <- function(object, print.epinowfit <- function(x, ...) { print(summary(x)) } + +#' Summarise results from estimate_secondary +#' +#' @description `r lifecycle::badge("stable")` +#' Returns a summary of the fitted secondary model including posterior +#' parameter estimates with credible intervals. +#' +#' @param object A fitted model object from `estimate_secondary()` +#' @param type Character string indicating the type of summary to return. +#' Options are "compact" (default, key parameters only) or "parameters" +#' (all parameters or filtered set). +#' @param params Character vector of parameter names to include. Only used +#' when `type = "parameters"`. If NULL (default), returns all parameters. +#' @param CrIs Numeric vector of credible intervals to return. Defaults to +#' c(0.2, 0.5, 0.9). +#' @param ... Additional arguments (currently unused) +#' +#' @return A `` with summary statistics (mean, sd, median, +#' credible intervals) for model parameters. When `type = "compact"`, +#' returns only key parameters (delay distribution parameters and scaling +#' factors). When `type = "parameters"`, returns all or filtered parameters. +#' @importFrom rlang arg_match +#' @method summary estimate_secondary +#' @export +summary.estimate_secondary <- function(object, + type = c("compact", "parameters"), + params = NULL, + CrIs = c(0.2, 0.5, 0.9), ...) { + type <- arg_match(type) + + # Get all posterior samples + samples <- get_samples(object) + + # Filter to non-time-varying parameters (delay_params and params) + # Time-varying parameters like secondary and sim_secondary have dates + param_samples <- samples[is.na(date)] + + # Calculate summary statistics + out <- calc_summary_measures( + param_samples, + summarise_by = "variable", + order_by = "variable", + CrIs = CrIs + ) + + if (type == "compact") { + # Return only key parameters for a compact summary + # Typical parameters: delay_params (distribution parameters), + # params (scaling factors) + key_vars <- c("delay_params", "params", "frac_obs") + out <- out[grepl(paste(key_vars, collapse = "|"), variable)] + } else if (type == "parameters" && !is.null(params)) { + # Optional filtering by parameter name + out <- out[variable %in% params] + } + + out[] +} diff --git a/R/utilities.R b/R/utilities.R index 53817fd2e..4363396fe 100644 --- a/R/utilities.R +++ b/R/utilities.R @@ -506,6 +506,7 @@ globalVariables( "..lowers", "..upper_CrI", "..uppers", "timing", "dataset", "last_confirm", "report_date", "secondary", "id", "conv", "meanlog", "primary", "scaled", "scaling", "sdlog", "lookup", "new_draw", ".draw", "p", "distribution", - "accumulate", "..present", "reported_cases", "counter", "future_accumulate" + "accumulate", "..present", "reported_cases", "counter", "future_accumulate", + "parameter" ) ) diff --git a/inst/dev/stan-to-R.R b/inst/dev/stan-to-R.R index cda56e932..1e2556653 100644 --- a/inst/dev/stan-to-R.R +++ b/inst/dev/stan-to-R.R @@ -118,10 +118,10 @@ simulate <- function(data, } if (estimate_r) { gt_rev_pmf <- get_delay_rev_pmf( - gt_id, delay_type_max[gt_id] + 1, delay_types_p, delay_types_id, - delay_types_groups, delay_max, delay_np_pmf, - delay_np_pmf_groups, delay_params, delay_params_groups, delay_dist, - 1, 1, 0 + delay_id_generation_time, delay_type_max[delay_id_generation_time] + 1, + delay_types_p, delay_types_id, delay_types_groups, delay_max, + delay_np_pmf, delay_np_pmf_groups, delay_params, delay_params_groups, + delay_dist, 1, 1, 0 ) R0 <- get_param( param_id_R0, params_fixed_lookup, params_variable_lookup, params_value, params @@ -146,13 +146,13 @@ simulate <- function(data, shifted_cases, noise, fixed, backcalc_prior ) } - delay_rev_pmf <- get_delay_rev_pmf( - delay_id, delay_type_max[delay_id] + 1, delay_types_p, delay_types_id, - delay_types_groups, delay_max, delay_np_pmf, - delay_np_pmf_groups, delay_params, delay_params_groups, delay_dist, - 0, 1, 0 + reporting_rev_pmf <- get_delay_rev_pmf( + delay_id_reporting, delay_type_max[delay_id_reporting] + 1, delay_types_p, + delay_types_id, delay_types_groups, delay_max, delay_np_pmf, + delay_np_pmf_groups, delay_params, delay_params_groups, delay_dist, 0, 1, + 0 ) - reports <- convolve_to_report(infections, delay_rev_pmf, seeding_time) + reports <- convolve_to_report(infections, reporting_rev_pmf, seeding_time) if (week_effect > 1) { reports <- day_of_week_effect(reports, day_of_week, day_of_week_simplex) } @@ -163,12 +163,12 @@ simulate <- function(data, ) reports <- scale_obs(reports, frac_obs) } - if (trunc_id) { + if (delay_id_truncation) { trunc_rev_cmf <- get_delay_rev_pmf( - trunc_id, delay_type_max[trunc_id] + 1, delay_types_p, delay_types_id, - delay_types_groups, delay_max, delay_np_pmf, - delay_np_pmf_groups, delay_params, delay_params_groups, delay_dist, - 0, 1, 1 + delay_id_truncation, delay_type_max[delay_id_truncation] + 1, + delay_types_p, delay_types_id, delay_types_groups, delay_max, + delay_np_pmf, delay_np_pmf_groups, delay_params, delay_params_groups, + delay_dist, 0, 1, 1 ) obs_reports <- truncate_obs(reports[1:ot], trunc_rev_cmf, 0) } else { diff --git a/inst/stan/data/observation_model.stan b/inst/stan/data/observation_model.stan index 715d4db52..b0b840930 100644 --- a/inst/stan/data/observation_model.stan +++ b/inst/stan/data/observation_model.stan @@ -5,5 +5,5 @@ int obs_scale; // logical controlling scaling of observations real obs_weight; // weight given to observation in log density int likelihood; // Should the likelihood be included in the model int return_likelihood; // Should the likehood be returned by the model -int trunc_id; // id of truncation -int delay_id; // id of delay +int delay_id_truncation; // id of truncation delay +int delay_id_reporting; // id of reporting delay diff --git a/inst/stan/data/rt.stan b/inst/stan/data/rt.stan index 3fb1fbad5..79d24a6f3 100644 --- a/inst/stan/data/rt.stan +++ b/inst/stan/data/rt.stan @@ -5,5 +5,5 @@ int future_fixed; // is underlying future Rt assumed to be fixed int fixed_from; // Reference date for when Rt estimation should be fixed int use_pop; // use population size (0 = no; 1 = forecasts; 2 = all) real pop_floor; // Minimum susceptible population (numerical stability floor) -int gt_id; // id of generation time +int delay_id_generation_time; // id of generation time int growth_method; // method to compute growth rate (0 = infections, 1 = infectiousness) diff --git a/inst/stan/data/simulation_delays.stan b/inst/stan/data/simulation_delays.stan index c6fec4eb5..d90f7f336 100644 --- a/inst/stan/data/simulation_delays.stan +++ b/inst/stan/data/simulation_delays.stan @@ -25,4 +25,4 @@ array[delay_n] int delay_types_id; // index of each delay (parametric or non) array[delay_types + 1] int delay_types_groups; -int delay_id; // id of generation time +int delay_id_reporting; // id of reporting delay diff --git a/inst/stan/data/simulation_observation_model.stan b/inst/stan/data/simulation_observation_model.stan index db40e95d3..69acaefbe 100644 --- a/inst/stan/data/simulation_observation_model.stan +++ b/inst/stan/data/simulation_observation_model.stan @@ -3,4 +3,4 @@ int week_effect; // should a day of the week effect be estimated array[n, week_effect] real day_of_week_simplex; int obs_scale; int model_type; -int trunc_id; // id of truncation +int delay_id_truncation; // id of truncation delay diff --git a/inst/stan/data/simulation_rt.stan b/inst/stan/data/simulation_rt.stan index bab9fc6ef..e7d3f757e 100644 --- a/inst/stan/data/simulation_rt.stan +++ b/inst/stan/data/simulation_rt.stan @@ -5,7 +5,7 @@ matrix[n, t - seeding_time] R; // reproduction number int use_pop; // use population size (0 = no; 1 = forecasts; 2 = all) real pop_floor; // Minimum susceptible population (numerical stability floor) -int gt_id; // id of generation time +int delay_id_generation_time; // id of generation time int growth_method; // method to compute growth rate (0 = infections, 1 = infectiousness) diff --git a/inst/stan/estimate_infections.stan b/inst/stan/estimate_infections.stan index e651e6c3c..3817dac39 100644 --- a/inst/stan/estimate_infections.stan +++ b/inst/stan/estimate_infections.stan @@ -75,7 +75,8 @@ transformed parameters { vector[t] infections; // latent infections vector[ot_h] reports; // estimated reported cases vector[ot] obs_reports; // observed estimated reported cases - vector[estimate_r * (delay_type_max[gt_id] + 1)] gt_rev_pmf; + vector[estimate_r * (delay_type_max[delay_id_generation_time] + 1)] + gt_rev_pmf; // GP in noise - spectral densities profile("update gp") { @@ -98,10 +99,10 @@ transformed parameters { if (estimate_r) { profile("gt") { gt_rev_pmf = get_delay_rev_pmf( - gt_id, delay_type_max[gt_id] + 1, delay_types_p, delay_types_id, - delay_types_groups, delay_max, delay_np_pmf, - delay_np_pmf_groups, delay_params, delay_params_groups, delay_dist, - 1, 1, 0 + delay_id_generation_time, delay_type_max[delay_id_generation_time] + 1, + delay_types_p, delay_types_id, delay_types_groups, delay_max, + delay_np_pmf, delay_np_pmf_groups, delay_params, delay_params_groups, + delay_dist, 1, 1, 0 ); } profile("R0") { @@ -136,18 +137,18 @@ transformed parameters { } // convolve from latent infections to mean of observations - if (delay_id) { - vector[delay_type_max[delay_id] + 1] delay_rev_pmf; + if (delay_id_reporting) { + vector[delay_type_max[delay_id_reporting] + 1] reporting_rev_pmf; profile("delays") { - delay_rev_pmf = get_delay_rev_pmf( - delay_id, delay_type_max[delay_id] + 1, delay_types_p, delay_types_id, - delay_types_groups, delay_max, delay_np_pmf, - delay_np_pmf_groups, delay_params, delay_params_groups, delay_dist, - 0, 1, 0 + reporting_rev_pmf = get_delay_rev_pmf( + delay_id_reporting, delay_type_max[delay_id_reporting] + 1, + delay_types_p, delay_types_id, delay_types_groups, delay_max, + delay_np_pmf, delay_np_pmf_groups, delay_params, delay_params_groups, + delay_dist, 0, 1, 0 ); } profile("reports") { - reports = convolve_to_report(infections, delay_rev_pmf, seeding_time); + reports = convolve_to_report(infections, reporting_rev_pmf, seeding_time); } } else { reports = infections[(seeding_time + 1):t]; @@ -172,14 +173,14 @@ transformed parameters { } // truncate near time cases to observed reports - if (trunc_id) { - vector[delay_type_max[trunc_id] + 1] trunc_rev_cmf; + if (delay_id_truncation) { + vector[delay_type_max[delay_id_truncation] + 1] trunc_rev_cmf; profile("truncation") { trunc_rev_cmf = get_delay_rev_pmf( - trunc_id, delay_type_max[trunc_id] + 1, delay_types_p, delay_types_id, - delay_types_groups, delay_max, delay_np_pmf, - delay_np_pmf_groups, delay_params, delay_params_groups, delay_dist, - 0, 1, 1 + delay_id_truncation, delay_type_max[delay_id_truncation] + 1, + delay_types_p, delay_types_id, delay_types_groups, delay_max, + delay_np_pmf, delay_np_pmf_groups, delay_params, delay_params_groups, + delay_dist, 0, 1, 1 ); } profile("truncate") { @@ -264,19 +265,22 @@ generated quantities { } { - vector[delay_type_max[gt_id] + 1] gt_rev_pmf_for_growth; - + vector[delay_type_max[delay_id_generation_time] + 1] + gt_rev_pmf_for_growth; + if (estimate_r == 0) { // sample generation time vector[delay_params_length] delay_params_sample = to_vector(normal_lb_rng( delay_params_mean, delay_params_sd, delay_params_lower )); - vector[delay_type_max[gt_id] + 1] sampled_gt_rev_pmf = get_delay_rev_pmf( - gt_id, delay_type_max[gt_id] + 1, delay_types_p, delay_types_id, - delay_types_groups, delay_max, delay_np_pmf, - delay_np_pmf_groups, delay_params_sample, delay_params_groups, - delay_dist, 1, 1, 0 - ); + vector[delay_type_max[delay_id_generation_time] + 1] sampled_gt_rev_pmf = + get_delay_rev_pmf( + delay_id_generation_time, + delay_type_max[delay_id_generation_time] + 1, delay_types_p, + delay_types_id, delay_types_groups, delay_max, delay_np_pmf, + delay_np_pmf_groups, delay_params_sample, delay_params_groups, + delay_dist, 1, 1, 0 + ); gt_rev_pmf_for_growth = sampled_gt_rev_pmf; // calculate Rt using infections and generation time gen_R = calculate_Rt( diff --git a/inst/stan/estimate_secondary.stan b/inst/stan/estimate_secondary.stan index 3e260e96b..83c11a126 100644 --- a/inst/stan/estimate_secondary.stan +++ b/inst/stan/estimate_secondary.stan @@ -56,14 +56,15 @@ transformed parameters { scaled = primary; } - if (delay_id) { - vector[delay_type_max[delay_id] + 1] delay_rev_pmf = get_delay_rev_pmf( - delay_id, delay_type_max[delay_id] + 1, delay_types_p, delay_types_id, - delay_types_groups, delay_max, delay_np_pmf, - delay_np_pmf_groups, delay_params, delay_params_groups, delay_dist, - 0, 1, 0 - ); - convolved = convolved + convolve_to_report(scaled, delay_rev_pmf, 0); + if (delay_id_reporting) { + vector[delay_type_max[delay_id_reporting] + 1] reporting_rev_pmf = + get_delay_rev_pmf( + delay_id_reporting, delay_type_max[delay_id_reporting] + 1, + delay_types_p, delay_types_id, delay_types_groups, delay_max, + delay_np_pmf, delay_np_pmf_groups, delay_params, delay_params_groups, + delay_dist, 0, 1, 0 + ); + convolved = convolved + convolve_to_report(scaled, reporting_rev_pmf, 0); } else { convolved = convolved + scaled; } @@ -80,13 +81,14 @@ transformed parameters { } // truncate near time cases to observed reports - if (trunc_id) { - vector[delay_type_max[trunc_id]] trunc_rev_cmf = get_delay_rev_pmf( - trunc_id, delay_type_max[trunc_id] + 1, delay_types_p, delay_types_id, - delay_types_groups, delay_max, delay_np_pmf, - delay_np_pmf_groups, delay_params, delay_params_groups, delay_dist, - 0, 1, 1 - ); + if (delay_id_truncation) { + vector[delay_type_max[delay_id_truncation]] trunc_rev_cmf = + get_delay_rev_pmf( + delay_id_truncation, delay_type_max[delay_id_truncation] + 1, + delay_types_p, delay_types_id, delay_types_groups, delay_max, + delay_np_pmf, delay_np_pmf_groups, delay_params, delay_params_groups, + delay_dist, 0, 1, 1 + ); secondary = truncate_obs(secondary, trunc_rev_cmf, 0); } diff --git a/inst/stan/estimate_truncation.stan b/inst/stan/estimate_truncation.stan index c62365582..75ba24f09 100644 --- a/inst/stan/estimate_truncation.stan +++ b/inst/stan/estimate_truncation.stan @@ -14,7 +14,7 @@ data { } transformed data{ - int trunc_id = 1; + int delay_id_truncation = 1; array[obs_sets] int end_t; array[obs_sets] int start_t; @@ -26,7 +26,7 @@ transformed data{ for (i in 1:obs_sets) { end_t[i] = t - obs_dist[i]; - start_t[i] = max(1, end_t[i] - delay_type_max[trunc_id]); + start_t[i] = max(1, end_t[i] - delay_type_max[delay_id_truncation]); } } @@ -38,15 +38,15 @@ parameters { transformed parameters{ real phi = 1 / sqrt(dispersion); - matrix[delay_type_max[trunc_id] + 1, obs_sets - 1] trunc_obs = rep_matrix( - 0, delay_type_max[trunc_id] + 1, obs_sets - 1 - ); - vector[delay_type_max[trunc_id] + 1] trunc_rev_cmf = get_delay_rev_pmf( - trunc_id, delay_type_max[trunc_id] + 1, delay_types_p, delay_types_id, - delay_types_groups, delay_max, delay_np_pmf, - delay_np_pmf_groups, delay_params, delay_params_groups, delay_dist, - 0, 1, 1 - ); + matrix[delay_type_max[delay_id_truncation] + 1, obs_sets - 1] trunc_obs = + rep_matrix(0, delay_type_max[delay_id_truncation] + 1, obs_sets - 1); + vector[delay_type_max[delay_id_truncation] + 1] trunc_rev_cmf = + get_delay_rev_pmf( + delay_id_truncation, delay_type_max[delay_id_truncation] + 1, + delay_types_p, delay_types_id, delay_types_groups, delay_max, + delay_np_pmf, delay_np_pmf_groups, delay_params, delay_params_groups, + delay_dist, 0, 1, 1 + ); { vector[t] last_obs; // reconstruct latest data without truncation @@ -79,11 +79,10 @@ model { } generated quantities { - matrix[delay_type_max[trunc_id] + 1, obs_sets] recon_obs = rep_matrix( - 0, delay_type_max[trunc_id] + 1, obs_sets - ); - matrix[delay_type_max[trunc_id] + 1, obs_sets - 1] gen_obs; - // reconstruct all truncated datasets using posterior of the truncation distribution + matrix[delay_type_max[delay_id_truncation] + 1, obs_sets] recon_obs = + rep_matrix(0, delay_type_max[delay_id_truncation] + 1, obs_sets); + matrix[delay_type_max[delay_id_truncation] + 1, obs_sets - 1] gen_obs; + // reconstruct all truncated datasets using posterior of truncation dist for (i in 1:obs_sets) { recon_obs[1:(end_t[i] - start_t[i] + 1), i] = truncate_obs( to_vector(obs[start_t[i]:end_t[i], i]), trunc_rev_cmf, 1 @@ -91,7 +90,7 @@ generated quantities { } // generate observations for comparing for (i in 1:(obs_sets - 1)) { - for (j in 1:(delay_type_max[trunc_id] + 1)) { + for (j in 1:(delay_type_max[delay_id_truncation] + 1)) { if (trunc_obs[j, i] == 0) { gen_obs[j, i] = 0; } else { diff --git a/inst/stan/simulate_infections.stan b/inst/stan/simulate_infections.stan index d4c7531ef..a715a3522 100644 --- a/inst/stan/simulate_infections.stan +++ b/inst/stan/simulate_infections.stan @@ -59,29 +59,33 @@ generated quantities { for (i in 1:n) { // generate infections from Rt trace - vector[delay_type_max[gt_id] + 1] gt_rev_pmf; + vector[delay_type_max[delay_id_generation_time] + 1] gt_rev_pmf; gt_rev_pmf = get_delay_rev_pmf( - gt_id, delay_type_max[gt_id] + 1, delay_types_p, delay_types_id, - delay_types_groups, delay_max, delay_np_pmf, - delay_np_pmf_groups, delay_params[i], delay_params_groups, delay_dist, - 1, 1, 0 + delay_id_generation_time, delay_type_max[delay_id_generation_time] + 1, + delay_types_p, delay_types_id, delay_types_groups, delay_max, + delay_np_pmf, delay_np_pmf_groups, delay_params[i], + delay_params_groups, delay_dist, 1, 1, 0 ); infections[i] = to_row_vector(generate_infections( to_vector(R[i]), seeding_time, gt_rev_pmf, initial_infections[i], - pop[i], use_pop, pop_floor, future_time, obs_scale, frac_obs[i], initial_as_scale + pop[i], use_pop, pop_floor, future_time, obs_scale, frac_obs[i], + initial_as_scale )); - if (delay_id) { - vector[delay_type_max[delay_id] + 1] delay_rev_pmf = get_delay_rev_pmf( - delay_id, delay_type_max[delay_id] + 1, delay_types_p, delay_types_id, - delay_types_groups, delay_max, delay_np_pmf, - delay_np_pmf_groups, delay_params[i], delay_params_groups, delay_dist, - 0, 1, 0 - ); + if (delay_id_reporting) { + vector[delay_type_max[delay_id_reporting] + 1] reporting_rev_pmf = + get_delay_rev_pmf( + delay_id_reporting, delay_type_max[delay_id_reporting] + 1, + delay_types_p, delay_types_id, delay_types_groups, delay_max, + delay_np_pmf, delay_np_pmf_groups, delay_params[i], + delay_params_groups, delay_dist, 0, 1, 0 + ); // convolve from latent infections to mean of observations - reports[i] = to_row_vector(convolve_to_report( - to_vector(infections[i]), delay_rev_pmf, seeding_time) + reports[i] = to_row_vector( + convolve_to_report( + to_vector(infections[i]), reporting_rev_pmf, seeding_time + ) ); } else { reports[i] = to_row_vector( @@ -96,20 +100,23 @@ generated quantities { to_vector(day_of_week_simplex[i]))); } // truncate near time cases to observed reports - if (trunc_id) { - vector[delay_type_max[trunc_id] + 1] trunc_rev_cmf = get_delay_rev_pmf( - trunc_id, delay_type_max[trunc_id] + 1, delay_types_p, delay_types_id, - delay_types_groups, delay_max, delay_np_pmf, - delay_np_pmf_groups, delay_params[i], delay_params_groups, delay_dist, - 0, 1, 1 - ); + if (delay_id_truncation) { + vector[delay_type_max[delay_id_truncation] + 1] trunc_rev_cmf = + get_delay_rev_pmf( + delay_id_truncation, delay_type_max[delay_id_truncation] + 1, + delay_types_p, delay_types_id, delay_types_groups, delay_max, + delay_np_pmf, delay_np_pmf_groups, delay_params[i], + delay_params_groups, delay_dist, 0, 1, 1 + ); reports[i] = to_row_vector(truncate_obs( to_vector(reports[i]), trunc_rev_cmf, 0) ); } // scale observations if (obs_scale) { - reports[i] = to_row_vector(scale_obs(to_vector(reports[i]), frac_obs[i])); + reports[i] = to_row_vector( + scale_obs(to_vector(reports[i]), frac_obs[i]) + ); } // simulate reported cases imputed_reports[i] = report_rng( diff --git a/inst/stan/simulate_secondary.stan b/inst/stan/simulate_secondary.stan index b5169e0ed..0b2f8949b 100644 --- a/inst/stan/simulate_secondary.stan +++ b/inst/stan/simulate_secondary.stan @@ -53,14 +53,15 @@ generated quantities { scaled = to_vector(primary[i]); } - if (delay_id) { - vector[delay_type_max[delay_id] + 1] delay_rev_pmf = get_delay_rev_pmf( - delay_id, delay_type_max[delay_id] + 1, delay_types_p, delay_types_id, - delay_types_groups, delay_max, delay_np_pmf, - delay_np_pmf_groups, delay_params[i], delay_params_groups, delay_dist, - 0, 1, 0 - ); - convolved = convolved + convolve_to_report(scaled, delay_rev_pmf, 0); + if (delay_id_reporting) { + vector[delay_type_max[delay_id_reporting] + 1] reporting_rev_pmf = + get_delay_rev_pmf( + delay_id_reporting, delay_type_max[delay_id_reporting] + 1, + delay_types_p, delay_types_id, delay_types_groups, delay_max, + delay_np_pmf, delay_np_pmf_groups, delay_params[i], + delay_params_groups, delay_dist, 0, 1, 0 + ); + convolved = convolved + convolve_to_report(scaled, reporting_rev_pmf, 0); } else { convolved = convolved + scaled; } @@ -77,13 +78,14 @@ generated quantities { } // truncate near time cases to observed reports - if (trunc_id) { - vector[delay_type_max[trunc_id] + 1] trunc_rev_cmf = get_delay_rev_pmf( - trunc_id, delay_type_max[trunc_id] + 1, delay_types_p, delay_types_id, - delay_types_groups, delay_max, delay_np_pmf, - delay_np_pmf_groups, delay_params[i], delay_params_groups, delay_dist, - 0, 1, 1 - ); + if (delay_id_truncation) { + vector[delay_type_max[delay_id_truncation] + 1] trunc_rev_cmf = + get_delay_rev_pmf( + delay_id_truncation, delay_type_max[delay_id_truncation] + 1, + delay_types_p, delay_types_id, delay_types_groups, delay_max, + delay_np_pmf, delay_np_pmf_groups, delay_params[i], + delay_params_groups, delay_dist, 0, 1, 1 + ); secondary = truncate_obs( secondary, trunc_rev_cmf, 0 ); diff --git a/man/cash-.estimate_secondary.Rd b/man/cash-.estimate_secondary.Rd new file mode 100644 index 000000000..235bb0c64 --- /dev/null +++ b/man/cash-.estimate_secondary.Rd @@ -0,0 +1,28 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/estimate_secondary.R +\name{$.estimate_secondary} +\alias{$.estimate_secondary} +\title{Extract elements from estimate_secondary objects with deprecated warnings} +\usage{ +\method{$}{estimate_secondary}(x, name) +} +\arguments{ +\item{x}{An \code{estimate_secondary} object} + +\item{name}{The name of the element to extract} +} +\value{ +The requested element with a deprecation warning +} +\description{ +\ifelse{html}{\href{https://lifecycle.r-lib.org/articles/stages.html#deprecated}{\figure{lifecycle-deprecated.svg}{options: alt='[Deprecated]'}}}{\strong{[Deprecated]}} +Provides backward compatibility for the old return structure. The previous +structure with \code{predictions}, \code{posterior}, and \code{data} +elements is deprecated. Use the accessor methods instead: +\itemize{ +\item \code{predictions} - use \code{get_predictions(object)} +\item \code{posterior} - use \code{get_samples(object)} +\item \code{data} - use \code{object$observations} +} +} +\keyword{internal} diff --git a/man/estimate_infections.Rd b/man/estimate_infections.Rd index c16820ab7..c15390315 100644 --- a/man/estimate_infections.Rd +++ b/man/estimate_infections.Rd @@ -106,8 +106,9 @@ horizon} \value{ An \verb{} object which is a list of outputs including: the stan object (\code{fit}), arguments used to fit the model -(\code{args}), and the observed data (\code{observations}). The estimates included in -the fit can be accessed using the \code{summary()} function. +(\code{args}), and the observed data (\code{observations}). Use \code{summary()} to access +estimates, \code{get_samples()} to extract posterior samples, and +\code{get_predictions()} to access predicted reported cases. } \description{ \ifelse{html}{\href{https://lifecycle.r-lib.org/articles/stages.html#maturing}{\figure{lifecycle-maturing.svg}{options: alt='[Maturing]'}}}{\strong{[Maturing]}} diff --git a/man/estimate_secondary.Rd b/man/estimate_secondary.Rd index d003b4d04..4788d2f27 100644 --- a/man/estimate_secondary.Rd +++ b/man/estimate_secondary.Rd @@ -91,11 +91,11 @@ number of cases based on the 7-day average. If the average is above this threshold then the zero is replaced using \code{fill}.} } \value{ -A list containing: \code{predictions} (a \verb{} ordered by date -with the primary, and secondary observations, and a summary of the model -estimated secondary observations), \code{posterior} which contains a summary of -the entire model posterior, \code{data} (a list of data used to fit the -model), and \code{fit} (the \code{stanfit} object). +An \verb{} object which is a list of outputs +including: the stan object (\code{fit}), arguments used to fit the model +(\code{args}), and the observed data (\code{observations}). Use \code{summary()} to access +estimates, \code{get_samples()} to extract posterior samples, and +\code{get_predictions()} to access predictions. } \description{ \ifelse{html}{\href{https://lifecycle.r-lib.org/articles/stages.html#stable}{\figure{lifecycle-stable.svg}{options: alt='[Stable]'}}}{\strong{[Stable]}} diff --git a/man/extract_delays.Rd b/man/extract_delays.Rd new file mode 100644 index 000000000..7256b217c --- /dev/null +++ b/man/extract_delays.Rd @@ -0,0 +1,22 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/extract.R +\name{extract_delays} +\alias{extract_delays} +\title{Extract samples from all delay parameters} +\usage{ +extract_delays(samples) +} +\arguments{ +\item{samples}{Extracted stan model (using \code{\link[rstan:stanfit-method-extract]{rstan::extract()}})} +} +\value{ +A \verb{} with columns: parameter, sample, value, or NULL if +delay parameters don't exist in the samples +} +\description{ +Extracts samples from all delay parameters using the delay ID lookup system. +Similar to extract_parameters(), this extracts all delay distribution +parameters and uses the *_id variables (e.g., delay_id, trunc_id) to assign +meaningful names. +} +\keyword{internal} diff --git a/man/extract_latent_state.Rd b/man/extract_latent_state.Rd index 89aaff260..eb38b5b52 100644 --- a/man/extract_latent_state.Rd +++ b/man/extract_latent_state.Rd @@ -2,7 +2,7 @@ % Please edit documentation in R/extract.R \name{extract_latent_state} \alias{extract_latent_state} -\title{Extract Samples for a Latent State from a Stan model} +\title{Extract samples for a latent state from a Stan model} \usage{ extract_latent_state(param, samples, dates) } diff --git a/man/extract_parameter.Rd b/man/extract_parameter.Rd deleted file mode 100644 index 4c9b00a52..000000000 --- a/man/extract_parameter.Rd +++ /dev/null @@ -1,21 +0,0 @@ -% Generated by roxygen2: do not edit by hand -% Please edit documentation in R/extract.R -\name{extract_parameter} -\alias{extract_parameter} -\title{Extract Samples from a Parameter with a Single Dimension} -\usage{ -extract_parameter(param, samples) -} -\arguments{ -\item{param}{Character string indicating the parameter to extract} - -\item{samples}{Extracted stan model (using \code{\link[rstan:stanfit-method-extract]{rstan::extract()}})} -} -\value{ -A \verb{} containing the parameter name, sample id and sample -value, or NULL if the parameter doesn't exist in the samples -} -\description{ -Extract Samples from a Parameter with a Single Dimension -} -\keyword{internal} diff --git a/man/extract_parameter_samples.Rd b/man/extract_parameter_samples.Rd index cea50c37f..8b35932d9 100644 --- a/man/extract_parameter_samples.Rd +++ b/man/extract_parameter_samples.Rd @@ -2,7 +2,7 @@ % Please edit documentation in R/extract.R \name{extract_parameter_samples} \alias{extract_parameter_samples} -\title{Extract Parameter Samples from a Stan Model} +\title{Extract parameter samples from a Stan model} \usage{ extract_parameter_samples( stan_fit, diff --git a/man/extract_parameters.Rd b/man/extract_parameters.Rd new file mode 100644 index 000000000..b87c02b0c --- /dev/null +++ b/man/extract_parameters.Rd @@ -0,0 +1,19 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/extract.R +\name{extract_parameters} +\alias{extract_parameters} +\title{Extract samples from all parameters} +\usage{ +extract_parameters(samples) +} +\arguments{ +\item{samples}{Extracted stan model (using \code{\link[rstan:stanfit-method-extract]{rstan::extract()}})} +} +\value{ +A \verb{} containing the parameter name, sample id and sample +value, or NULL if parameters don't exist in the samples +} +\description{ +Extract samples from all parameters +} +\keyword{internal} diff --git a/man/extract_stan_param.Rd b/man/extract_stan_param.Rd index 9cf9c7b9c..e4f132b9c 100644 --- a/man/extract_stan_param.Rd +++ b/man/extract_stan_param.Rd @@ -2,7 +2,7 @@ % Please edit documentation in R/extract.R \name{extract_stan_param} \alias{extract_stan_param} -\title{Extract a Parameter Summary from a Stan Object} +\title{Extract a parameter summary from a Stan object} \usage{ extract_stan_param( fit, diff --git a/man/get_predictions.Rd b/man/get_predictions.Rd new file mode 100644 index 000000000..4643bf8cb --- /dev/null +++ b/man/get_predictions.Rd @@ -0,0 +1,45 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/get.R +\name{get_predictions} +\alias{get_predictions} +\alias{get_predictions.estimate_infections} +\alias{get_predictions.estimate_secondary} +\alias{get_predictions.forecast_infections} +\alias{get_predictions.forecast_secondary} +\title{Get predictions from a fitted model} +\usage{ +get_predictions(object, ...) + +\method{get_predictions}{estimate_infections}(object, CrIs = c(0.2, 0.5, 0.9), ...) + +\method{get_predictions}{estimate_secondary}(object, CrIs = c(0.2, 0.5, 0.9), ...) + +\method{get_predictions}{forecast_infections}(object, ...) + +\method{get_predictions}{forecast_secondary}(object, ...) +} +\arguments{ +\item{object}{A fitted model object (e.g., from \code{estimate_infections()} or +\code{estimate_secondary()})} + +\item{...}{Additional arguments (currently unused)} + +\item{CrIs}{Numeric vector of credible intervals to return. Defaults to +c(0.2, 0.5, 0.9).} +} +\value{ +A \code{data.table} with columns including date, observations, and summary +statistics (mean, sd, credible intervals) for the model predictions. +} +\description{ +\ifelse{html}{\href{https://lifecycle.r-lib.org/articles/stages.html#stable}{\figure{lifecycle-stable.svg}{options: alt='[Stable]'}}}{\strong{[Stable]}} +Extracts predictions from a fitted model, combining observations with model +estimates. For \code{estimate_infections()} returns predicted reported cases, for +\code{estimate_secondary()} returns predicted secondary observations. +} +\examples{ +\dontrun{ +# After fitting a model +predictions <- get_predictions(fit) +} +} diff --git a/man/get_samples.Rd b/man/get_samples.Rd index d4ed3b890..e4c9d7736 100644 --- a/man/get_samples.Rd +++ b/man/get_samples.Rd @@ -4,6 +4,8 @@ \alias{get_samples} \alias{get_samples.estimate_infections} \alias{get_samples.forecast_infections} +\alias{get_samples.estimate_secondary} +\alias{get_samples.forecast_secondary} \title{Get posterior samples from a fitted model} \usage{ get_samples(object, ...) @@ -11,6 +13,10 @@ get_samples(object, ...) \method{get_samples}{estimate_infections}(object, ...) \method{get_samples}{forecast_infections}(object, ...) + +\method{get_samples}{estimate_secondary}(object, ...) + +\method{get_samples}{forecast_secondary}(object, ...) } \arguments{ \item{object}{A fitted model object (e.g., from \code{estimate_infections()})} diff --git a/man/plot.forecast_secondary.Rd b/man/plot.forecast_secondary.Rd new file mode 100644 index 000000000..7fdec4e94 --- /dev/null +++ b/man/plot.forecast_secondary.Rd @@ -0,0 +1,28 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/estimate_secondary.R +\name{plot.forecast_secondary} +\alias{plot.forecast_secondary} +\title{Plot method for forecast_secondary objects} +\usage{ +\method{plot}{forecast_secondary}(x, primary = FALSE, from = NULL, to = NULL, new_obs = NULL, ...) +} +\arguments{ +\item{x}{A list of output as produced by \code{estimate_secondary}} + +\item{primary}{Logical, defaults to \code{FALSE}. Should \code{primary} reports also +be plot?} + +\item{from}{Date object indicating when to plot from.} + +\item{to}{Date object indicating when to plot up to.} + +\item{new_obs}{A \verb{} containing the columns \code{date} and \code{secondary} +which replace the secondary observations stored in the \code{estimate_secondary} +output.} + +\item{...}{Pass additional arguments to plot function. Not currently in use.} +} +\description{ +\ifelse{html}{\href{https://lifecycle.r-lib.org/articles/stages.html#stable}{\figure{lifecycle-stable.svg}{options: alt='[Stable]'}}}{\strong{[Stable]}} +Plot method for forecast secondary observations. +} diff --git a/man/summary.estimate_secondary.Rd b/man/summary.estimate_secondary.Rd new file mode 100644 index 000000000..0ea9db8a6 --- /dev/null +++ b/man/summary.estimate_secondary.Rd @@ -0,0 +1,40 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/summarise.R +\name{summary.estimate_secondary} +\alias{summary.estimate_secondary} +\title{Summarise results from estimate_secondary} +\usage{ +\method{summary}{estimate_secondary}( + object, + type = c("compact", "parameters"), + params = NULL, + CrIs = c(0.2, 0.5, 0.9), + ... +) +} +\arguments{ +\item{object}{A fitted model object from \code{estimate_secondary()}} + +\item{type}{Character string indicating the type of summary to return. +Options are "compact" (default, key parameters only) or "parameters" +(all parameters or filtered set).} + +\item{params}{Character vector of parameter names to include. Only used +when \code{type = "parameters"}. If NULL (default), returns all parameters.} + +\item{CrIs}{Numeric vector of credible intervals to return. Defaults to +c(0.2, 0.5, 0.9).} + +\item{...}{Additional arguments (currently unused)} +} +\value{ +A \verb{} with summary statistics (mean, sd, median, +credible intervals) for model parameters. When \code{type = "compact"}, +returns only key parameters (delay distribution parameters and scaling +factors). When \code{type = "parameters"}, returns all or filtered parameters. +} +\description{ +\ifelse{html}{\href{https://lifecycle.r-lib.org/articles/stages.html#stable}{\figure{lifecycle-stable.svg}{options: alt='[Stable]'}}}{\strong{[Stable]}} +Returns a summary of the fitted secondary model including posterior +parameter estimates with credible intervals. +} diff --git a/tests/testthat/test-checks.R b/tests/testthat/test-checks.R index 15c897d1d..414650198 100644 --- a/tests/testthat/test-checks.R +++ b/tests/testthat/test-checks.R @@ -207,7 +207,7 @@ test_that("check_truncation_length warns when truncation PMF is longer than time # Test with long truncation PMF (length 15) and time_points = 10 stan_args <- list( data = list( - trunc_id = 1, + delay_id_truncation = 1, delay_n_np = 1, delay_types_groups = array(c(1, 2)), # Truncation maps to position 1 delay_types_p = array(c(0)), # Truncation is nonparametric @@ -228,7 +228,7 @@ test_that("check_truncation_length works with truncation from create_stan_delays short_trunc <- trunc_opts(dist = LogNormal(mean = 1, sd = 0.5, max = 5)) stan_args_short <- list( data = create_stan_delays( - trunc = short_trunc, + truncation = short_trunc, time_points = 10 ) ) @@ -240,7 +240,7 @@ test_that("check_truncation_length works with truncation from create_stan_delays long_trunc <- trunc_opts(dist = LogNormal(mean = 2, sd = 0.5, max = 20)) stan_args_long <- list( data = create_stan_delays( - trunc = long_trunc, + truncation = long_trunc, time_points = 10 ) ) @@ -261,9 +261,9 @@ test_that("check_truncation_length works when truncation is combined with other stan_args <- list( data = create_stan_delays( - gt = gt, - delay = delays, - trunc = long_trunc, + generation_time = gt, + reporting = delays, + truncation = long_trunc, time_points = 10 ) ) @@ -283,7 +283,7 @@ test_that("check_truncation_length correctly indexes when parametric delays prec # to index into np_pmf_lengths, which only contains nonparametric delays stan_args <- list( data = list( - trunc_id = 3, + delay_id_truncation = 3, delay_n_np = 1, delay_types_groups = array(c(1, 2, 3, 4)), # Three delays total delay_types_p = array(c(1, 1, 0)), # First two parametric, third nonparametric diff --git a/tests/testthat/test-delays.R b/tests/testthat/test-delays.R index 68327aca0..6be90663f 100644 --- a/tests/testthat/test-delays.R +++ b/tests/testthat/test-delays.R @@ -93,3 +93,87 @@ test_that("distributions incompatible with stan models are caught", { Normal(2, 2, max = 10) ), "lognormal") }) + +test_that("create_stan_delays creates delay_id_* variables with correct names", { + # Test with all delay types (infection context) + data <- EpiNow2:::create_stan_delays( + generation_time = gt_opts(Fixed(1)), + reporting = delay_opts(Fixed(2)), + truncation = trunc_opts(Fixed(1)) + ) + + expect_true("delay_id_generation_time" %in% names(data)) + expect_true("delay_id_reporting" %in% names(data)) + expect_true("delay_id_truncation" %in% names(data)) + + # IDs should be sequential for non-empty delays + expect_equal(data$delay_id_generation_time, 1) + expect_equal(data$delay_id_reporting, 2) + expect_equal(data$delay_id_truncation, 3) +}) + +test_that("create_stan_delays creates delay_id_* for secondary models", { + # Test with reporting delay for secondary models + data <- EpiNow2:::create_stan_delays( + reporting = delay_opts(Fixed(2)), + truncation = trunc_opts(Fixed(1)) + ) + + expect_true("delay_id_reporting" %in% names(data)) + expect_true("delay_id_truncation" %in% names(data)) + + expect_equal(data$delay_id_reporting, 1) + expect_equal(data$delay_id_truncation, 2) +}) + +test_that("create_stan_delays sets ID to 0 for missing delays", { + # Test with only one delay type + data <- EpiNow2:::create_stan_delays( + generation_time = gt_opts(Fixed(1)) + ) + + expect_equal(data$delay_id_generation_time, 1) + # No reporting or truncation delays provided + expect_false("delay_id_reporting" %in% names(data)) + expect_false("delay_id_truncation" %in% names(data)) +}) + +test_that("extract_delays works with delay_id_* naming", { + # Create mock samples with delay_id_* variables + samples <- list( + delay_params = matrix(c(1.5, 2.0, 1.8, 2.2), nrow = 2, ncol = 2), + delay_id_generation_time = c(1, 1), # ID = 1 + delay_id_reporting = c(0, 0), # ID = 0 (not used) + delay_types_groups = c(1, 3) # Group 1: cols 1-2 + ) + + result <- EpiNow2:::extract_delays(samples) + + expect_true(!is.null(result)) + expect_true("parameter" %in% names(result)) + expect_true("sample" %in% names(result)) + expect_true("value" %in% names(result)) + + # Check that generation_time parameters are named correctly + expect_true(any(grepl("generation_time\\[1\\]", result$parameter))) + expect_true(any(grepl("generation_time\\[2\\]", result$parameter))) +}) + +test_that("extract_delays returns NULL when delay_params don't exist", { + samples <- list(some_other_param = 1:10) + result <- EpiNow2:::extract_delays(samples) + expect_null(result) +}) + +test_that("extract_delays handles delays with no ID lookup gracefully", { + # Samples without delay_id_* variables + samples <- list( + delay_params = matrix(c(1.5, 2.0), nrow = 2, ncol = 1) + ) + + result <- EpiNow2:::extract_delays(samples) + + expect_true(!is.null(result)) + # Should fall back to indexed naming + expect_true(any(grepl("delay_params\\[", result$parameter))) +}) diff --git a/tests/testthat/test-estimate_infections.R b/tests/testthat/test-estimate_infections.R index 76d8b0571..204f4a10c 100644 --- a/tests/testthat/test-estimate_infections.R +++ b/tests/testthat/test-estimate_infections.R @@ -29,6 +29,14 @@ test_estimate_infections <- function(...) { expect_true(nrow(get_samples(out)) > 0) expect_true(nrow(summary(out, type = "parameters")) > 0) expect_true(nrow(out$observations) > 0) + + # Test get_predictions accessor + predictions <- get_predictions(out) + expect_true(nrow(predictions) > 0) + expect_true("date" %in% names(predictions)) + expect_true("confirm" %in% names(predictions)) + expect_true("mean" %in% names(predictions)) + invisible(out) } diff --git a/tests/testthat/test-estimate_secondary.R b/tests/testthat/test-estimate_secondary.R index e921dfdf7..9d94fe924 100644 --- a/tests/testthat/test-estimate_secondary.R +++ b/tests/testthat/test-estimate_secondary.R @@ -56,20 +56,30 @@ default_inc <- estimate_secondary(inc_cases[1:60], verbose = FALSE ) +# Test output test_that("estimate_secondary can return values from simulated data and plot them", { # Reuse pre-computed fit inc <- default_inc - expect_equal(names(inc), c("predictions", "posterior", "data", "fit")) + expect_equal(names(inc), c("fit", "args", "observations")) + expect_s3_class(inc, "estimate_secondary") + expect_s3_class(inc, "epinowfit") + + # Test accessor methods + predictions <- get_predictions(inc) expect_equal( - names(inc$predictions), + names(predictions), c( - "date", "primary", "secondary", "accumulate", "mean", "se_mean", "sd", - "lower_90", "lower_50", "lower_20", "median", "upper_20", "upper_50", "upper_90" + "date", "primary", "secondary", "accumulate", "median", "mean", "sd", + "lower_90", "lower_50", "lower_20", "upper_20", "upper_50", "upper_90" ) ) - expect_true(is.list(inc$data)) + + posterior <- get_samples(inc) + expect_true(is.data.frame(posterior)) + + expect_true(is.list(inc$args)) # validation plot of observations vs estimates expect_error(plot(inc, primary = TRUE), NA) }) @@ -88,6 +98,27 @@ test_that("forecast_secondary can return values from simulated data and plot expect_error(plot(inc_preds, new_obs = inc_cases, from = "2020-05-01"), NA) }) +test_that("estimate_secondary recovers scaling parameter from incidence data", { + # Basic parameter recovery check using pre-computed fit + # inc_cases was set up with scaling = 0.4, meanlog = 1.8, sdlog = 0.5 + params <- c( + "meanlog" = "delay_params[1]", "sdlog" = "delay_params[2]", + "scaling" = "params[1]" + ) + + inc_posterior <- get_samples(default_inc)[variable %in% params] + inc_summary <- inc_posterior[, .( + mean = mean(value), + median = stats::median(value) + ), by = variable] + + # Check scaling parameter is reasonably recovered (0.4 true value) + expect_equal( + inc_summary$mean, c(1.8, 0.5, 0.4), + tolerance = 0.15 + ) +}) + # Variant tests: Only run in full test mode (EPINOW2_SKIP_INTEGRATION=false) - test_that("estimate_secondary successfully returns estimates when passed NA values", { @@ -114,8 +145,8 @@ test_that("estimate_secondary successfully returns estimates when passed NA valu obs = obs_opts(scale = Normal(mean = 0.2, sd = 0.2), week_effect = FALSE), verbose = FALSE ) - expect_true(is.list(inc_na$data)) - expect_true(is.list(prev_na$data)) + expect_true(is.list(inc_na$args)) + expect_true(is.list(prev_na$args)) }) test_that("estimate_secondary successfully returns estimates when accumulating to weekly", { @@ -146,7 +177,7 @@ test_that("estimate_secondary successfully returns estimates when accumulating t scale = Normal(mean = 0.4, sd = 0.05), week_effect = FALSE ), verbose = FALSE ) - expect_true(is.list(inc_weekly$data)) + expect_true(is.list(inc_weekly$args)) }) test_that("estimate_secondary works when only estimating scaling", { @@ -158,7 +189,7 @@ test_that("estimate_secondary works when only estimating scaling", { delay = delay_opts(), verbose = FALSE ) - expect_equal(names(inc), c("predictions", "posterior", "data", "fit")) + expect_equal(names(inc), c("fit", "args", "observations")) }) test_that("estimate_secondary can recover simulated parameters", { @@ -190,23 +221,33 @@ test_that("estimate_secondary can recover simulated parameters", { "scaling" = "params[1]" ) - inc_posterior <- inc$posterior[variable %in% params] - prev_posterior <- prev$posterior[variable %in% params] + inc_posterior <- get_samples(inc)[variable %in% params] + prev_posterior <- get_samples(prev)[variable %in% params] + + # Calculate summary statistics from raw samples + inc_summary <- inc_posterior[, .( + mean = mean(value), + median = stats::median(value) + ), by = variable] + prev_summary <- prev_posterior[, .( + mean = mean(value), + median = stats::median(value) + ), by = variable] expect_equal( - inc_posterior[, mean], c(1.8, 0.5, 0.4), + inc_summary$mean, c(1.8, 0.5, 0.4), tolerance = 0.1 ) expect_equal( - inc_posterior[, median], c(1.8, 0.5, 0.4), + inc_summary$median, c(1.8, 0.5, 0.4), tolerance = 0.1 ) expect_equal( - prev_posterior[, mean], c(1.6, 0.8, 0.3), + prev_summary$mean, c(1.6, 0.8, 0.3), tolerance = 0.2 ) expect_equal( - prev_posterior[, median], c(1.6, 0.8, 0.3), + prev_summary$median, c(1.6, 0.8, 0.3), tolerance = 0.2 ) }) @@ -229,13 +270,20 @@ test_that("estimate_secondary can recover simulated parameters with the verbose = FALSE, stan = stan_opts(backend = "cmdstanr") ) ))) - inc_posterior_cmdstanr <- inc_cmdstanr$posterior[variable %in% params] + inc_posterior_cmdstanr <- get_samples(inc_cmdstanr)[variable %in% params] + + # Calculate summary statistics from raw samples + inc_summary_cmdstanr <- inc_posterior_cmdstanr[, .( + mean = mean(value), + median = stats::median(value) + ), by = variable] + expect_equal( - inc_posterior_cmdstanr[, mean], c(1.8, 0.5, 0.4), + inc_summary_cmdstanr$mean, c(1.8, 0.5, 0.4), tolerance = 0.1 ) expect_equal( - inc_posterior_cmdstanr[, median], c(1.8, 0.5, 0.4), + inc_summary_cmdstanr$median, c(1.8, 0.5, 0.4), tolerance = 0.1 ) }) @@ -316,8 +364,8 @@ test_that("estimate_secondary works with filter_leading_zeros set", { verbose = FALSE )) expect_s3_class(out, "estimate_secondary") - expect_named(out, c("predictions", "posterior", "data", "fit")) - expect_equal(out$predictions$primary, modified_data$primary) + expect_named(out, c("fit", "args", "observations")) + expect_equal(get_predictions(out)$primary, modified_data$primary) }) test_that("estimate_secondary works with zero_threshold set", { @@ -340,5 +388,5 @@ test_that("estimate_secondary works with zero_threshold set", { verbose = FALSE ) expect_s3_class(out, "estimate_secondary") - expect_named(out, c("predictions", "posterior", "data", "fit")) + expect_named(out, c("fit", "args", "observations")) })