diff --git a/DESCRIPTION b/DESCRIPTION index ce462430..6bec29ee 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -56,5 +56,5 @@ LazyData: false URL: https://mc-stan.org/posterior/, https://discourse.mc-stan.org/ BugReports: https://github.com/stan-dev/posterior/issues Roxygen: list(markdown = TRUE) -RoxygenNote: 7.3.0 +RoxygenNote: 7.3.1 VignetteBuilder: knitr diff --git a/NAMESPACE b/NAMESPACE index 62cf616b..0426a993 100755 --- a/NAMESPACE +++ b/NAMESPACE @@ -248,6 +248,7 @@ S3method(print,draws_list) S3method(print,draws_matrix) S3method(print,draws_rvars) S3method(print,draws_summary) +S3method(print,rollup_summary) S3method(print,rvar) S3method(prod,rvar) S3method(quantile,rvar) @@ -287,6 +288,10 @@ S3method(rhat_basic,default) S3method(rhat_basic,rvar) S3method(rhat_nested,default) S3method(rhat_nested,rvar) +S3method(rollup_summary,data.frame) +S3method(rollup_summary,default) +S3method(rollup_summary,draws) +S3method(rollup_summary,rollup_summary) S3method(sd,default) S3method(sd,rvar) S3method(split_chains,draws) @@ -420,6 +425,7 @@ export(cdf) export(chain_ids) export(default_convergence_measures) export(default_mcse_measures) +export(default_rollups) export(default_summary_measures) export(diag) export(dissent) @@ -485,6 +491,7 @@ export(rfun) export(rhat) export(rhat_basic) export(rhat_nested) +export(rollup_summary) export(rstar) export(rvar) export(rvar_all) diff --git a/NEWS.md b/NEWS.md index 03f5a1a8..e10ff33f 100644 --- a/NEWS.md +++ b/NEWS.md @@ -15,6 +15,8 @@ * For types that support `factor` variables (`draws_df`, `draws_list`, and `draws_rvars`), `extract_variable()` and `extract_variable_matrix()` can now return `factor`s. +* Add `rollup_summary()` function for rolling up summaries of variables with + indices (#43). # posterior 1.5.0 diff --git a/R/as_draws_rvars.R b/R/as_draws_rvars.R index fd7a558a..506a5581 100755 --- a/R/as_draws_rvars.R +++ b/R/as_draws_rvars.R @@ -84,8 +84,7 @@ as_draws_rvars.draws_matrix <- function(x, ...) { # first, pull out the list of indices into a data frame # where each column is an index variable - indices <- as.data.frame(do.call(rbind, split_indices(var$indices)), - stringsAsFactors = FALSE) + indices <- split_indices_to_df(var$indices) unique_indices <- vector("list", length(indices)) .dimnames <- vector("list", length(indices)) names(unique_indices) <- names(indices) diff --git a/R/misc.R b/R/misc.R index c1bce76c..8f07274a 100644 --- a/R/misc.R +++ b/R/misc.R @@ -136,6 +136,7 @@ move_to_start <- function(x, start) { # prettily deparse an expression # @return a single character string deparse_pretty <- function(x, max_chars = NULL, max_wsp = 1L) { + if (rlang::is_quosure(x)) x <- rlang::get_expr(x) out <- collapse(deparse(x)) out <- rm_wsp(out, max_wsp) assert_int(max_chars, null.ok = TRUE) diff --git a/R/rollup_summary.R b/R/rollup_summary.R new file mode 100644 index 00000000..80afa78d --- /dev/null +++ b/R/rollup_summary.R @@ -0,0 +1,257 @@ +#' Roll up `draws_summary` objects by collapsing summaries of non-scalar parameters. +#' +#' Roll up summaries of draws (e.g. as returned by [summarise_draws()]); that +#' is, summarise the summaries. By default, summaries of all variables containing +#' indices (e.g. `"x[1]"`) are rolled up, but the `variable` parameter can be +#' used to roll up specific variables only. +#' +#' @param .x (multiple options) The object containing summaries to roll up. One of: +#' - a [`draws_summary`] object such as produced by [summarise_draws()]. +#' - a `data.frame` with a `"variable"` column giving the names of variables, +#' where all other columns are numeric summaries of those variables. +#' - an object with a [summarise_draws()] method, such as a [`draws`] object, +#' in which case [summarise_draws()] will be called on `.x` and the result +#' will be rolled up. +#' - a [`rollup_summary`] object such as produced by `rollup_summary()`, in +#' which case variables that have not been rolled up yet may be rolled up. +#' @param ... (multiple options) arguments where the name of each argument is a +#' summary measure (i.e. column) in `.x` and the value is the rollup functions +#' to apply to that summary measure, specified as one of: +#' - bare name of a function +#' - a character vector of function names (optionally named). +#' - a function formula, as accepted by [rlang::as_function()]. +#' - a named list of any of the above. +#' +#' Unnamed arguments in `...` specify default rollup functions to apply to all +#' summary measures that do not have specific rollup functions given in `...`. +#' @param variable (character vector) base names (without indices) of variables +#' to roll up. If `NULL` (the default), all variables with indices in their names +#' (e.g. `"x[1,2]"`) will be rolled up. +#' @param .funs (list) named list where names are summary measures in `.x` +#' and values are the default rollup functions to apply to those summary +#' measures, unless overridden by `...`. As in `...`, unnamed elements of this +#' list give default rollup functions to apply to summary measures that do not +#' have specific rollup functions given in `.funs`. +#' @details +#' If called without specifying additional rollup functions in `...`, +#' `rollup_summary()` will apply the default rollup functions as determined by +#' `.funs` to the columns in `.x` (or, if `.x` is not a `data.frame`, to the +#' result of `summarise_draws(.x)`). +#' +#' The default value of `.funs` provides several default rollup functions +#' that will be applied to specific summary measures, unless this is overridden +#' by entries in `...`. For example, `ess_bulk` has the default +#' rollup function `"min"` instead of `c("min", "max")`, as the minimum +#' effective sample size is likely of more interest than the maximum. +#' `default_rollups()` gives the complete list of default rollup functions. +#' +#' Calls to `rollup_summary()` can be chained, in which case subsequent +#' rollups will be applied only to variables that have not already been +#' rolled up (i.e. the `"unrolled"` element; see the description of +#' `rollup_summary` objects below). This makes it possible to provide different +#' rollup functions for different variables by calling `rollup_summary()` +#' multiple times with different values of the `variable` parameter. +#' @returns +#' A `rollup_summary` object, which is a named list of [`draws_summary`] objects: +#' - `"unrolled"`: a [`draws_summary`] of the variables that were not rolled up. +#' - `"rolled"`: a [`draws_summary`] of the rolled-up variables. The second +#' column of this data frame, `"dim"`, gives the lengths of the dimensions +#' of each rolled up variable as a comma-separated character vector. The +#' remaining columns give the rollups of each summary measure; e.g. if `x` +#' contained a summary measure `"mean"` and it was rolled up using the `"min"` +#' and `"max"` functions (the default), the output will have a `"mean_min"` +#' and `"mean_max"` column. +#' @examples +#' x <- example_draws() +#' +#' # default summaries show a row for every element in array-like variables +#' summarise_draws(x) +#' +#' # you can roll up summaries of array-like variables by rolling up draws +#' # objects directly; this will apply the default options of summarise_draws() +#' rollup_summary(x) +#' +#' # or summarise draws objects first to pick the desired summary measures +#' # (note that ess_bulk is only rolled up using min by default; see the +#' # .funs parameter) +#' ds <- summarise_draws(x, "mean", "sd", "ess_bulk") +#' rollup_summary(ds) +#' +#' # rollups work on variables of any dimension +#' x <- example_draws(example = "multi_normal") +#' rollup_summary(x) +#' +#' # you can roll up only some variables +#' rollup_summary(x, variable = "Sigma") +#' +#' # you can specify the rollup functions to apply to all summaries by passing +#' # unnamed parameters ... +#' rollup_summary(x, "mean", "min") +#' +#' # ... or use names to specify rollup functions for specific summaries +#' rollup_summary(x, mean = "sd", median = "min") +#' +#' # you can pass parameters to rollup functions using anonymous functions +#' x2 <- draws_rvars(x = c(rvar_rng(rnorm, 5), NA)) +#' rollup_summary(x2, list(min = function(x) min(x, na.rm = TRUE))) +#' +#' # rollups can be chained to provide different rollup functions to +#' # different variables +#' ds <- summarise_draws(x, "mean", "sd") +#' rs <- rollup_summary(ds, variable = "mu", sd = "min") +#' rs <- rollup_summary(rs, variable = "Sigma", sd = "max") +#' rs +#' @export +rollup_summary <- function(.x, ...) { + UseMethod("rollup_summary") +} + +#' @rdname rollup_summary +#' @export +rollup_summary.default <- function(.x, ...) { + rollup_summary(summarise_draws(.x), ...) +} + +#' @rdname rollup_summary +#' @export +rollup_summary.draws <- function(.x, ...) { + rollup_summary(summarise_draws(.x), ...) +} + +#' @rdname rollup_summary +#' @export +rollup_summary.data.frame <- function ( + .x, + ..., + variable = NULL, + .funs = default_rollups() +) { + assert_multi_class(.x$variable, c("character", "factor")) + assert_character(variable, null.ok = TRUE) + assert_list(.funs, null.ok = TRUE) + + rollup_funs <- lapply(rlang::enquos0(...), create_function_list) + default_rollup_funs <- lapply(.funs, create_function_list) + + is_unnamed <- rlang::names2(rollup_funs) == "" + if (any(is_unnamed)) { + # user provided unnamed functions in dots, use these for summary measures + # that otherwise don't have a rollup function specified + unspecified_rollup_funs <- do.call(c, rollup_funs[is_unnamed]) + rollup_funs <- rollup_funs[!is_unnamed] + } else { + # use the default unspecified rollup funs + is_unnamed <- rlang::names2(default_rollup_funs) == "" + unspecified_rollup_funs <- do.call(c, default_rollup_funs[is_unnamed]) + default_rollup_funs <- default_rollup_funs[!is_unnamed] + + # apply the measure-specific default rollup functions to any columns not + # overridden by the user + missing_default_funs <- setdiff(names(default_rollup_funs), names(rollup_funs)) + rollup_funs[missing_default_funs] <- default_rollup_funs[missing_default_funs] + } + + # apply the generic default rollup functions to any remaining unspecified columns + rollup_funs[setdiff(names(.x), names(rollup_funs))] <- list(unspecified_rollup_funs) + + # determine the variables to roll up + vars <- split_variable_names(.x$variable) + if (is.null(variable)) { + rollup_rows <- nzchar(vars$indices) + } else { + rollup_rows <- vars$base_name %in% variable + } + variable_col <- which(names(.x) == "variable") + vars <- vars[rollup_rows, ] + + # split the input df by variable base name and roll up the summaries + var_groups <- vctrs::vec_split(cbind(vars, .x[rollup_rows, -variable_col, drop = FALSE]), vars$base_name) + rolled_up_vars <- lapply(var_groups$val, function(x) { + indices <- split_indices_to_df(x$indices) + rolled_up_cols <- do.call(cbind, lapply(seq_along(x)[c(-1,-2)], function(col_i) { + col <- x[[col_i]] + col_name <- names(x)[[col_i]] + rolled_up_col <- lapply(rollup_funs[[col_name]], function(f) f(col)) + names(rolled_up_col) <- sprintf("%s_%s", col_name, names(rolled_up_col)) + vctrs::new_data_frame(rolled_up_col, n = 1L) + })) + cbind( + variable = x$base_name[[1]], + dim = paste0(lengths(lapply(indices, unique)), collapse = ","), + rolled_up_cols, + stringsAsFactors = FALSE + ) + }) + + new_rollup_summary( + unrolled = .x[!rollup_rows, , drop = FALSE], + rolled = do.call(rbind, rolled_up_vars) + ) +} + +#' @rdname rollup_summary +#' @export +rollup_summary.rollup_summary <- function (.x, ...) { + out <- rollup_summary(.x$unrolled, ...) + new_rollup_summary( + unrolled = out$unrolled, + rolled = vctrs::vec_rbind(.x$rolled, out$rolled) + ) +} + +new_rollup_summary <- function(unrolled, rolled) { + assert_data_frame(unrolled) + if (!inherits(unrolled, "draws_summary")) class(unrolled) <- class_draws_summary() + assert_data_frame(rolled) + if (!inherits(rolled, "draws_summary")) class(rolled) <- class_draws_summary() + + structure( + list(unrolled = unrolled, rolled = rolled), + class = class_rollup_summary() + ) +} + +class_rollup_summary <- function() { + c("rollup_summary", "list") +} + +#' @export +print.rollup_summary <- function(x, ..., color = TRUE) { + color <- as_one_logical(color) + if (color) { + subtle <- pillar::style_subtle + } else { + subtle <- identity + } + + cat(":\n\n") + if (NROW(x$unrolled) > 0) { + cat("$unrolled", subtle("(variables that have not been rolled up):"), "\n") + print(x$unrolled, ...) + cat("\n") + } + if (NROW(x$rolled) > 0) { + cat("$rolled", subtle("(variables that have been rolled up):"), "\n") + print(x$rolled, ...) + cat("\n") + } + invisible(x) +} + +#' @rdname rollup_summary +#' @export +default_rollups <- function() { + list( + c("min", "max"), + ess_basic = "min", + ess_bulk = "min", + ess_mean = "min", + ess_median = "min", + ess_quantile = "min", + ess_sd = "min", + ess_tail = "min", + rhat = "max", + rhat_basic = "max", + rhat_nested = "max" + ) +} diff --git a/R/summarise_draws.R b/R/summarise_draws.R index 6f13755a..34e579c7 100644 --- a/R/summarise_draws.R +++ b/R/summarise_draws.R @@ -117,41 +117,9 @@ summarise_draws.draws <- function( if (.cores <= 0) { stop_no_call("'.cores' must be a positive integer.") } - funs <- as.list(c(...)) .args <- as.list(.args) - if (length(funs)) { - if (is.null(names(funs))) { - # ensure names are initialized properly - names(funs) <- rep("", length(funs)) - } - calls <- substitute(list(...))[-1] - calls <- ulapply(calls, deparse_pretty) - for (i in seq_along(funs)) { - fname <- NULL - if (is.character(funs[[i]])) { - fname <- as_one_character(funs[[i]]) - } - # label unnamed arguments via their calls - if (!nzchar(names(funs)[i])) { - if (!is.null(fname)) { - names(funs)[i] <- fname - } else { - names(funs)[i] <- calls[i] - } - } - # get functions passed as strings from the right environments - if (!is.null(fname)) { - if (exists(fname, envir = caller_env())) { - env <- caller_env() - } else if (fname %in% getNamespaceExports("posterior")) { - env <- asNamespace("posterior") - } else { - stop_no_call("Cannot find function '", fname, "'.") - } - } - funs[[i]] <- rlang::as_function(funs[[i]], env = env) - } - } else { + funs <- create_function_list(rlang::enquos0(...)) + if (length(funs) == 0) { # default functions funs <- list( mean = base::mean, @@ -196,23 +164,6 @@ summarise_draws.draws <- function( if (checkmate::test_os("windows")) { cl <- parallel::makePSOCKcluster(.cores) on.exit(parallel::stopCluster(cl)) - # exporting all these functions seems to be required to - # pass GitHub actions checks on Windows - parallel::clusterExport( - cl, - varlist = package_function_names("posterior"), - envir = as.environment(asNamespace("posterior")) - ) - parallel::clusterExport( - cl, - varlist = package_function_names("checkmate"), - envir = as.environment(asNamespace("checkmate")) - ) - parallel::clusterExport( - cl, - varlist = package_function_names("rlang"), - envir = as.environment(asNamespace("rlang")) - ) summary_list <- parallel::parLapply( cl, X = chunk_list, @@ -327,6 +278,61 @@ empty_draws_summary <- function(dimensions = "variable") { } +#' convert a specification for a list of functions (in various formats) into a +#' named list of functions +#' @param fun_exprs One of: +#' - a function. +#' - a character vector of names of functions that can be found either in `env` +#' or in the \pkg{posterior} namespace. +#' - an unevaluated expression or a quosure that represents a function +#' - an \pkg{rlang} function formula (a la [rlang::as_function()]). +#' - a list where each element is of the above. +#' @param env the environment to evaluate expressions in and to go searching for +#' functions specified as strings in. +#' @returns a named list of functions in `fun_expres` +#' @noRd +create_function_list <- function(fun_exprs, env = caller_env(2)) { + # flatten fun_exprs into two lists: funs, a list of functions/strings/formulas, + # and fun_exprs, a list of bare expressions or quosures + if (!is.list(fun_exprs)) fun_exprs <- list(fun_exprs) + funs <- lapply(fun_exprs, eval_tidy, env = env) + fun_exprs <- rep(fun_exprs, lengths(funs)) + funs <- as.list(do.call(c, funs)) + + if (is.null(names(funs))) { + # ensure names are initialized properly + names(funs) <- rep("", length(funs)) + } + + for (i in seq_along(funs)) { + fname <- NULL + if (is.character(funs[[i]])) { + fname <- as_one_character(funs[[i]]) + } + + # label unnamed arguments via their calls + if (!nzchar(names(funs)[i])) { + if (!is.null(fname)) { + names(funs)[i] <- fname + } else { + names(funs)[i] <- deparse_pretty(fun_exprs[[i]]) + } + } + + # get the environment to find functions passed as strings in + env_i <- env + if (!is.null(fname) && !exists(fname, envir = env_i, mode = "function")) { + # if the function isn't in the calling environment fall back to the package + env_i <- asNamespace("posterior") + } + + funs[[i]] <- rlang::as_function(funs[[i]], env = env_i) + } + + names(funs) <- make.unique(names(funs)) + funs +} + create_summary_list <- function(x, v, funs, .args) { draws <- drop_dims_or_classes(x[, , v], dims = 3, reset_class = FALSE) args <- c(list(draws), .args) diff --git a/R/variable-indices.R b/R/variable-indices.R index 7965044f..20b25853 100644 --- a/R/variable-indices.R +++ b/R/variable-indices.R @@ -114,6 +114,23 @@ split_indices <- function(x) { strsplit(substr(x, 2, nchar(x) - 1), ",", fixed = TRUE) } +#' Given a vector of index strings of all the same length (such as returned by +#' `split_variable_names(x)$indices` for all variables with the same base name), +#' split each index string into a character vector of indices and form them +#' into a data frame. +#' @param x a character vector of index strings that all have the same number +#' of dimensions, +#' e.g. `c("[1,1]", "[1,2]", "[1,3]")` +#' @returns a data frame with `length(x)` rows and number of columns equal to +#' the number of dimensions in the variables. Throws an error if the number +#' of dimensions are not all equal. +#' @noRd +split_indices_to_df <- function(x) { + indices = split_indices(x) + stopifnot(all(lengths(indices) == lengths(indices[1]))) + as.data.frame(do.call(rbind, indices), stringsAsFactors = FALSE) +} + # manipulating flattened variable names ----------------------------------- diff --git a/man/draws_array.Rd b/man/draws_array.Rd index 359703ca..941cf85e 100755 --- a/man/draws_array.Rd +++ b/man/draws_array.Rd @@ -74,10 +74,10 @@ str(x2) } \seealso{ Other formats: +\code{\link{draws}}, \code{\link{draws_df}()}, \code{\link{draws_list}()}, \code{\link{draws_matrix}()}, -\code{\link{draws_rvars}()}, -\code{\link{draws}} +\code{\link{draws_rvars}()} } \concept{formats} diff --git a/man/draws_df.Rd b/man/draws_df.Rd index a34daffb..bbfc8206 100755 --- a/man/draws_df.Rd +++ b/man/draws_df.Rd @@ -96,10 +96,10 @@ print(xnew) } \seealso{ Other formats: +\code{\link{draws}}, \code{\link{draws_array}()}, \code{\link{draws_list}()}, \code{\link{draws_matrix}()}, -\code{\link{draws_rvars}()}, -\code{\link{draws}} +\code{\link{draws_rvars}()} } \concept{formats} diff --git a/man/draws_list.Rd b/man/draws_list.Rd index f45526a8..b696350e 100755 --- a/man/draws_list.Rd +++ b/man/draws_list.Rd @@ -76,10 +76,10 @@ str(x2) } \seealso{ Other formats: +\code{\link{draws}}, \code{\link{draws_array}()}, \code{\link{draws_df}()}, \code{\link{draws_matrix}()}, -\code{\link{draws_rvars}()}, -\code{\link{draws}} +\code{\link{draws_rvars}()} } \concept{formats} diff --git a/man/draws_matrix.Rd b/man/draws_matrix.Rd index 1b548412..432e3f9b 100755 --- a/man/draws_matrix.Rd +++ b/man/draws_matrix.Rd @@ -74,10 +74,10 @@ str(x2) } \seealso{ Other formats: +\code{\link{draws}}, \code{\link{draws_array}()}, \code{\link{draws_df}()}, \code{\link{draws_list}()}, -\code{\link{draws_rvars}()}, -\code{\link{draws}} +\code{\link{draws_rvars}()} } \concept{formats} diff --git a/man/draws_rvars.Rd b/man/draws_rvars.Rd index 0e4b614a..28786629 100755 --- a/man/draws_rvars.Rd +++ b/man/draws_rvars.Rd @@ -77,10 +77,10 @@ str(x2) } \seealso{ Other formats: +\code{\link{draws}}, \code{\link{draws_array}()}, \code{\link{draws_df}()}, \code{\link{draws_list}()}, -\code{\link{draws_matrix}()}, -\code{\link{draws}} +\code{\link{draws_matrix}()} } \concept{formats} diff --git a/man/ess_basic.Rd b/man/ess_basic.Rd index 867076ca..47aeec0e 100755 --- a/man/ess_basic.Rd +++ b/man/ess_basic.Rd @@ -81,9 +81,9 @@ Other diagnostics: \code{\link{mcse_sd}()}, \code{\link{pareto_diags}()}, \code{\link{pareto_khat}()}, +\code{\link{rhat}()}, \code{\link{rhat_basic}()}, \code{\link{rhat_nested}()}, -\code{\link{rhat}()}, \code{\link{rstar}()} } \concept{diagnostics} diff --git a/man/ess_bulk.Rd b/man/ess_bulk.Rd index c1456be3..8baada5c 100755 --- a/man/ess_bulk.Rd +++ b/man/ess_bulk.Rd @@ -74,9 +74,9 @@ Other diagnostics: \code{\link{mcse_sd}()}, \code{\link{pareto_diags}()}, \code{\link{pareto_khat}()}, +\code{\link{rhat}()}, \code{\link{rhat_basic}()}, \code{\link{rhat_nested}()}, -\code{\link{rhat}()}, \code{\link{rstar}()} } \concept{diagnostics} diff --git a/man/ess_quantile.Rd b/man/ess_quantile.Rd index aa85c909..f919ad60 100755 --- a/man/ess_quantile.Rd +++ b/man/ess_quantile.Rd @@ -83,9 +83,9 @@ Other diagnostics: \code{\link{mcse_sd}()}, \code{\link{pareto_diags}()}, \code{\link{pareto_khat}()}, +\code{\link{rhat}()}, \code{\link{rhat_basic}()}, \code{\link{rhat_nested}()}, -\code{\link{rhat}()}, \code{\link{rstar}()} } \concept{diagnostics} diff --git a/man/ess_sd.Rd b/man/ess_sd.Rd index 38475d2a..91278bf8 100755 --- a/man/ess_sd.Rd +++ b/man/ess_sd.Rd @@ -68,9 +68,9 @@ Other diagnostics: \code{\link{mcse_sd}()}, \code{\link{pareto_diags}()}, \code{\link{pareto_khat}()}, +\code{\link{rhat}()}, \code{\link{rhat_basic}()}, \code{\link{rhat_nested}()}, -\code{\link{rhat}()}, \code{\link{rstar}()} } \concept{diagnostics} diff --git a/man/ess_tail.Rd b/man/ess_tail.Rd index 8f959718..36b2772c 100755 --- a/man/ess_tail.Rd +++ b/man/ess_tail.Rd @@ -74,9 +74,9 @@ Other diagnostics: \code{\link{mcse_sd}()}, \code{\link{pareto_diags}()}, \code{\link{pareto_khat}()}, +\code{\link{rhat}()}, \code{\link{rhat_basic}()}, \code{\link{rhat_nested}()}, -\code{\link{rhat}()}, \code{\link{rstar}()} } \concept{diagnostics} diff --git a/man/extract_variable_array.Rd b/man/extract_variable_array.Rd index 348a1a74..2c1eae76 100644 --- a/man/extract_variable_array.Rd +++ b/man/extract_variable_array.Rd @@ -45,7 +45,7 @@ str(Sigma) } \seealso{ Other variable extraction methods: -\code{\link{extract_variable_matrix}()}, -\code{\link{extract_variable}()} +\code{\link{extract_variable}()}, +\code{\link{extract_variable_matrix}()} } \concept{variable extraction methods} diff --git a/man/extract_variable_matrix.Rd b/man/extract_variable_matrix.Rd index 1b9c97c1..dedb49c1 100644 --- a/man/extract_variable_matrix.Rd +++ b/man/extract_variable_matrix.Rd @@ -47,7 +47,7 @@ rhat(mu) } \seealso{ Other variable extraction methods: -\code{\link{extract_variable_array}()}, -\code{\link{extract_variable}()} +\code{\link{extract_variable}()}, +\code{\link{extract_variable_array}()} } \concept{variable extraction methods} diff --git a/man/mcse_mean.Rd b/man/mcse_mean.Rd index c75935b1..65828211 100755 --- a/man/mcse_mean.Rd +++ b/man/mcse_mean.Rd @@ -65,9 +65,9 @@ Other diagnostics: \code{\link{mcse_sd}()}, \code{\link{pareto_diags}()}, \code{\link{pareto_khat}()}, +\code{\link{rhat}()}, \code{\link{rhat_basic}()}, \code{\link{rhat_nested}()}, -\code{\link{rhat}()}, \code{\link{rstar}()} } \concept{diagnostics} diff --git a/man/mcse_quantile.Rd b/man/mcse_quantile.Rd index 2d05f626..4651181e 100755 --- a/man/mcse_quantile.Rd +++ b/man/mcse_quantile.Rd @@ -80,9 +80,9 @@ Other diagnostics: \code{\link{mcse_sd}()}, \code{\link{pareto_diags}()}, \code{\link{pareto_khat}()}, +\code{\link{rhat}()}, \code{\link{rhat_basic}()}, \code{\link{rhat_nested}()}, -\code{\link{rhat}()}, \code{\link{rstar}()} } \concept{diagnostics} diff --git a/man/mcse_sd.Rd b/man/mcse_sd.Rd index 671ef249..02f86ebd 100755 --- a/man/mcse_sd.Rd +++ b/man/mcse_sd.Rd @@ -70,9 +70,9 @@ Other diagnostics: \code{\link{mcse_quantile}()}, \code{\link{pareto_diags}()}, \code{\link{pareto_khat}()}, +\code{\link{rhat}()}, \code{\link{rhat_basic}()}, \code{\link{rhat_nested}()}, -\code{\link{rhat}()}, \code{\link{rstar}()} } \concept{diagnostics} diff --git a/man/pareto_diags.Rd b/man/pareto_diags.Rd index 46370c49..49247293 100644 --- a/man/pareto_diags.Rd +++ b/man/pareto_diags.Rd @@ -150,9 +150,9 @@ Other diagnostics: \code{\link{mcse_quantile}()}, \code{\link{mcse_sd}()}, \code{\link{pareto_khat}()}, +\code{\link{rhat}()}, \code{\link{rhat_basic}()}, \code{\link{rhat_nested}()}, -\code{\link{rhat}()}, \code{\link{rstar}()} } \concept{diagnostics} diff --git a/man/pareto_khat.Rd b/man/pareto_khat.Rd index a4f91707..6234abf9 100644 --- a/man/pareto_khat.Rd +++ b/man/pareto_khat.Rd @@ -92,9 +92,9 @@ Other diagnostics: \code{\link{mcse_quantile}()}, \code{\link{mcse_sd}()}, \code{\link{pareto_diags}()}, +\code{\link{rhat}()}, \code{\link{rhat_basic}()}, \code{\link{rhat_nested}()}, -\code{\link{rhat}()}, \code{\link{rstar}()} } \concept{diagnostics} diff --git a/man/rhat_basic.Rd b/man/rhat_basic.Rd index 16ffd332..762cce0f 100755 --- a/man/rhat_basic.Rd +++ b/man/rhat_basic.Rd @@ -77,8 +77,8 @@ Other diagnostics: \code{\link{mcse_sd}()}, \code{\link{pareto_diags}()}, \code{\link{pareto_khat}()}, -\code{\link{rhat_nested}()}, \code{\link{rhat}()}, +\code{\link{rhat_nested}()}, \code{\link{rstar}()} } \concept{diagnostics} diff --git a/man/rhat_nested.Rd b/man/rhat_nested.Rd index f2536efd..9f91ad05 100644 --- a/man/rhat_nested.Rd +++ b/man/rhat_nested.Rd @@ -85,8 +85,8 @@ Other diagnostics: \code{\link{mcse_sd}()}, \code{\link{pareto_diags}()}, \code{\link{pareto_khat}()}, -\code{\link{rhat_basic}()}, \code{\link{rhat}()}, +\code{\link{rhat_basic}()}, \code{\link{rstar}()} } \concept{diagnostics} diff --git a/man/rollup_summary.Rd b/man/rollup_summary.Rd new file mode 100644 index 00000000..e9c0d864 --- /dev/null +++ b/man/rollup_summary.Rd @@ -0,0 +1,139 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/rollup_summary.R +\name{rollup_summary} +\alias{rollup_summary} +\alias{rollup_summary.default} +\alias{rollup_summary.draws} +\alias{rollup_summary.data.frame} +\alias{rollup_summary.rollup_summary} +\alias{default_rollups} +\title{Roll up \code{draws_summary} objects by collapsing summaries of non-scalar parameters.} +\usage{ +rollup_summary(.x, ...) + +\method{rollup_summary}{default}(.x, ...) + +\method{rollup_summary}{draws}(.x, ...) + +\method{rollup_summary}{data.frame}(.x, ..., variable = NULL, .funs = default_rollups()) + +\method{rollup_summary}{rollup_summary}(.x, ...) + +default_rollups() +} +\arguments{ +\item{.x}{(multiple options) The object containing summaries to roll up. One of: +\itemize{ +\item a \code{\link{draws_summary}} object such as produced by \code{\link[=summarise_draws]{summarise_draws()}}. +\item a \code{data.frame} with a \code{"variable"} column giving the names of variables, +where all other columns are numeric summaries of those variables. +\item an object with a \code{\link[=summarise_draws]{summarise_draws()}} method, such as a \code{\link{draws}} object, +in which case \code{\link[=summarise_draws]{summarise_draws()}} will be called on \code{.x} and the result +will be rolled up. +\item a \code{\link{rollup_summary}} object such as produced by \code{rollup_summary()}, in +which case variables that have not been rolled up yet may be rolled up. +}} + +\item{...}{(multiple options) arguments where the name of each argument is a +summary measure (i.e. column) in \code{.x} and the value is the rollup functions +to apply to that summary measure, specified as one of: +\itemize{ +\item bare name of a function +\item a character vector of function names (optionally named). +\item a function formula, as accepted by \code{\link[rlang:as_function]{rlang::as_function()}}. +\item a named list of any of the above. +} + +Unnamed arguments in \code{...} specify default rollup functions to apply to all +summary measures that do not have specific rollup functions given in \code{...}.} + +\item{variable}{(character vector) base names (without indices) of variables +to roll up. If \code{NULL} (the default), all variables with indices in their names +(e.g. \code{"x[1,2]"}) will be rolled up.} + +\item{.funs}{(list) named list where names are summary measures in \code{.x} +and values are the default rollup functions to apply to those summary +measures, unless overridden by \code{...}. As in \code{...}, unnamed elements of this +list give default rollup functions to apply to summary measures that do not +have specific rollup functions given in \code{.funs}.} +} +\value{ +A \code{rollup_summary} object, which is a named list of \code{\link{draws_summary}} objects: +\itemize{ +\item \code{"unrolled"}: a \code{\link{draws_summary}} of the variables that were not rolled up. +\item \code{"rolled"}: a \code{\link{draws_summary}} of the rolled-up variables. The second +column of this data frame, \code{"dim"}, gives the lengths of the dimensions +of each rolled up variable as a comma-separated character vector. The +remaining columns give the rollups of each summary measure; e.g. if \code{x} +contained a summary measure \code{"mean"} and it was rolled up using the \code{"min"} +and \code{"max"} functions (the default), the output will have a \code{"mean_min"} +and \code{"mean_max"} column. +} +} +\description{ +Roll up summaries of draws (e.g. as returned by \code{\link[=summarise_draws]{summarise_draws()}}); that +is, summarise the summaries. By default, summaries of all variables containing +indices (e.g. \code{"x[1]"}) are rolled up, but the \code{variable} parameter can be +used to roll up specific variables only. +} +\details{ +If called without specifying additional rollup functions in \code{...}, +\code{rollup_summary()} will apply the default rollup functions as determined by +\code{.funs} to the columns in \code{.x} (or, if \code{.x} is not a \code{data.frame}, to the +result of \code{summarise_draws(.x)}). + +The default value of \code{.funs} provides several default rollup functions +that will be applied to specific summary measures, unless this is overridden +by entries in \code{...}. For example, \code{ess_bulk} has the default +rollup function \code{"min"} instead of \code{c("min", "max")}, as the minimum +effective sample size is likely of more interest than the maximum. +\code{default_rollups()} gives the complete list of default rollup functions. + +Calls to \code{rollup_summary()} can be chained, in which case subsequent +rollups will be applied only to variables that have not already been +rolled up (i.e. the \code{"unrolled"} element; see the description of +\code{rollup_summary} objects below). This makes it possible to provide different +rollup functions for different variables by calling \code{rollup_summary()} +multiple times with different values of the \code{variable} parameter. +} +\examples{ +x <- example_draws() + +# default summaries show a row for every element in array-like variables +summarise_draws(x) + +# you can roll up summaries of array-like variables by rolling up draws +# objects directly; this will apply the default options of summarise_draws() +rollup_summary(x) + +# or summarise draws objects first to pick the desired summary measures +# (note that ess_bulk is only rolled up using min by default; see the +# .funs parameter) +ds <- summarise_draws(x, "mean", "sd", "ess_bulk") +rollup_summary(ds) + +# rollups work on variables of any dimension +x <- example_draws(example = "multi_normal") +rollup_summary(x) + +# you can roll up only some variables +rollup_summary(x, variable = "Sigma") + +# you can specify the rollup functions to apply to all summaries by passing +# unnamed parameters ... +rollup_summary(x, "mean", "min") + +# ... or use names to specify rollup functions for specific summaries +rollup_summary(x, mean = "sd", median = "min") + +# you can pass parameters to rollup functions using anonymous functions +x2 <- draws_rvars(x = c(rvar_rng(rnorm, 5), NA)) +rollup_summary(x2, list(min = function(x) min(x, na.rm = TRUE))) + +# rollups can be chained to provide different rollup functions to +# different variables +ds <- summarise_draws(x, "mean", "sd") +rs <- rollup_summary(ds, variable = "mu", sd = "min") +rs <- rollup_summary(rs, variable = "Sigma", sd = "max") +rs +} diff --git a/man/rstar.Rd b/man/rstar.Rd index c9479902..f2e74898 100644 --- a/man/rstar.Rd +++ b/man/rstar.Rd @@ -117,8 +117,8 @@ Other diagnostics: \code{\link{mcse_sd}()}, \code{\link{pareto_diags}()}, \code{\link{pareto_khat}()}, +\code{\link{rhat}()}, \code{\link{rhat_basic}()}, -\code{\link{rhat_nested}()}, -\code{\link{rhat}()} +\code{\link{rhat_nested}()} } \concept{diagnostics} diff --git a/tests/testthat/test-rollup_summary.R b/tests/testthat/test-rollup_summary.R new file mode 100644 index 00000000..f1d8b69b --- /dev/null +++ b/tests/testthat/test-rollup_summary.R @@ -0,0 +1,110 @@ +test_that("rollup_summary works correctly", { + set.seed(1234) + x_array <- as_draws_array(example_draws(example = "multi_normal")) + x_array <- mutate_variables(x_array, y = rnorm(ndraws(x_array))) + x <- as_draws_df(x_array) + + sum_x <- summarise_draws(x) + rollup <- rollup_summary(sum_x) + expect_equal(rollup, rollup_summary(sum_x)) + expect_equal(rollup, rollup_summary(x_array)) + + sum_x <- summarise_draws(x, "mean", "sd") + + rollup <- rollup_summary(sum_x) + expect_equal(rollup$unrolled, sum_x[sum_x$variable == "y", ]) + expect_equal(rollup$rolled$variable, c("mu", "Sigma")) + expect_equal(rollup$rolled$dim, c("3", "3,3")) + expect_equal(names(rollup$rolled), c("variable", "dim", "mean_min", "mean_max", "sd_min", "sd_max")) + expect_equal(rollup$rolled$mean_max[1], max(sum_x[startsWith(sum_x$variable, "mu"),"mean"])) + + rollup <- rollup_summary(sum_x, variable = "Sigma") + expect_equal(rollup$unrolled, sum_x[!startsWith(sum_x$variable, "Sigma"), ]) + expect_equal(rollup$rolled$variable, c("Sigma")) + expect_equal(rollup$rolled$dim, c("3,3")) + expect_equal(names(rollup$rolled), c("variable", "dim", "mean_min", "mean_max", "sd_min", "sd_max")) + expect_equal(rollup$rolled$mean_min, min(sum_x[startsWith(sum_x$variable, "Sigma"),]$mean)) + + rollup <- rollup_summary(sum_x, "mean", "min") + expect_equal(names(rollup$rolled), c("variable", "dim", "mean_mean", "mean_min", "sd_mean", "sd_min")) + expect_equal(rollup$rolled$mean_mean[1], mean(sum_x[startsWith(sum_x$variable, "mu"),]$mean)) + + rollup <- rollup_summary(sum_x, mean = c("median", "mean"), .funs = list(mean = "stop", sd = "min")) + expect_equal(names(rollup$rolled), c("variable", "dim", "mean_median", "mean_mean", "sd_min")) + expect_equal(rollup$rolled$mean_median[1], median(sum_x[startsWith(sum_x$variable, "mu"),]$mean)) + + x2 <- draws_rvars(x = c(rvar(matrix(1:20, ncol = 2)), NA)) + sum_x2 <- summarise_draws(x2, min, max) + rollup <- rollup_summary(sum_x2, list(min = function(x) min(x, na.rm = TRUE)), max) + expect_equal(rollup$rolled$variable, "x") + expect_equal(rollup$rolled$dim, "3") + expect_equal(rollup$rolled$min_min, 1) + expect_equal(rollup$rolled$min_max, NA_real_) + expect_equal(rollup$rolled$max_min, 10) + expect_equal(rollup$rolled$max_max, NA_real_) +}) + +test_that("chaining rollups works", { + set.seed(1234) + x <- example_draws(example = "multi_normal") + x <- mutate_variables(x, y = rnorm(ndraws(x))) + x <- as_draws_df(x) + + sum_x <- summarise_draws(x, "mean", "sd") + + rollup <- rollup_summary( + rollup_summary(sum_x, variable = "mu", sd = "min"), + variable = "Sigma", sd = "max" + ) + expect_equal(rollup$unrolled$variable, "y") + expect_equal(rollup$rolled$variable, c("mu", "Sigma")) + expect_equal(names(rollup$rolled), c("variable", "dim", "mean_min", "mean_max", "sd_min", "sd_max")) + expect_equal(rollup$rolled$sd_min, c(min(sum_x[startsWith(sum_x$variable, "mu"),]$sd), NA_real_)) + expect_equal(rollup$rolled$sd_max, c(NA_real_, max(sum_x[startsWith(sum_x$variable, "Sigma"),]$sd))) +}) + +test_that("rollup on draws-like object works", { + x <- as_draws_array(example_draws()) + expect_equal(rollup_summary(unclass(x)), rollup_summary(x)) +}) + +test_that("rollup on data frames works", { + x <- example_draws() + sum_x <- summarise_draws(x) + df_sum_x <- as.data.frame(sum_x) + df_sum_x$variable <- factor(df_sum_x$variable) + + expect_equal(rollup_summary(df_sum_x)$rolled, rollup_summary(sum_x)$rolled) +}) + +test_that("NULL rollup functions work", { + x <- example_draws() + + expect_equal( + as.data.frame(rollup_summary(x, .funs = NULL)$rolled), + data.frame(variable = "theta", dim = "8", stringsAsFactors = FALSE) + ) +}) + +test_that("unnamed rollups in `...` override measure-specific rollups in .funs", { + x <- example_draws() + ds <- summarise_draws(x, "mean", "rhat", "ess_bulk") + rollup <- rollup_summary(ds, "median") + expect_equal( + names(rollup$rolled), + c("variable", "dim", "mean_median", "rhat_median", "ess_bulk_median") + ) +}) + +test_that("printing works", { + x <- rollup_summary(example_draws()) + + for (color in c(TRUE, FALSE)) { + out <- capture.output(print(x, color = color)) + expect_match(out, "", fixed = TRUE, all = FALSE) + expect_match(out, "$unrolled", fixed = TRUE, all = FALSE) + expect_match(out, "variable +mean +median", all = FALSE) + expect_match(out, "$rolled", fixed = TRUE, all = FALSE) + expect_match(out, "variable +dim +mean_min +mean_max", all = FALSE) + } +}) diff --git a/tests/testthat/test-summarise_draws.R b/tests/testthat/test-summarise_draws.R index 80ed971e..476dd61c 100644 --- a/tests/testthat/test-summarise_draws.R +++ b/tests/testthat/test-summarise_draws.R @@ -15,6 +15,12 @@ test_that("summarise_draws works correctly", { sum_x <- summarise_draws(x, ~quantile(.x, probs = c(0.4, 0.6))) expect_true(all(c("40%", "60%") %in% names(sum_x))) + sum_x <- summarise_draws(x, c("mcse_mean", "mean"), median) + expect_true(all(c("mcse_mean", "mean", "median") %in% names(sum_x))) + + sum_x <- summarise_draws(x, f = c("mcse_mean", "mean")) + expect_true(all(c("f1", "f2") %in% names(sum_x))) + x[1, 1] <- NA sum_x <- summarise_draws(x) expect_true(is.na(sum_x[1, "q5"])) @@ -187,3 +193,20 @@ test_that("draws summaries can be converted to data frames", { expect_equal(as.data.frame(summarise_draws(draws_matrix, mean, quantile2)), ref) }) + +test_that("string summary functions in the posterior namespace can be found", { + expect_equal( + # execute in an environment where only summarise_draws() and example_draws() + # are available, but not ess_bulk(), so that summarise_draws() is explicitly + # forced to look in the posterior namespace for ess_bulk() + evalq( + summarise_draws(example_draws(), "ess_bulk"), + envir = list( + summarise_draws = summarise_draws, + example_draws = example_draws + ), + enclos = emptyenv() + ), + summarise_draws(example_draws(), ess_bulk) + ) +}) diff --git a/touchstone/script.R b/touchstone/script.R index 39aae730..f8ee8164 100644 --- a/touchstone/script.R +++ b/touchstone/script.R @@ -38,7 +38,7 @@ for (dest_type in draws_types) { as_draws_dest(x) } }, - n = 50 + n = 20 ) } @@ -54,7 +54,7 @@ for (n_variables in c(10, 100)) { "summarise_draws_{n_variables}_variables" := { posterior::summarise_draws(x) }, - n = 50 + n = 20 ) }