Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
51 commits
Select commit Hold shift + click to select a range
141dcd1
Standardise S3 interface for estimate_secondary (#1142)
claude Nov 10, 2025
b7761f6
Add get_predictions() for estimate_infections to complete S3 interfac…
claude Nov 10, 2025
4eef357
Fix summary.estimate_secondary() to return summary statistics instead…
claude Nov 10, 2025
9c4a6d3
Add type and params arguments to summary.estimate_secondary() for con…
claude Nov 10, 2025
b6ae503
Rename summary type from 'snapshot' to 'estimates' for estimate_secon…
claude Nov 10, 2025
cee9290
Rename summary type from 'estimates' to 'compact' for estimate_secondary
claude Nov 11, 2025
4e8fe0f
Export S3 methods for estimate_secondary accessors and summary
claude Nov 11, 2025
c26cc1d
Add @method tag for summary.estimate_secondary
claude Nov 11, 2025
3433201
Fix linting issues: line length and unnecessary nesting
claude Nov 13, 2025
7d59d4f
Fix indentation linting issues in get_predictions functions
claude Nov 13, 2025
85bec8b
Update documentation
epiforecasts-workflows[bot] Nov 13, 2025
e611234
Add backward compatibility for forecast_secondary in get_predictions
claude Nov 13, 2025
0b92ea8
Give forecast_secondary its own S3 class and methods
claude Nov 13, 2025
f4bab07
Update documentation
epiforecasts-workflows[bot] Nov 13, 2025
dcd983b
Add plot method for forecast_secondary objects
claude Nov 14, 2025
f36f0d0
Update documentation
epiforecasts-workflows[bot] Nov 14, 2025
b0c3f03
Pass CrIs parameter through in get_samples.estimate_secondary
claude Nov 14, 2025
7b4036f
Update documentation
epiforecasts-workflows[bot] Nov 14, 2025
f33c208
Clarify pull request workflow in CLAUDE.md
claude Nov 14, 2025
9b13ce3
Add NEWS item for estimate_secondary S3 standardisation
claude Nov 14, 2025
983a271
Fix get_samples to return raw posterior samples
claude Nov 14, 2025
5bd96c8
Update documentation
epiforecasts-workflows[bot] Nov 14, 2025
d1452e7
Add extract_array_parameter helper for consistent param extraction
claude Nov 14, 2025
c6eea24
Simplify plot.forecast_secondary to call plot.estimate_secondary
claude Nov 14, 2025
b8fbfcc
Update documentation
epiforecasts-workflows[bot] Nov 19, 2025
a106aee
Refactor estimate_secondary to use modern extraction pattern
sbfnk Nov 19, 2025
134691c
Unify parameter extraction with extract_parameters() and extract_dela…
sbfnk Nov 19, 2025
f6e5c59
Fix function titles to use sentence case per CLAUDE.md
sbfnk Nov 19, 2025
5271cf1
Implement delay lookup system in extract_delays
sbfnk Nov 19, 2025
5b29932
Update documentation
epiforecasts-workflows[bot] Nov 19, 2025
180fe5e
Implement delay_id_* naming system with semantic delay names
sbfnk Nov 19, 2025
29dedb0
Add tests for delay_id_* naming system
sbfnk Nov 19, 2025
be7e7f0
Fix delay_id naming in checks and tests
sbfnk Nov 19, 2025
ca3d211
Separate delay_id_reporting from observation_model.stan
sbfnk Nov 19, 2025
a91a2bb
Use shared delay_id_reporting for both infections and secondary models
sbfnk Nov 20, 2025
66e6b46
Improve data.table copy safety and add missing methods
sbfnk Nov 20, 2025
3f35449
Add backward compatibility $ operator for estimate_secondary
sbfnk Nov 20, 2025
cd371b6
Restore delay_id_reporting to simulation_delays.stan
sbfnk Nov 20, 2025
2d22d37
Fix delay ID variable names in stan-to-R.R
sbfnk Nov 20, 2025
2b73a44
Rename delay_rev_pmf to reporting_rev_pmf for clarity
sbfnk Nov 20, 2025
b73e1f9
Break long lines in Stan files to stay under 80 characters
sbfnk Nov 20, 2025
c684a8b
Fix critical bug: use delay_id_reporting not delay_id_generation_time
sbfnk Nov 20, 2025
796d681
Fix linting issues
sbfnk Nov 21, 2025
939c865
Update documentation
sbfnk Nov 21, 2025
2f7b5d1
Mark $.estimate_secondary as internal for pkgdown
sbfnk Nov 21, 2025
1964d95
Address CodeRabbit review feedback
sbfnk Nov 21, 2025
8069e53
Fix indentation in checks.R
sbfnk Nov 21, 2025
afc3353
Merge branch 'main' into feature/delay-parameter-lookup
sbfnk Nov 27, 2025
3293941
Remove unnecessary explicit return statements
sbfnk-bot Nov 27, 2025
7c93f19
Merge branch 'main' into feature/delay-parameter-lookup
sbfnk Dec 3, 2025
a4a88b0
Remove explicit return() statements
sbfnk Dec 3, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 20 additions & 0 deletions CLAUDE.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,29 @@

## Pull Requests

### Workflow Overview
When working on a task, follow this sequence:

1. **During development**: Make commits with descriptive messages (NO issue numbers)
2. **Before opening PR**: Add a NEWS.md item for user-facing changes
3. **When opening PR**: Link the issue in PR description using "This PR closes #XXX"
4. **In PR template**: Complete all applicable checklist items

### Key Distinctions
- **Commit messages**: Describe what the code does, never reference issues
- Good: "Standardise return structure for estimate_secondary"
- Bad: "Standardise return structure for estimate_secondary (#1142)"
- **PR descriptions**: Link to issues here using the template
- Format: "This PR closes #1142."
- **NEWS items**: Describe user-facing changes, never reference issues or PRs
- Good: "Added S3 methods for estimate_secondary objects"
- Bad: "Added S3 methods for estimate_secondary (#1142)"

### PR Template Requirements
- Follow the pull request template in `.github/PULL_REQUEST_TEMPLATE.md`
- Complete the checklist items in the template
- Link to the related issue in the PR description (not in commit messages)
- Ensure you've added a NEWS.md item before submitting

## Testing

Expand Down
10 changes: 10 additions & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Generated by roxygen2: do not edit by hand

S3method("!=",dist_spec)
S3method("$",estimate_secondary)
S3method("+",dist_spec)
S3method("==",dist_spec)
S3method(c,dist_spec)
Expand All @@ -10,8 +11,14 @@ S3method(discretise,dist_spec)
S3method(discretise,multi_dist_spec)
S3method(fix_parameters,dist_spec)
S3method(fix_parameters,multi_dist_spec)
S3method(get_predictions,estimate_infections)
S3method(get_predictions,estimate_secondary)
S3method(get_predictions,forecast_infections)
S3method(get_predictions,forecast_secondary)
S3method(get_samples,estimate_infections)
S3method(get_samples,estimate_secondary)
S3method(get_samples,forecast_infections)
S3method(get_samples,forecast_secondary)
S3method(is_constrained,dist_spec)
S3method(is_constrained,multi_dist_spec)
S3method(max,dist_spec)
Expand All @@ -24,13 +31,15 @@ S3method(plot,estimate_infections)
S3method(plot,estimate_secondary)
S3method(plot,estimate_truncation)
S3method(plot,forecast_infections)
S3method(plot,forecast_secondary)
S3method(print,dist_spec)
S3method(print,epinowfit)
S3method(sd,default)
S3method(sd,dist_spec)
S3method(sd,multi_dist_spec)
S3method(summary,epinow)
S3method(summary,estimate_infections)
S3method(summary,estimate_secondary)
S3method(summary,forecast_infections)
export(Fixed)
export(Gamma)
Expand Down Expand Up @@ -79,6 +88,7 @@ export(generation_time_opts)
export(get_distribution)
export(get_parameters)
export(get_pmf)
export(get_predictions)
export(get_regional_results)
export(get_samples)
export(gp_opts)
Expand Down
7 changes: 7 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,13 @@
- **Deprecated**: `summary(object, type = "samples")` now issues a deprecation warning. Use `get_samples(object)` instead.
- **Deprecated**: Internal function `extract_parameter_samples()` renamed to `format_simulation_output()` for clarity.
- `forecast_infections()` now returns an independent S3 class `"forecast_infections"` instead of inheriting from `"estimate_infections"`. This clarifies the distinction between fitted models (which contain a Stan fit for diagnostics) and forecast simulations (which contain pre-computed samples). Dedicated `summary()`, `plot()`, and `get_samples()` methods are provided.
- `estimate_secondary()` now returns an S3 object of class `c("epinowfit", "estimate_secondary", "list")` with elements `fit`, `args`, and `observations`, matching the structure of `estimate_infections()`.
- Use `get_samples(object)` to extract formatted posterior samples for delay and scaling parameters.
- Use `get_predictions(object)` to get predicted secondary observations with credible intervals merged with observations.
- Use `summary(object)` to get summarised parameter estimates. Use `type = "compact"` for key parameters only, or `type = "parameters"` with a `params` argument to select specific parameters.
- Access the Stan fit directly via `object$fit`, model arguments via `object$args`, and observations via `object$observations`.
- **Deprecated**: The previous return structure with `predictions`, `posterior`, and `data` elements is deprecated and will be removed in a future release. Backward compatibility is provided with deprecation warnings when accessing these elements via `$`.
- `forecast_secondary()` now returns an independent S3 class `"forecast_secondary"` instead of inheriting from `"estimate_secondary"`, with dedicated `get_samples()`, `get_predictions()`, and `plot()` methods.
- `plot.estimate_infections()` and `plot.forecast_infections()` now accept a `CrIs` argument to control which credible intervals are displayed.

## Model changes
Expand Down
9 changes: 6 additions & 3 deletions R/checks.R
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,8 @@ check_sparse_pmf_tail <- function(pmf, span = 5, tol = 1e-6) {
#' @keywords internal
check_truncation_length <- function(stan_args, time_points) {
# Check if truncation exists
if (is.null(stan_args$data$trunc_id) || stan_args$data$trunc_id == 0) {
if (is.null(stan_args$data$delay_id_truncation) ||
stan_args$data$delay_id_truncation == 0) {
return(invisible())
}

Expand All @@ -210,9 +211,11 @@ check_truncation_length <- function(stan_args, time_points) {

# Map truncation to its position in the flat delays array
# delay_types_groups gives start and end indices for each delay type
trunc_start <- stan_args$data$delay_types_groups[stan_args$data$trunc_id]
trunc_start <- stan_args$data$delay_types_groups[
stan_args$data$delay_id_truncation
]
trunc_end <- stan_args$data$delay_types_groups[
stan_args$data$trunc_id + 1
stan_args$data$delay_id_truncation + 1
] - 1

# Get which truncation delays are non-parametric
Expand Down
17 changes: 11 additions & 6 deletions R/create.R
Original file line number Diff line number Diff line change
Expand Up @@ -370,7 +370,8 @@ create_gp_data <- function(gp = gp_opts(), data) {
w0 = gp$w0
)

c(data, gp_data)
gp_data <- c(data, gp_data)
gp_data
}

#' Create Observation Model Settings
Expand Down Expand Up @@ -686,6 +687,8 @@ create_stan_args <- function(stan = stan_opts(),
##' @keywords internal
create_stan_delays <- function(..., time_points = 1L) {
delays <- list(...)
delay_names <- names(delays)

## discretise
delays <- map(delays, discretise, strict = FALSE)
delays <- map(delays, collapse)
Expand All @@ -694,10 +697,12 @@ create_stan_delays <- function(..., time_points = 1L) {
max_delay <- unname(as.numeric(flatten(map(bounded_delays, max))))
## number of different non-empty types
type_n <- vapply(delays, ndist, integer(1))
## assign ID values to each type
ids <- rep(0L, length(type_n))
ids[type_n > 0] <- seq_len(sum(type_n > 0))
names(ids) <- paste(names(type_n), "id", sep = "_")

## Create delay_id_* variables pointing to delay_types_groups index
## Similar to param_id_* in create_stan_params()
delay_ids <- rep(0L, length(type_n))
delay_ids[type_n > 0] <- seq_len(sum(type_n > 0))
names(delay_ids) <- paste("delay_id", delay_names, sep = "_")

## create "flat version" of delays, i.e. a list of all the delays (including
## elements of composite delays)
Expand Down Expand Up @@ -773,7 +778,7 @@ create_stan_delays <- function(..., time_points = 1L) {
ret$dist <- array(match(distributions, c("lognormal", "gamma")) - 1L)

names(ret) <- paste("delay", names(ret), sep = "_")
ret <- c(ret, ids)
ret <- c(ret, as.list(delay_ids))

ret
}
Expand Down
11 changes: 6 additions & 5 deletions R/estimate_infections.R
Original file line number Diff line number Diff line change
Expand Up @@ -73,8 +73,9 @@
#' @export
#' @return An `<estimate_infections>` object which is a list of outputs
#' including: the stan object (`fit`), arguments used to fit the model
#' (`args`), and the observed data (`observations`). The estimates included in
#' the fit can be accessed using the `summary()` function.
#' (`args`), and the observed data (`observations`). Use `summary()` to access
#' estimates, `get_samples()` to extract posterior samples, and
#' `get_predictions()` to access predicted reported cases.
#'
#' @seealso [epinow()] [regional_epinow()] [forecast_infections()]
#' [estimate_truncation()]
Expand Down Expand Up @@ -264,9 +265,9 @@ estimate_infections <- function(data,
)

stan_data <- c(stan_data, create_stan_delays(
gt = generation_time,
delay = delays,
trunc = truncation,
generation_time = generation_time,
reporting = delays,
truncation = truncation,
time_points = stan_data$t - stan_data$seeding_time - stan_data$horizon
))

Expand Down
115 changes: 87 additions & 28 deletions R/estimate_secondary.R
Original file line number Diff line number Diff line change
Expand Up @@ -56,11 +56,11 @@
#' @param verbose Logical, should model fitting progress be returned. Defaults
#' to [interactive()].
#'
#' @return A list containing: `predictions` (a `<data.frame>` ordered by date
#' with the primary, and secondary observations, and a summary of the model
#' estimated secondary observations), `posterior` which contains a summary of
#' the entire model posterior, `data` (a list of data used to fit the
#' model), and `fit` (the `stanfit` object).
#' @return An `<estimate_secondary>` object which is a list of outputs
#' including: the stan object (`fit`), arguments used to fit the model
#' (`args`), and the observed data (`observations`). Use `summary()` to access
#' estimates, `get_samples()` to extract posterior samples, and
#' `get_predictions()` to access predictions.
#' @export
#' @inheritParams estimate_infections
#' @inheritParams update_secondary_args
Expand Down Expand Up @@ -232,8 +232,8 @@ estimate_secondary <- function(data,
stan_data <- c(stan_data, secondary)
# delay data
stan_data <- c(stan_data, create_stan_delays(
delay = delays,
trunc = truncation,
reporting = delays,
truncation = truncation,
time_points = stan_data$t
))

Expand Down Expand Up @@ -266,22 +266,15 @@ estimate_secondary <- function(data,

fit <- fit_model(stan_, id = "estimate_secondary")

out <- list()
out$predictions <- extract_stan_param(fit, "sim_secondary", CrIs = CrIs)
out$predictions <- out$predictions[, lapply(.SD, round, 1)]
out$predictions <- out$predictions[, date := reports[(burn_in + 1):.N]$date]
out$predictions <- data.table::merge.data.table(
reports, out$predictions,
all = TRUE, by = "date"
)
out$posterior <- extract_stan_param(
fit,
CrIs = CrIs
# Create standardized S3 return structure
ret <- list(
fit = fit,
args = stan_data,
observations = reports
)
out$data <- stan_data
out$fit <- fit
class(out) <- c("estimate_secondary", class(out))
out

class(ret) <- c("epinowfit", "estimate_secondary", class(ret))
ret
}

#' Update estimate_secondary default priors
Expand Down Expand Up @@ -381,7 +374,7 @@ plot.estimate_secondary <- function(x, primary = FALSE,
from = NULL, to = NULL,
new_obs = NULL,
...) {
predictions <- data.table::copy(x$predictions)
predictions <- get_predictions(x)

if (!is.null(new_obs)) {
new_obs <- data.table::as.data.table(new_obs)
Expand Down Expand Up @@ -428,6 +421,23 @@ plot.estimate_secondary <- function(x, primary = FALSE,
ggplot2::theme(axis.text.x = ggplot2::element_text(angle = 90))
}

#' Plot method for forecast_secondary objects
#'
#' @description `r lifecycle::badge("stable")`
#' Plot method for forecast secondary observations.
#'
#' @inheritParams plot.estimate_secondary
#' @method plot forecast_secondary
#' @export
plot.forecast_secondary <- function(x, primary = FALSE,
from = NULL, to = NULL,
new_obs = NULL,
...) {
plot.estimate_secondary(
x, primary = primary, from = from, to = to, new_obs = new_obs, ...
)
}

#' Convolve and scale a time series
#'
#' This applies a lognormal convolution with given, potentially time-varying
Expand Down Expand Up @@ -623,7 +633,7 @@ forecast_secondary <- function(estimate,
if (inherits(primary, "estimate_infections")) {
primary_samples <- get_samples(primary)
primary <- primary_samples[variable == primary_variable]
primary <- primary[date > max(estimate$predictions$date, na.rm = TRUE)]
primary <- primary[date > max(get_predictions(estimate)$date, na.rm = TRUE)]
primary <- primary[, .(date, sample, value)]
if (!is.null(samples)) {
primary <- primary[sample(.N, samples, replace = TRUE)]
Expand All @@ -641,10 +651,10 @@ forecast_secondary <- function(estimate,
include = FALSE
)
# extract data from stanfit
stan_data <- estimate$data
stan_data <- estimate$args

# combined primary from data and input primary
primary_fit <- estimate$predictions[
primary_fit <- get_predictions(estimate)[
,
.(date, value = primary, sample = list(unique(updated_primary$sample)))
]
Expand Down Expand Up @@ -721,7 +731,7 @@ forecast_secondary <- function(estimate,
# link previous prediction observations with forecast observations
forecast_obs <- data.table::rbindlist(
list(
estimate$predictions[, .(date, primary, secondary)],
get_predictions(estimate)[, .(date, primary, secondary)],
data.table::copy(primary)[, .(primary = median(value)), by = "date"]
),
use.names = TRUE, fill = TRUE
Expand All @@ -735,6 +745,55 @@ forecast_secondary <- function(estimate,
data.table::setcolorder(
out$predictions, c("date", "primary", "secondary", "mean", "sd")
)
class(out) <- c("estimate_secondary", class(out))
class(out) <- c("forecast_secondary", class(out))
out
}

#' Extract elements from estimate_secondary objects with deprecated warnings
#'
#' @description `r lifecycle::badge("deprecated")`
#' Provides backward compatibility for the old return structure. The previous
#' structure with \code{predictions}, \code{posterior}, and \code{data}
#' elements is deprecated. Use the accessor methods instead:
#' \itemize{
#' \item \code{predictions} - use \code{get_predictions(object)}
#' \item \code{posterior} - use \code{get_samples(object)}
#' \item \code{data} - use \code{object$observations}
#' }
#'
#' @param x An \code{estimate_secondary} object
#' @param name The name of the element to extract
#' @return The requested element with a deprecation warning
#' @keywords internal
#' @export
#' @method $ estimate_secondary
`$.estimate_secondary` <- function(x, name) {
switch(name,
predictions = {
lifecycle::deprecate_warn(
"1.8.0",
"estimate_secondary()$predictions",
"get_predictions()"
)
get_predictions(x)
},
posterior = {
lifecycle::deprecate_warn(
"1.8.0",
"estimate_secondary()$posterior",
"get_samples()"
)
get_samples(x)
},
data = {
lifecycle::deprecate_warn(
"1.8.0",
"estimate_secondary()$data",
"estimate_secondary()$observations"
)
x[["observations"]]
},
# For other elements, use normal list extraction
NextMethod("$")
)
}
2 changes: 1 addition & 1 deletion R/estimate_truncation.R
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,7 @@ estimate_truncation <- function(data,
)

stan_data <- c(stan_data, create_stan_delays(
trunc = truncation,
truncation = truncation,
time_points = stan_data$t
))

Expand Down
Loading
Loading