-
-
Notifications
You must be signed in to change notification settings - Fork 24
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Roll up summaries of nonscalar variables #152
base: master
Are you sure you want to change the base?
Changes from all commits
7b826db
4afd494
fb9f21a
9ff6754
0426a91
b7d4ac4
0abbdc5
0af3699
29180d3
4636788
640b2c8
a12fe52
c60f840
3a3d525
6e4bbb6
c77881f
53cc5b3
deb9121
d080f2a
f0c9cf9
138ee3f
9efc43d
1311415
97aff00
9af3493
92c826b
e639cd1
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,135 @@ | ||
#' Parse the indices from a vector of variables (extracted e.g. from a | ||
#' `draws_summary` object or a `draws` object) | ||
#' | ||
#' @name variable_indices | ||
#' @param x a character vector of variables | ||
#' | ||
#' @return | ||
#' A list with index information for each unique variable name V in `x`. Top-level list names are | ||
#' the variable names. Each element contains: | ||
#' $ndim the number of dimensions of V. Returns 0 for scalars with no brackets | ||
#' but 1 for `y[1]` even if `y` has no other entries in `x`. | ||
#' | ||
#' $dim a vector of the actual dimensions of V, as determined by the number of unique | ||
#' elements at each index position. Set to `NA` if the ndim is zero. | ||
#' | ||
#' $implied_dim a vector of the implied dimensions of V, where any position in V that | ||
#' contains exclusively integers is filled in to include all integers from the lesser of one | ||
#' and its minimum up to its maximum. Set to `NA` if ndim is zero. | ||
#' | ||
#' $index_names a list of length corresponding to ndim, where each element is the | ||
#' unique levels of the corresponding index if the index is parsed as factor, and NULL otherwise. | ||
#' Set to `NULL` if ndim is zero. | ||
#' | ||
#' $indices if ndim is zero, returns 1. | ||
#' if ndim is 1 or greater, returns a dataframe of every implied combination of indices | ||
#' | ||
#' $position the position of each combination of indices from $indices in the the argument `x` | ||
#' | ||
#' @details | ||
#' Assumes that variable indexing uses square brackets in the variable names | ||
#' | ||
NULL | ||
|
||
parse_variable_indices <- function(x){ | ||
vars_indices <- strsplit(x, "(\\[|\\])") | ||
vars <- sapply(vars_indices, `[[`, 1) | ||
var_names <- unique(vars) | ||
# Check that no variables contain unpaired or non-terminal square brackets | ||
if_indexed <- lengths(vars_indices) > 1 | ||
if_indexed2 <- grepl("\\[.*\\]$", x) | ||
bracket_problems <- grepl("\\].|\\[.*\\[|\\[,|,\\]|,,", x) | ||
if (any(bracket_problems) | (!identical(if_indexed, if_indexed2))) { | ||
stop_no_call(paste("Some variable names contain unpaired square brackets,", | ||
"missing indices, or are multi-indexed.")) | ||
} | ||
# Get ndim. Variables with no brackets are given as ndim zero. | ||
# Variables with brackets are given ndim 1, even if they contain just one element. | ||
ndim_elementwise <- sapply(vars_indices, function(x){ | ||
if (length(x) == 2) { | ||
out <- length(strsplit(x[2], ",")[[1]]) # number of commas plus one | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why not generate the |
||
} else { | ||
out <- 0 | ||
} | ||
out | ||
}) | ||
ndim <- sapply(var_names, function(x){ | ||
out <- unique(ndim_elementwise[vars == x]) | ||
if (length(out) != 1) { | ||
stop_no_call(paste0("Inconsistent indexing found for variable ", x, " .")) | ||
} | ||
out | ||
}) | ||
variable_indices_info <- lapply(var_names, function (x) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. You have three instances of *apply over the same structure ( There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I could well be missing a clever trick here, but I think this and the previous comment are not quite straightforward. Of these three apply statements, the first applies a function over There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ah I think I misunderstood what this was doing. I think this still could be simplified so that the parsing of variable indices only has to be done once (which should be faster and also have lower maintenance cost, since we wouldn't have to modify that parsing in two places for future features). If you move this line: indices <- sapply(vars_indices[var_i], `[[`, 2) Above its containing if/else block, then I think you could also factor out the vectorized strsplit from the line after the above one to get a list of indices, upon which you could then use lengths / etc to get elementwise ndims and check for consistency. Then index parsing only happens once / in one place. |
||
indices_info <- named_list(c("ndim", "dimensions", "implied_dimensions", "index_names", "indices", "position")) | ||
indices_info$ndim <- ndim[[x]] | ||
var_i <- vars == x | ||
var_length <- sum(var_i) | ||
if (ndim[x] == 0) { | ||
# single variable, no indices | ||
indices_info$dimensions <- indices_info$implied_dimensions <- NA | ||
indices_info$indices <- 1L | ||
indices_info$implied_dimensions <- NA | ||
indices_info$internal_position <- 1 | ||
} else { | ||
indices <- sapply(vars_indices[var_i], `[[`, 2) | ||
indices <- as.data.frame(do.call(rbind, strsplit(indices, ",")), | ||
stringsAsFactors = FALSE) | ||
indices_info$dimensions <- unname(apply(indices, 2, function(x){length(unique(x))})) | ||
unique_indices <- vector("list", length(indices)) | ||
.dimnames <- vector("list", length(indices)) | ||
names(unique_indices) <- names(indices) | ||
for (i in seq_along(indices)) { | ||
numeric_index <- suppressWarnings(as.numeric(indices[[i]])) | ||
if (!anyNA(numeric_index) && rlang::is_integerish(numeric_index)) { | ||
# for integer indices, we need to convert them to integers | ||
# so that we can sort them in numerical order (not string order) | ||
indices[[i]] <- as.integer(numeric_index) | ||
if (min(numeric_index) >= 1) { | ||
jsocolar marked this conversation as resolved.
Show resolved
Hide resolved
|
||
# integer indices >= 1 are forced to lower bound of 1 + no dimnames | ||
unique_indices[[i]] <- seq.int(1, max(numeric_index)) | ||
} else { | ||
# indices with values < 1 are filled in between the min and max | ||
unique_indices[[i]] <- seq.int(min(numeric_index), max(numeric_index)) | ||
} | ||
} else { | ||
# we convert non-numeric indices to factors so that we can force them | ||
# to be ordered as they appear in the data (rather than in alphabetical order) | ||
factor_levels <- unique(indices[[i]]) | ||
indices[[i]] <- factor(indices[[i]], levels = factor_levels) | ||
# these aren't sorted so they appear in original order | ||
unique_indices[[i]] <- factor(factor_levels, levels = factor_levels) | ||
.dimnames[[i]] <- unique_indices[[i]] | ||
} | ||
} | ||
|
||
indices_info$index_names <- .dimnames | ||
indices_info$implied_dimensions <- unname(lengths(unique_indices)) | ||
|
||
# sort indices and fill in missing indices as NA to ensure | ||
# (1) even if the order of the variables is something weird (like | ||
# x[2,2] comes before x[1,1]) the result | ||
# places those columns in the correct cells of the array | ||
# (2) if some combination of indices is missing (say x[2,1] isn't | ||
# in the input) that cell in the array gets an NA | ||
|
||
# Use expand.grid to get all cells in output array. We reverse indices | ||
# here because it helps us do the sort after the merge, where | ||
# we need to sort in reverse order of the indices (because | ||
# the value of the last index should move slowest) | ||
all_indices <- expand.grid(rev(unique_indices)) | ||
# merge with all.x = TRUE (left join) to fill in missing cells with NA | ||
indices <- merge(all_indices, cbind(indices, index = seq_len(nrow(indices))), | ||
all.x = TRUE, sort = FALSE) | ||
# need to do the sort manually after merge because when sort = TRUE, merge | ||
# sorts factors as if they were strings, and we need factors to be sorted as factors | ||
indices <- indices[do.call(order, as.list(indices[, -ncol(indices), drop = FALSE])),] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Even though I wrote this sort originally I'm not actually sure this sort is needed, because I think the |
||
|
||
indices_info$indices <- unname(rev(indices[, names(indices) != "index", drop = FALSE])) | ||
indices_info$position <- which(var_i)[indices$index] | ||
} | ||
indices_info | ||
}) | ||
names(variable_indices_info) <- var_names | ||
variable_indices_info | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,113 @@ | ||
#' "Roll up" `draws_summary` objects by collapsing over nonscalar parameters. | ||
#' | ||
#' By default, all variables with names matched by `\\[.*\\]$` are rolled up, | ||
#' but there is an option to pass a list of parameter names, which will roll up | ||
#' any variables matched by `^parameter_name\\[.*\\]$` | ||
#' | ||
#' @name draws_summary_rollup | ||
#' @param x a `draws_summary` object or a `draws` object to be summarised | ||
#' @param rollup_vars a list of variable names (excluding brackets and indices) to roll up | ||
#' @param min_only a character vector of varable names for which only minimum values are | ||
#' desired in the rollup | ||
#' @param max_only a character vector of varable names for which only maximum values are | ||
#' desired in the rollup | ||
|
||
#' @return | ||
#' The `rollup_summary()` methods return a list of [tibble][tibble::tibble] data frames. | ||
#' The first element is a standard `draws_summary` for the variables that are not rolled up | ||
#' The second element is a rollup of the variables to be rolled up and contains max and min | ||
#' values of the summary functions attained by any element of the variable | ||
#' | ||
#' @details | ||
#' By default, only the maximum value of `rhat` and the minimum values of [ess_bulk()] and | ||
#' [ess_tail()] are returned. # INSERT HOW WE HANDLE NA SUMMARIES | ||
#' | ||
#' @examples | ||
#' ds <- summarise_draws(example_draws()) | ||
#' ds2 <- summarise_draws(2 * example_draws()) | ||
#' ds2$variable <- c("pi", "upsilon", | ||
#' "omega[1,1]", "omega[2,1]", "omega[3,1]", "omega[4,1]", | ||
#' "omega[1,2]", "omega[2,2]", "omega[3,2]", "omega[4,2]") | ||
#' draws_summary <- rbind(ds, ds2) | ||
#' rollup_summary(draws_summary) | ||
#' rollup_summary(draws_summary, rollup_vars = "theta") | ||
#' rollup_summary(example_draws()) | ||
NULL | ||
|
||
#' @rdname draws_summary_rollup | ||
#' @export | ||
rollup_summary <- function(x, rollup_vars = NULL, | ||
min_only = c("ess_bulk", "ess_tail"), | ||
max_only = "rhat") { | ||
UseMethod("rollup_summary") | ||
} | ||
|
||
#' @rdname draws_summary_rollup | ||
#' @export | ||
rollup_summary.default <- function(x, rollup_vars = NULL, | ||
min_only = c("ess_bulk", "ess_tail"), | ||
max_only = "rhat") { | ||
rollup_summary(summarise_draws(x), rollup_vars = rollup_vars, | ||
min_only = min_only, | ||
max_only = max_only) | ||
} | ||
|
||
#' @rdname draws_summary_rollup | ||
#' @export | ||
rollup_summary.draws_summary <- function (x, rollup_vars = NULL, | ||
min_only = c("ess_bulk", "ess_tail"), | ||
max_only = "rhat") { | ||
# get variable names | ||
vars <- draws_summary$variable | ||
# Determine which variable names need to be rolled up | ||
if (is.null(rollup_vars)) { | ||
vars_nonscalar <- grepl("\\[", vars) | ||
} else { | ||
vars_nonscalar <- as.logical(colSums(do.call(rbind, | ||
lapply(paste0("^", rollup_vars, "\\["), | ||
function(x){grepl(x, vars)})))) | ||
} | ||
# Separate out draws_summary into the scalar variables to leave alone and the nonscalar | ||
# variables for rollup | ||
ds_scalar <- draws_summary[!vars_nonscalar, ] | ||
ds_nonscalar <- draws_summary[vars_nonscalar, ] | ||
# Roll up the nonscalar variables | ||
varnames_nonscalar <- gsub("\\[(.*)", "", ds_nonscalar$variable) | ||
summary_names <- names(draws_summary)[-1] | ||
names_minmax <- summary_names[!(summary_names %in% c(min_only, max_only))] | ||
split_nonscalar <- split(ds_nonscalar, varnames_nonscalar)[unique(varnames_nonscalar)] | ||
# [unique(varnames_nonscalar)] preserves the order of the names | ||
min_max <- do.call(rbind, lapply(split_nonscalar, rollup_helper_minmax, | ||
names = names_minmax)) | ||
min_only <- do.call(rbind, lapply(split_nonscalar, rollup_helper_min, names = min_only)) | ||
max_only <- do.call(rbind, lapply(split_nonscalar, rollup_helper_max, names = max_only)) | ||
variable_column <- data.frame("variable" = unique(varnames_nonscalar)) | ||
variable_indices <- parse_variable_indices(ds_nonscalar$variable) | ||
dimension_column <- data.frame("dimension" = paste0("(", | ||
sapply(variable_indices, function(x){paste(x$dimensions, collapse = ",")}), | ||
")")) | ||
nonscalar_out <- tibble::as_tibble(cbind(variable_column, dimension_column, min_max, max_only, min_only)) | ||
out <- list(unrolled_vars = ds_scalar, rolled_vars = nonscalar_out) | ||
out | ||
} | ||
|
||
rollup_helper_minmax <- function(x, names){ | ||
x <- x[, names] | ||
mm <- c(apply(x, 2, function(x) {c(min(x), max(x))})) | ||
names(mm) <- paste0(rep(names(x), each = 2), c("_min", "_max")) | ||
mm | ||
} | ||
|
||
rollup_helper_min <- function(x, names){ | ||
x <- x[, names] | ||
min_only <- apply(x, 2, min) | ||
names(min_only) <- paste0(names(x), "_min") | ||
min_only | ||
} | ||
|
||
rollup_helper_max <- function(x, names){ | ||
x <- x[, names] | ||
max_only <- apply(x, 2, max) | ||
names(max_only) <- paste0(names(x), "_max") | ||
max_only | ||
} |
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You use a number of regexes in a row here to parse
x
. Can you combine them into a single regex and use capture groups to pull out the relevant pieces? That should be faster, and will also be more flexible if at some point in the future we want to provide people with options to parse different formats of indices (like different separators or whatever).There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
To expand on this: here's a regular expression that matches well-formed indexed variable names. Starting/ending the regex with "^" and "$" guarantees it must match the whole name, so failures imply an ill-formed name:
Then you can check for ill-formed names with something like this:
Variables without indices will have the fourth element of their corresponding entry in
vars_matches
be empty, and variables with indices will have it be the indices which can then be split.This regex is fairly strict, and I'm not sure how strict we want to be --- e.g., do we want names like
"x]"
to fail, or do we want them just to be parsed as non-indexed variables? I can see an argument for a looser version of this parsing that rarely (or maybe never) fails but instead lets weird-looking names go by and treats them as non-indexed. E.g. something like this:i.e. in this example
"b["
,"d[]"
, and"e]"
all get treated as weird-looking variable names.This is also a bit more flexible in that it would be pretty easy to allow people to pass in different choices for the opening/closing bracket and update the regex accordingly.