From c29392b6cab8467388099de5320caec3744f6d9e Mon Sep 17 00:00:00 2001 From: Matthew Kay Date: Sat, 18 Nov 2023 11:29:04 -0600 Subject: [PATCH 01/10] variance and sd for complex numbers --- NAMESPACE | 3 +++ R/rvar-summaries-over-draws.R | 14 ++++++++++++++ tests/testthat/test-rvar-summaries-over-draws.R | 7 +++++++ 3 files changed, 24 insertions(+) diff --git a/NAMESPACE b/NAMESPACE index 2d81d65d..1337b316 100755 --- a/NAMESPACE +++ b/NAMESPACE @@ -275,6 +275,7 @@ S3method(rhat_basic,default) S3method(rhat_basic,rvar) S3method(rhat_nested,default) S3method(rhat_nested,rvar) +S3method(sd,complex) S3method(sd,default) S3method(sd,rvar) S3method(split_chains,draws) @@ -297,6 +298,7 @@ S3method(thin_draws,draws) S3method(thin_draws,rvar) S3method(unique,rvar) S3method(unique,rvar_factor) +S3method(var,complex) S3method(var,default) S3method(var,rvar) S3method(variables,"NULL") @@ -305,6 +307,7 @@ S3method(variables,draws_df) S3method(variables,draws_list) S3method(variables,draws_matrix) S3method(variables,draws_rvars) +S3method(variance,complex) S3method(variance,draws_array) S3method(variance,draws_matrix) S3method(variance,rvar) diff --git a/R/rvar-summaries-over-draws.R b/R/rvar-summaries-over-draws.R index b0adb9ff..f73cd2c0 100755 --- a/R/rvar-summaries-over-draws.R +++ b/R/rvar-summaries-over-draws.R @@ -176,6 +176,12 @@ variance.rvar <- function(x, ...) { x, "variance", matrixStats::colVars, useNames = FALSE, .ordered_okay = FALSE, ... ) } +#' @rdname rvar-summaries-over-draws +#' @export +variance.complex <- function(x, ...) { + variance(Re(c(x), ...)) + variance(Im(c(x), ...)) +} + #' @rdname rvar-summaries-over-draws #' @export @@ -185,6 +191,9 @@ var <- function(x, ...) UseMethod("var") var.default <- function(x, ...) stats::var(x, ...) #' @rdname rvar-summaries-over-draws #' @export +var.complex <- variance.complex +#' @rdname rvar-summaries-over-draws +#' @export var.rvar <- variance.rvar #' @rdname rvar-summaries-over-draws @@ -195,6 +204,11 @@ sd <- function(x, ...) UseMethod("sd") sd.default <- function(x, ...) stats::sd(x, ...) #' @rdname rvar-summaries-over-draws #' @export +sd.complex <- function(x, ...) { + sqrt(variance(c(x), ...)) +} +#' @rdname rvar-summaries-over-draws +#' @export sd.rvar <- function(x, ...) { summarise_rvar_by_element_via_matrix( x, "sd", matrixStats::colSds, useNames = FALSE, .ordered_okay = FALSE, ... diff --git a/tests/testthat/test-rvar-summaries-over-draws.R b/tests/testthat/test-rvar-summaries-over-draws.R index 5d1cd62d..25fbda0e 100755 --- a/tests/testthat/test-rvar-summaries-over-draws.R +++ b/tests/testthat/test-rvar-summaries-over-draws.R @@ -89,6 +89,13 @@ test_that("spread functions work", { expect_equal(mad(y), apply(y_array, 2, mad)) }) +test_that("spread functions work on complex numbers", { + x_array = 1:11 + c(10:1,22) * 1i + + expect_equal(var(x_array), 44) + expect_equal(variance(x_array), 44) + expect_equal(sd(x_array), sqrt(44)) +}) # range ------------------------------------------------------------------- From ecb33d02618babf0f5d73904a321fedf07625160 Mon Sep 17 00:00:00 2001 From: Matthew Kay Date: Sat, 18 Nov 2023 12:36:55 -0600 Subject: [PATCH 02/10] basic support for complex numbers in all formats --- R/as_draws.R | 4 ++-- R/misc.R | 5 +++++ R/rvar-cast.R | 4 ++-- tests/testthat/test-as_draws.R | 26 ++++++++++++++++++++++++++ 4 files changed, 35 insertions(+), 4 deletions(-) diff --git a/R/as_draws.R b/R/as_draws.R index c52e75a5..261b077e 100644 --- a/R/as_draws.R +++ b/R/as_draws.R @@ -109,7 +109,7 @@ check_draws_object <- function(x) { #' @noRd check_variables_are_numeric <- function( x, to = "draws_array", - is_non_numeric = function(x_i) !is.numeric(x_i) && !is.logical(x_i), + is_non_numeric = function(x_i) !is.numeric(x_i) && !is.logical(x_i) && !is.complex(x_i), convert = TRUE ) { @@ -145,7 +145,7 @@ validate_draws_per_variable <- function(...) { # '.nchains' is an additional argument in chain supporting formats stop_no_call("'.nchains' is not supported for this format.") } - out <- lapply(out, as.numeric) + out <- lapply(out, as_numeric_or_complex) ndraws_per_variable <- lengths(out) ndraws <- max(ndraws_per_variable) if (!all(ndraws_per_variable %in% c(1, ndraws))) { diff --git a/R/misc.R b/R/misc.R index e5c681d8..68fac39f 100644 --- a/R/misc.R +++ b/R/misc.R @@ -111,6 +111,11 @@ as_one_character <- function(x, allow_na = FALSE) { x } +# coerce 'x' to a numeric or complex vector +as_numeric_or_complex <- function(x) { + if (is.numeric(x) || is.complex(x)) x else as.numeric(x) +} + # check if all inputs are NULL all_null <- function(...) { all(ulapply(list(...), is.null)) diff --git a/R/rvar-cast.R b/R/rvar-cast.R index 11328947..5f11c720 100755 --- a/R/rvar-cast.R +++ b/R/rvar-cast.R @@ -19,7 +19,7 @@ #' While `as_rvar()` attempts to pick the most suitable subtype of [`rvar`] based on the #' type of `x` (possibly returning an [`rvar_factor`] or [`rvar_ordered`]), #' `as_rvar_numeric()`, `as_rvar_integer()`, and `as_rvar_logical()` always coerce -#' the draws of the output [`rvar`] to be [`numeric`], [`integer`], or [`logical`] +#' the draws of the output [`rvar`] to be [`numeric`] (or [`complex`]), [`integer`], or [`logical`] #' (respectively), and always return a base [`rvar`], never a subtype. #' #' @seealso [rvar()] to construct [`rvar`]s directly. See [rdo()], [rfun()], and @@ -83,7 +83,7 @@ as_rvar <- function(x, dim = NULL, dimnames = NULL, nchains = NULL) { #' @export as_rvar_numeric <- function(x, dim = NULL, dimnames = NULL, nchains = NULL) { out <- as_rvar(x, dim = dim, dimnames = dimnames, nchains = nchains) - draws_of(out) <- while_preserving_dims(as.numeric, draws_of(out)) + draws_of(out) <- while_preserving_dims(as_numeric_or_complex, draws_of(out)) out } diff --git a/tests/testthat/test-as_draws.R b/tests/testthat/test-as_draws.R index c016351d..4e3dda4f 100644 --- a/tests/testthat/test-as_draws.R +++ b/tests/testthat/test-as_draws.R @@ -545,3 +545,29 @@ test_that("lossy conversion to formats that don't support discrete variables wor ) } }) + + +# complex variables ------------------------------------------------------- + +test_that("all formats support complex numbers", { + y_array = array(1:24 + 24:1 * 1i, dim = c(2,2,3,2), dimnames = list(NULL)) + z_array = array(1:12 + 12:1 * 1i, dim = c(2,2,3), dimnames = list(NULL)) + draws_rvars <- draws_rvars( + y = rvar(y_array, with_chains = TRUE), + z = rvar(z_array, with_chains = TRUE) + ) + + expect_equal(draws_of(draws_rvars$y, with_chains = TRUE), y_array) + expect_equal(draws_of(draws_rvars$z, with_chains = TRUE), z_array) + + expect_equal(as_draws_rvars(as_draws_matrix(draws_rvars)), draws_rvars) + expect_equal(as_draws_rvars(as_draws_array(draws_rvars)), draws_rvars) + expect_equal(as_draws_rvars(as_draws_df(draws_rvars)), draws_rvars) + expect_equal(as_draws_rvars(as_draws_list(draws_rvars)), draws_rvars) + + expect_equal(unname(draws_matrix(z = c(z_array))[,"z",drop = TRUE]), c(z_array)) + expect_equal(unname(draws_array(z = c(z_array))[,,"z",drop = TRUE]), c(z_array)) + expect_equal(draws_df(z = c(z_array))$z, c(z_array)) + expect_equal(draws_list(z = c(z_array))[[1]]$z, c(z_array)) + expect_equal(draws_rvars(z = c(z_array))$z, rvar(c(z_array))) +}) From 7756de697da1e604fcc575edc8e9a8828a2b0316 Mon Sep 17 00:00:00 2001 From: Matthew Kay Date: Sat, 18 Nov 2023 13:45:16 -0600 Subject: [PATCH 03/10] casting and type predicates for complex rvars --- NAMESPACE | 9 ++++ R/rvar-cast.R | 77 ++++++++++++++++++++++++++++++-- man/as_rvar.Rd | 9 ++-- man/is_rvar_complex.Rd | 20 +++++++++ man/is_rvar_integer.Rd | 20 +++++++++ man/is_rvar_logical.Rd | 20 +++++++++ man/rvar-summaries-over-draws.Rd | 9 ++++ 7 files changed, 157 insertions(+), 7 deletions(-) create mode 100644 man/is_rvar_complex.Rd create mode 100644 man/is_rvar_integer.Rd create mode 100644 man/is_rvar_logical.Rd diff --git a/NAMESPACE b/NAMESPACE index 1337b316..ff163c68 100755 --- a/NAMESPACE +++ b/NAMESPACE @@ -316,6 +316,7 @@ S3method(vec_cast,character.rvar_factor) S3method(vec_cast,character.rvar_ordered) S3method(vec_cast,distribution.rvar) S3method(vec_cast,rvar.character) +S3method(vec_cast,rvar.complex) S3method(vec_cast,rvar.distribution) S3method(vec_cast,rvar.double) S3method(vec_cast,rvar.factor) @@ -326,6 +327,7 @@ S3method(vec_cast,rvar.rvar) S3method(vec_cast,rvar.rvar_factor) S3method(vec_cast,rvar.rvar_ordered) S3method(vec_cast,rvar_factor.character) +S3method(vec_cast,rvar_factor.complex) S3method(vec_cast,rvar_factor.double) S3method(vec_cast,rvar_factor.factor) S3method(vec_cast,rvar_factor.integer) @@ -335,6 +337,7 @@ S3method(vec_cast,rvar_factor.rvar) S3method(vec_cast,rvar_factor.rvar_factor) S3method(vec_cast,rvar_factor.rvar_ordered) S3method(vec_cast,rvar_ordered.character) +S3method(vec_cast,rvar_ordered.complex) S3method(vec_cast,rvar_ordered.double) S3method(vec_cast,rvar_ordered.factor) S3method(vec_cast,rvar_ordered.integer) @@ -352,6 +355,7 @@ S3method(vec_ptype,rvar_factor) S3method(vec_ptype,rvar_ordered) S3method(vec_ptype2,character.rvar_factor) S3method(vec_ptype2,character.rvar_ordered) +S3method(vec_ptype2,complex.rvar) S3method(vec_ptype2,distribution.rvar) S3method(vec_ptype2,double.rvar) S3method(vec_ptype2,factor.rvar_factor) @@ -360,6 +364,7 @@ S3method(vec_ptype2,integer.rvar) S3method(vec_ptype2,logical.rvar) S3method(vec_ptype2,ordered.rvar_factor) S3method(vec_ptype2,ordered.rvar_ordered) +S3method(vec_ptype2,rvar.complex) S3method(vec_ptype2,rvar.distribution) S3method(vec_ptype2,rvar.double) S3method(vec_ptype2,rvar.integer) @@ -399,6 +404,7 @@ export(as_draws_list) export(as_draws_matrix) export(as_draws_rvars) export(as_rvar) +export(as_rvar_complex) export(as_rvar_factor) export(as_rvar_integer) export(as_rvar_logical) @@ -441,7 +447,10 @@ export(is_draws_list) export(is_draws_matrix) export(is_draws_rvars) export(is_rvar) +export(is_rvar_complex) export(is_rvar_factor) +export(is_rvar_integer) +export(is_rvar_logical) export(is_rvar_ordered) export(iteration_ids) export(mad) diff --git a/R/rvar-cast.R b/R/rvar-cast.R index 5f11c720..5901b664 100755 --- a/R/rvar-cast.R +++ b/R/rvar-cast.R @@ -11,15 +11,15 @@ #' @details For objects that are already [`rvar`]s, returns them (with modified dimensions #' if `dim` is not `NULL`). #' -#' For numeric or logical vectors or arrays, returns an [`rvar`] with a single draw and +#' For [`numeric`], [`complex`], or [`logical`] vectors or arrays, returns an [`rvar`] with a single draw and #' the same dimensions as `x`. This is in contrast to the [rvar()] constructor, which #' treats the first dimension of `x` as the draws dimension. As a result, `as_rvar()` #' is useful for creating constants. #' #' While `as_rvar()` attempts to pick the most suitable subtype of [`rvar`] based on the #' type of `x` (possibly returning an [`rvar_factor`] or [`rvar_ordered`]), -#' `as_rvar_numeric()`, `as_rvar_integer()`, and `as_rvar_logical()` always coerce -#' the draws of the output [`rvar`] to be [`numeric`] (or [`complex`]), [`integer`], or [`logical`] +#' `as_rvar_numeric()`, `as_rvar_complex()`, `as_rvar_integer()`, and `as_rvar_logical()` always coerce +#' the draws of the output [`rvar`] to be [`numeric`], [`complex`], [`integer`], or [`logical`] #' (respectively), and always return a base [`rvar`], never a subtype. #' #' @seealso [rvar()] to construct [`rvar`]s directly. See [rdo()], [rfun()], and @@ -83,7 +83,15 @@ as_rvar <- function(x, dim = NULL, dimnames = NULL, nchains = NULL) { #' @export as_rvar_numeric <- function(x, dim = NULL, dimnames = NULL, nchains = NULL) { out <- as_rvar(x, dim = dim, dimnames = dimnames, nchains = nchains) - draws_of(out) <- while_preserving_dims(as_numeric_or_complex, draws_of(out)) + draws_of(out) <- while_preserving_dims(as.numeric, draws_of(out)) + out +} + +#' @rdname as_rvar +#' @export +as_rvar_complex <- function(x, dim = NULL, dimnames = NULL, nchains = NULL) { + out <- as_rvar(x, dim = dim, dimnames = dimnames, nchains = nchains) + draws_of(out) <- while_preserving_dims(as.complex, draws_of(out)) out } @@ -121,6 +129,51 @@ is_rvar <- function(x) { inherits(x, "rvar") } +#' Is `x` a complex random variable? +#' +#' Test if `x` is an [`rvar`] backed by [`complex`] draws. +#' +#' @inheritParams is_rvar +#' +#' @seealso [as_rvar_complex()] to convert objects to `rvar`s backed by [`complex`] draws. +#' +#' @return `TRUE` if `x` is an [`rvar`] backed by [`complex`] draws, `FALSE` otherwise. +#' +#' @export +is_rvar_complex <- function(x) { + is.complex(draws_of(x)) +} + +#' Is `x` an integer random variable? +#' +#' Test if `x` is an [`rvar`] backed by [`integer`] draws. +#' +#' @inheritParams is_rvar +#' +#' @seealso [as_rvar_integer()] to convert objects to `rvar`s backed by [`integer`] draws. +#' +#' @return `TRUE` if `x` is an [`rvar`] backed by [`integer`] draws, `FALSE` otherwise. +#' +#' @export +is_rvar_integer <- function(x) { + is.integer(draws_of(x)) +} + +#' Is `x` a logical random variable? +#' +#' Test if `x` is an [`rvar`] backed by [`logical`] draws. +#' +#' @inheritParams is_rvar +#' +#' @seealso [as_rvar_logical()] to convert objects to `rvar`s backed by [`logical`] draws. +#' +#' @return `TRUE` if `x` is an [`rvar`] backed by [`logical`] draws, `FALSE` otherwise. +#' +#' @export +is_rvar_logical <- function(x) { + is.logical(draws_of(x)) +} + #' @export is.matrix.rvar <- function(x) { length(dim(draws_of(x))) == 3 @@ -384,6 +437,22 @@ vec_cast.rvar_factor.double <- function(x, to, ...) new_constant_rvar(while_pres #' @export vec_cast.rvar_ordered.double <- function(x, to, ...) new_constant_rvar(while_preserving_dims(as.ordered, x)) +# complex -> rvar +#' @export +vec_ptype2.complex.rvar <- function(x, y, ...) new_rvar() +#' @export +vec_ptype2.rvar.complex <- function(x, y, ...) new_rvar() +#' @export +vec_cast.rvar.complex <- function(x, to, ...) new_constant_rvar(x) + +# complex -> rvar_factor +#' @export +vec_cast.rvar_factor.complex <- function(x, to, ...) new_constant_rvar(while_preserving_dims(as.factor, x)) + +# complex -> rvar_ordered +#' @export +vec_cast.rvar_ordered.complex <- function(x, to, ...) new_constant_rvar(while_preserving_dims(as.ordered, x)) + # integer -> rvar #' @export vec_ptype2.integer.rvar <- function(x, y, ...) new_rvar() diff --git a/man/as_rvar.Rd b/man/as_rvar.Rd index b5fd527f..6481fc45 100755 --- a/man/as_rvar.Rd +++ b/man/as_rvar.Rd @@ -3,6 +3,7 @@ \name{as_rvar} \alias{as_rvar} \alias{as_rvar_numeric} +\alias{as_rvar_complex} \alias{as_rvar_integer} \alias{as_rvar_logical} \title{Coerce to a random variable} @@ -11,6 +12,8 @@ as_rvar(x, dim = NULL, dimnames = NULL, nchains = NULL) as_rvar_numeric(x, dim = NULL, dimnames = NULL, nchains = NULL) +as_rvar_complex(x, dim = NULL, dimnames = NULL, nchains = NULL) + as_rvar_integer(x, dim = NULL, dimnames = NULL, nchains = NULL) as_rvar_logical(x, dim = NULL, dimnames = NULL, nchains = NULL) @@ -44,15 +47,15 @@ Convert \code{x} to an \code{\link{rvar}} object. For objects that are already \code{\link{rvar}}s, returns them (with modified dimensions if \code{dim} is not \code{NULL}). -For numeric or logical vectors or arrays, returns an \code{\link{rvar}} with a single draw and +For \code{\link{numeric}}, \code{\link{complex}}, or \code{\link{logical}} vectors or arrays, returns an \code{\link{rvar}} with a single draw and the same dimensions as \code{x}. This is in contrast to the \code{\link[=rvar]{rvar()}} constructor, which treats the first dimension of \code{x} as the draws dimension. As a result, \code{as_rvar()} is useful for creating constants. While \code{as_rvar()} attempts to pick the most suitable subtype of \code{\link{rvar}} based on the type of \code{x} (possibly returning an \code{\link{rvar_factor}} or \code{\link{rvar_ordered}}), -\code{as_rvar_numeric()}, \code{as_rvar_integer()}, and \code{as_rvar_logical()} always coerce -the draws of the output \code{\link{rvar}} to be \code{\link{numeric}}, \code{\link{integer}}, or \code{\link{logical}} +\code{as_rvar_numeric()}, \code{as_rvar_complex()}, \code{as_rvar_integer()}, and \code{as_rvar_logical()} always coerce +the draws of the output \code{\link{rvar}} to be \code{\link{numeric}}, \code{\link{complex}}, \code{\link{integer}}, or \code{\link{logical}} (respectively), and always return a base \code{\link{rvar}}, never a subtype. } \examples{ diff --git a/man/is_rvar_complex.Rd b/man/is_rvar_complex.Rd new file mode 100644 index 00000000..3765ceed --- /dev/null +++ b/man/is_rvar_complex.Rd @@ -0,0 +1,20 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/rvar-cast.R +\name{is_rvar_complex} +\alias{is_rvar_complex} +\title{Is \code{x} a complex random variable?} +\usage{ +is_rvar_complex(x) +} +\arguments{ +\item{x}{(any object) An object to test.} +} +\value{ +\code{TRUE} if \code{x} is an \code{\link{rvar}} backed by \code{\link{complex}} draws, \code{FALSE} otherwise. +} +\description{ +Test if \code{x} is an \code{\link{rvar}} backed by \code{\link{complex}} draws. +} +\seealso{ +\code{\link[=as_rvar_complex]{as_rvar_complex()}} to convert objects to \code{rvar}s backed by \code{\link{complex}} draws. +} diff --git a/man/is_rvar_integer.Rd b/man/is_rvar_integer.Rd new file mode 100644 index 00000000..4069ecd4 --- /dev/null +++ b/man/is_rvar_integer.Rd @@ -0,0 +1,20 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/rvar-cast.R +\name{is_rvar_integer} +\alias{is_rvar_integer} +\title{Is \code{x} an integer random variable?} +\usage{ +is_rvar_integer(x) +} +\arguments{ +\item{x}{(any object) An object to test.} +} +\value{ +\code{TRUE} if \code{x} is an \code{\link{rvar}} backed by \code{\link{integer}} draws, \code{FALSE} otherwise. +} +\description{ +Test if \code{x} is an \code{\link{rvar}} backed by \code{\link{integer}} draws. +} +\seealso{ +\code{\link[=as_rvar_integer]{as_rvar_integer()}} to convert objects to \code{rvar}s backed by \code{\link{integer}} draws. +} diff --git a/man/is_rvar_logical.Rd b/man/is_rvar_logical.Rd new file mode 100644 index 00000000..f20bc37a --- /dev/null +++ b/man/is_rvar_logical.Rd @@ -0,0 +1,20 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/rvar-cast.R +\name{is_rvar_logical} +\alias{is_rvar_logical} +\title{Is \code{x} a logical random variable?} +\usage{ +is_rvar_logical(x) +} +\arguments{ +\item{x}{(any object) An object to test.} +} +\value{ +\code{TRUE} if \code{x} is an \code{\link{rvar}} backed by \code{\link{logical}} draws, \code{FALSE} otherwise. +} +\description{ +Test if \code{x} is an \code{\link{rvar}} backed by \code{\link{logical}} draws. +} +\seealso{ +\code{\link[=as_rvar_logical]{as_rvar_logical()}} to convert objects to \code{rvar}s backed by \code{\link{logical}} draws. +} diff --git a/man/rvar-summaries-over-draws.Rd b/man/rvar-summaries-over-draws.Rd index 8b7c8caa..50937479 100755 --- a/man/rvar-summaries-over-draws.Rd +++ b/man/rvar-summaries-over-draws.Rd @@ -17,11 +17,14 @@ \alias{any.rvar} \alias{Summary.rvar} \alias{variance.rvar} +\alias{variance.complex} \alias{var} \alias{var.default} +\alias{var.complex} \alias{var.rvar} \alias{sd} \alias{sd.default} +\alias{sd.complex} \alias{sd.rvar} \alias{mad} \alias{mad.default} @@ -64,16 +67,22 @@ Pr(x, ...) \method{variance}{rvar}(x, ...) +\method{variance}{complex}(x, ...) + var(x, ...) \method{var}{default}(x, ...) +\method{var}{complex}(x, ...) + \method{var}{rvar}(x, ...) sd(x, ...) \method{sd}{default}(x, ...) +\method{sd}{complex}(x, ...) + \method{sd}{rvar}(x, ...) mad(x, ...) From 78a1d673016b0cd0c76432b2c48c67f06005b02d Mon Sep 17 00:00:00 2001 From: Matthew Kay Date: Sun, 19 Nov 2023 12:10:58 -0600 Subject: [PATCH 04/10] allow rvar summary functions to work with complex numbers --- R/rvar-.R | 23 +++++++++++++++---- R/rvar-dist.R | 9 ++++---- R/rvar-summaries-over-draws.R | 40 ++++++++++++++++++--------------- R/rvar-summaries-within-draws.R | 29 ++++++++++++------------ 4 files changed, 61 insertions(+), 40 deletions(-) diff --git a/R/rvar-.R b/R/rvar-.R index 129f357c..6b23946f 100755 --- a/R/rvar-.R +++ b/R/rvar-.R @@ -948,12 +948,16 @@ summarise_rvar_within_draws <- function(x, .f, ..., .transpose = FALSE, .when_em #' by first collapsing dimensions into columns of the draws matrix #' (so that .f can be a rowXXX() function) #' @param x an rvar -#' @param name function name to use for error messages +#' @param .name function name to use for error messages #' @param .f a function that takes a matrix and summarises its rows, like rowMeans #' @param ... arguments passed to `.f` #' @param .ordered_okay can this function be applied to rvar_ordereds? #' @noRd summarise_rvar_within_draws_via_matrix <- function(x, .name, .f, ..., .ordered_okay = FALSE) { + if (is_rvar_complex(x)) { + return(summarise_rvar_within_draws(x, match.fun(.name), ...)) + } + .length <- length(x) if (!.length) { x <- rvar() @@ -966,7 +970,7 @@ summarise_rvar_within_draws_via_matrix <- function(x, .name, .f, ..., .ordered_o .draws <- .f(draws_of(as_rvar_numeric(x)), ...) .draws <- while_preserving_dims(function(.draws) ordered(.levels[round(.draws)], .levels), .draws) } else if (is_rvar_factor(x)) { - stop_no_call("Cannot apply `", .name, "` function to rvar_factor objects.") + stop_no_call("Cannot apply `rvar_", .name, "` function to rvar_factor objects.") } else { .draws <- .f(draws_of(x), ...) } @@ -997,18 +1001,29 @@ summarise_rvar_by_element <- function(x, .f, ...) { #' by first collapsing dimensions into columns of the draws matrix, applying the #' function, then restoring dimensions (so that .f can be a colXXX() function) #' @param x an rvar -#' @param name function name to use for error messages +#' @param .name function name to use for error messages, and also function to +#' be used as a backup for complex numbers #' @param .f a function that takes a matrix and summarises its columns, like colMeans #' @param .extra_dim extra dims added by `.f` to the output, e.g. in the case of #' matrixStats::colRanges this is `2` #' @param .extra_dimnames extra dimension names for dims added by `.f` to the output #' @param .ordered_okay can this function be applied to rvar_ordereds? #' @param .factor_okay can this function be applied to rvar_factors? +#' @param .complex_okay can this function be applied to complex rvars? If not, +#' the function match.fun(.name) will be used instead, element-by-element. #' @param ... arguments passed to `.f` #' @noRd summarise_rvar_by_element_via_matrix <- function( - x, .name, .f, .extra_dim = NULL, .extra_dimnames = NULL, .ordered_okay = TRUE, .factor_okay = FALSE, ... + x, .name, .f, + .extra_dim = NULL, .extra_dimnames = NULL, + .ordered_okay = TRUE, .factor_okay = FALSE, + .complex_okay = FALSE, + ... ) { + if (is_rvar_complex(x) && !.complex_okay) { + return(summarise_rvar_by_element(x, match.fun(.name), ...)) + } + .dim <- dim(x) .dimnames <- dimnames(x) .length <- length(x) diff --git a/R/rvar-dist.R b/R/rvar-dist.R index ee57c162..5942953c 100755 --- a/R/rvar-dist.R +++ b/R/rvar-dist.R @@ -93,11 +93,12 @@ cdf.rvar_ordered <- function(x, q, ...) { quantile.rvar <- function(x, probs, ...) { summarise_rvar_by_element_via_matrix(x, "quantile", - function(draws) { - t(matrixStats::colQuantiles(draws, probs = probs, useNames = TRUE, ...)) - }, + function(..., names) t(matrixStats::colQuantiles(..., useNames = FALSE)), .extra_dim = length(probs), - .extra_dimnames = list(NULL) + .extra_dimnames = list(NULL), + probs = probs, + names = FALSE, + ... ) } diff --git a/R/rvar-summaries-over-draws.R b/R/rvar-summaries-over-draws.R index f73cd2c0..7c7dc682 100755 --- a/R/rvar-summaries-over-draws.R +++ b/R/rvar-summaries-over-draws.R @@ -68,7 +68,7 @@ E <- function(x, ...) { #' @export mean.rvar <- function(x, ...) { summarise_rvar_by_element_via_matrix( - x, "mean", matrixStats::colMeans2, useNames = FALSE, .ordered_okay = FALSE, ... + x, "mean", function(...) matrixStats::colMeans2(..., useNames = FALSE), .ordered_okay = FALSE, ... ) } @@ -101,7 +101,7 @@ Pr.rvar <- function(x, ...) { #' @export median.rvar <- function(x, ...) { summarise_rvar_by_element_via_matrix( - x, "median", matrixStats::colMedians, useNames = FALSE, ... + x, "median", function(...) matrixStats::colMedians(..., useNames = FALSE), ... ) } @@ -109,7 +109,7 @@ median.rvar <- function(x, ...) { #' @export min.rvar <- function(x, ...) { summarise_rvar_by_element_via_matrix( - x, "min", matrixStats::colMins, useNames = FALSE, ... + x, "min", function(...) matrixStats::colMins(..., useNames = FALSE), ... ) } @@ -117,7 +117,7 @@ min.rvar <- function(x, ...) { #' @export max.rvar <- function(x, ...) { summarise_rvar_by_element_via_matrix( - x, "max", matrixStats::colMaxs, useNames = FALSE, ... + x, "max", function(...) matrixStats::colMaxs(..., useNames = FALSE), ... ) } @@ -125,7 +125,7 @@ max.rvar <- function(x, ...) { #' @export sum.rvar <- function(x, ...) { summarise_rvar_by_element_via_matrix( - x, "sum", matrixStats::colSums2, useNames = FALSE, .ordered_okay = FALSE, ... + x, "sum", function(...) matrixStats::colSums2(..., useNames = FALSE), .ordered_okay = FALSE, ... ) } @@ -133,7 +133,7 @@ sum.rvar <- function(x, ...) { #' @export prod.rvar <- function(x, ...) { summarise_rvar_by_element_via_matrix( - x, "prod", matrixStats::colProds, useNames = FALSE, .ordered_okay = FALSE, ... + x, "prod", function(...) matrixStats::colProds(..., useNames = FALSE), .ordered_okay = FALSE, ... ) } @@ -141,7 +141,7 @@ prod.rvar <- function(x, ...) { #' @export all.rvar <- function(x, ...) { summarise_rvar_by_element_via_matrix( - x, "all", matrixStats::colAlls, useNames = FALSE, .ordered_okay = FALSE, ... + x, "all", function(...) matrixStats::colAlls(..., useNames = FALSE), .ordered_okay = FALSE, ... ) } @@ -149,7 +149,7 @@ all.rvar <- function(x, ...) { #' @export any.rvar <- function(x, ...) { summarise_rvar_by_element_via_matrix( - x, "any", matrixStats::colAnys, useNames = FALSE, .ordered_okay = FALSE, ... + x, "any", function(...) matrixStats::colAnys(..., useNames = FALSE), .ordered_okay = FALSE, ... ) } @@ -173,13 +173,13 @@ distributional::variance #' @export variance.rvar <- function(x, ...) { summarise_rvar_by_element_via_matrix( - x, "variance", matrixStats::colVars, useNames = FALSE, .ordered_okay = FALSE, ... + x, "variance", function(...) matrixStats::colVars(..., useNames = FALSE), .ordered_okay = FALSE, ... ) } #' @rdname rvar-summaries-over-draws #' @export variance.complex <- function(x, ...) { - variance(Re(c(x), ...)) + variance(Im(c(x), ...)) + variance(Re(c(x)), ...) + variance(Im(c(x)), ...) } @@ -211,7 +211,7 @@ sd.complex <- function(x, ...) { #' @export sd.rvar <- function(x, ...) { summarise_rvar_by_element_via_matrix( - x, "sd", matrixStats::colSds, useNames = FALSE, .ordered_okay = FALSE, ... + x, "sd", function(...) matrixStats::colSds(..., useNames = FALSE), .ordered_okay = FALSE, ... ) } @@ -225,7 +225,7 @@ mad.default <- function(x, ...) stats::mad(x, ...) #' @export mad.rvar <- function(x, ...) { summarise_rvar_by_element_via_matrix( - x, "mad", matrixStats::colMads, useNames = FALSE, .ordered_okay = FALSE, ... + x, "mad", function(...) matrixStats::colMads(..., useNames = FALSE), .ordered_okay = FALSE, ... ) } #' @rdname rvar-summaries-over-draws @@ -241,8 +241,8 @@ mad.rvar_ordered <- function(x, ...) { #' @export range.rvar <- function(x, ...) { summarise_rvar_by_element_via_matrix( - x, "range", function(...) t(matrixStats::colRanges(...)), - useNames = FALSE, .extra_dim = 2, .extra_dimnames = list(NULL), ... + x, "range", function(...) t(matrixStats::colRanges(..., useNames = FALSE)), + .extra_dim = 2, .extra_dimnames = list(NULL), ... ) } @@ -253,7 +253,8 @@ range.rvar <- function(x, ...) { #' @export is.finite.rvar <- function(x) { summarise_rvar_by_element_via_matrix( - x, "is.finite", function(x) matrixStats::colAlls(is.finite(x), useNames = FALSE), .factor_okay = TRUE + x, "is.finite", function(x) matrixStats::colAlls(is.finite(x), useNames = FALSE), + .factor_okay = TRUE, .complex_okay = TRUE ) } @@ -261,7 +262,8 @@ is.finite.rvar <- function(x) { #' @export is.infinite.rvar <- function(x) { summarise_rvar_by_element_via_matrix( - x, "is.inifite", function(x) matrixStats::colAnys(is.infinite(x), useNames = FALSE), .factor_okay = TRUE + x, "is.infinite", function(x) matrixStats::colAnys(is.infinite(x), useNames = FALSE), + .factor_okay = TRUE, .complex_okay = TRUE ) } @@ -269,7 +271,8 @@ is.infinite.rvar <- function(x) { #' @export is.nan.rvar <- function(x) { summarise_rvar_by_element_via_matrix( - x, "is.nan", function(x) matrixStats::colAnys(is.nan(x), useNames = FALSE), .factor_okay = TRUE + x, "is.nan", function(x) matrixStats::colAnys(is.nan(x), useNames = FALSE), + .factor_okay = TRUE, .complex_okay = TRUE ) } @@ -277,7 +280,8 @@ is.nan.rvar <- function(x) { #' @export is.na.rvar <- function(x) { summarise_rvar_by_element_via_matrix( - x, "is.na", matrixStats::colAnyNAs, useNames = FALSE, .factor_okay = TRUE + x, "is.na", matrixStats::colAnyNAs, useNames = FALSE, + .factor_okay = TRUE, .complex_okay = TRUE ) } diff --git a/R/rvar-summaries-within-draws.R b/R/rvar-summaries-within-draws.R index 75388455..e511091f 100755 --- a/R/rvar-summaries-within-draws.R +++ b/R/rvar-summaries-within-draws.R @@ -50,7 +50,7 @@ #' @export rvar_mean <- function(..., na.rm = FALSE) { summarise_rvar_within_draws_via_matrix( - c(...), "rvar_mean", matrixStats::rowMeans2, na.rm = na.rm + c(...), "mean", matrixStats::rowMeans2, na.rm = na.rm ) } @@ -60,7 +60,7 @@ rvar_mean <- function(..., na.rm = FALSE) { #' @export rvar_median <- function(..., na.rm = FALSE) { summarise_rvar_within_draws_via_matrix( - c(...), "rvar_median", matrixStats::rowMedians, na.rm = na.rm, .ordered_okay = TRUE + c(...), "median", matrixStats::rowMedians, na.rm = na.rm, .ordered_okay = TRUE ) } @@ -68,7 +68,7 @@ rvar_median <- function(..., na.rm = FALSE) { #' @export rvar_sum <- function(..., na.rm = FALSE) { summarise_rvar_within_draws_via_matrix( - c(...), "rvar_sum", matrixStats::rowSums2, na.rm = na.rm + c(...), "sum", matrixStats::rowSums2, na.rm = na.rm ) } @@ -76,7 +76,7 @@ rvar_sum <- function(..., na.rm = FALSE) { #' @export rvar_prod <- function(..., na.rm = FALSE) { summarise_rvar_within_draws_via_matrix( - c(...), "rvar_prod", matrixStats::rowProds, na.rm = na.rm + c(...), "prod", matrixStats::rowProds, na.rm = na.rm ) } @@ -84,7 +84,7 @@ rvar_prod <- function(..., na.rm = FALSE) { #' @export rvar_min <- function(..., na.rm = FALSE) { summarise_rvar_within_draws_via_matrix( - c(...), "rvar_min", matrixStats::rowMins, na.rm = na.rm, .ordered_okay = TRUE + c(...), "min", matrixStats::rowMins, na.rm = na.rm, .ordered_okay = TRUE ) } @@ -92,7 +92,7 @@ rvar_min <- function(..., na.rm = FALSE) { #' @export rvar_max <- function(..., na.rm = FALSE) { summarise_rvar_within_draws_via_matrix( - c(...), "rvar_max", matrixStats::rowMaxs, na.rm = na.rm, .ordered_okay = TRUE + c(...), "max", matrixStats::rowMaxs, na.rm = na.rm, .ordered_okay = TRUE ) } @@ -103,7 +103,7 @@ rvar_max <- function(..., na.rm = FALSE) { #' @export rvar_sd <- function(..., na.rm = FALSE) { summarise_rvar_within_draws_via_matrix( - c(...), "rvar_sd", matrixStats::rowSds, na.rm = na.rm + c(...), "sd", matrixStats::rowSds, na.rm = na.rm ) } @@ -111,7 +111,7 @@ rvar_sd <- function(..., na.rm = FALSE) { #' @export rvar_var <- function(..., na.rm = FALSE) { summarise_rvar_within_draws_via_matrix( - c(...), "rvar_var", matrixStats::rowVars, na.rm = na.rm + c(...), "var", matrixStats::rowVars, na.rm = na.rm ) } @@ -123,7 +123,7 @@ rvar_mad <- function(..., constant = 1.4826, na.rm = FALSE) { x <- as_rvar_numeric(x) } summarise_rvar_within_draws_via_matrix( - x, "rvar_mad", matrixStats::rowMads, constant = constant, na.rm = na.rm + x, "mad", matrixStats::rowMads, constant = constant, na.rm = na.rm ) } @@ -134,7 +134,7 @@ rvar_mad <- function(..., constant = 1.4826, na.rm = FALSE) { #' @export rvar_range <- function(..., na.rm = FALSE) { summarise_rvar_within_draws_via_matrix( - c(...), "rvar_range", matrixStats::rowRanges, na.rm = na.rm, .ordered_okay = TRUE + c(...), "range", matrixStats::rowRanges, na.rm = na.rm, .ordered_okay = TRUE ) } @@ -154,8 +154,9 @@ rvar_quantile <- function(..., probs, names = FALSE, na.rm = FALSE) { type <- if (is_rvar_ordered(x)) 1 else 7 out <- summarise_rvar_within_draws_via_matrix( - x, "rvar_quantile", matrixStats::rowQuantiles, probs = probs, type = type, - na.rm = na.rm, drop = FALSE, .ordered_okay = TRUE + x, "quantile", function(...) matrixStats::rowQuantiles(..., drop = FALSE), + probs = probs, type = type, na.rm = na.rm, + .ordered_okay = TRUE ) if (!names) { @@ -172,7 +173,7 @@ rvar_quantile <- function(..., probs, names = FALSE, na.rm = FALSE) { #' @export rvar_all <- function(..., na.rm = FALSE) { summarise_rvar_within_draws_via_matrix( - c(...), "rvar_all", matrixStats::rowAlls, na.rm = na.rm + c(...), "all", matrixStats::rowAlls, na.rm = na.rm ) } @@ -180,7 +181,7 @@ rvar_all <- function(..., na.rm = FALSE) { #' @export rvar_any <- function(..., na.rm = FALSE) { summarise_rvar_within_draws_via_matrix( - c(...), "rvar_any", matrixStats::rowAnys, na.rm = na.rm + c(...), "any", matrixStats::rowAnys, na.rm = na.rm ) } From bcd2039f0d27e39dc13a468ca783a283e8df9d64 Mon Sep 17 00:00:00 2001 From: Matthew Kay Date: Sun, 19 Nov 2023 12:21:29 -0600 Subject: [PATCH 05/10] add complex number operators to rvar --- NAMESPACE | 1 + R/rvar-math.R | 6 ++++++ tests/testthat/test-rvar-math.R | 11 +++++++++++ vignettes/rvar.Rmd | 1 + 4 files changed, 19 insertions(+) diff --git a/NAMESPACE b/NAMESPACE index ff163c68..0679d0aa 100755 --- a/NAMESPACE +++ b/NAMESPACE @@ -23,6 +23,7 @@ S3method(.subset_draws,draws_df) S3method(.subset_draws,draws_list) S3method(.subset_draws,draws_matrix) S3method(.subset_draws,draws_rvars) +S3method(Complex,rvar) S3method(Math,rvar) S3method(Math,rvar_factor) S3method(Ops,rvar) diff --git a/R/rvar-math.R b/R/rvar-math.R index 11a8ca3c..a0fe453f 100755 --- a/R/rvar-math.R +++ b/R/rvar-math.R @@ -106,6 +106,12 @@ Math.rvar_factor <- function(x, ...) { stop_no_call("Cannot apply `", .Generic, "` function to rvar_factor objects.") } +#' @export +Complex.rvar <- function(z) { + f <- get(.Generic) + rvar_apply_vec_fun(f, z) +} + # matrix stuff --------------------------------------------------- #' Matrix multiplication of random variables diff --git a/tests/testthat/test-rvar-math.R b/tests/testthat/test-rvar-math.R index 597b7805..1a9d848e 100755 --- a/tests/testthat/test-rvar-math.R +++ b/tests/testthat/test-rvar-math.R @@ -202,6 +202,17 @@ test_that("cumulative functions work", { expect_equal(cummin(x), cummin_ref) }) +test_that("complex number operators work", { + x_array <- array(1:11 + 11:1 * 1i, dim = c(2,2,3)) + x <- rvar(x_array) + + expect_equal(Arg(x), new_rvar(Arg(x_array))) + expect_equal(Conj(x), new_rvar(Conj(x_array))) + expect_equal(Mod(x), new_rvar(Mod(x_array))) + expect_equal(Im(x), new_rvar(Im(x_array))) + expect_equal(Re(x), new_rvar(Re(x_array))) +}) + # matrix stuff ------------------------------------------------------------ test_that("matrix multiplication works", { diff --git a/vignettes/rvar.Rmd b/vignettes/rvar.Rmd index fe61835a..5664bbbf 100755 --- a/vignettes/rvar.Rmd +++ b/vignettes/rvar.Rmd @@ -237,6 +237,7 @@ includes: | Array transposition | `t()`, `aperm()` | | Matrix decomposition | `chol()` | | Matrix diagonals | `diag()` | +| Complex numbers | `Arg()`, `Conj()`, `Im()`, `Mod()`, `Re()` | ## Expectations and summary functions From 254893ecb36e030ef56d21f7f0dc7f71077758f7 Mon Sep 17 00:00:00 2001 From: Matthew Kay Date: Sun, 19 Nov 2023 13:03:05 -0600 Subject: [PATCH 06/10] ensure cumulative functions work on rvar of length 1 --- R/rvar-math.R | 8 ++++++-- tests/testthat/test-rvar-math.R | 4 ++++ 2 files changed, 10 insertions(+), 2 deletions(-) diff --git a/R/rvar-math.R b/R/rvar-math.R index a0fe453f..762bc132 100755 --- a/R/rvar-math.R +++ b/R/rvar-math.R @@ -95,10 +95,14 @@ Math.rvar <- function(x, ...) { if (.Generic %in% c("cumsum", "cumprod", "cummax", "cummin")) { # cumulative functions need to be handled differently # from other functions in this generic - new_rvar(t(apply(draws_of(x), 1, f)), .nchains = nchains(x)) + if (length(x) > 1) { + draws_of(x) <- t(apply(draws_of(x), 1, f)) + } } else { - new_rvar(f(draws_of(x), ...), .nchains = nchains(x)) + draws_of(x) <- f(draws_of(x), ...) } + + x } #' @export diff --git a/tests/testthat/test-rvar-math.R b/tests/testthat/test-rvar-math.R index 1a9d848e..5c4e6916 100755 --- a/tests/testthat/test-rvar-math.R +++ b/tests/testthat/test-rvar-math.R @@ -182,24 +182,28 @@ test_that("cumulative functions work", { cumsum(draws_of(x)[2,,]) )) expect_equal(cumsum(x), cumsum_ref) + expect_equal(cumsum(x[[1]]), x[[1]]) cumprod_ref = new_rvar(rbind( cumprod(draws_of(x)[1,,]), cumprod(draws_of(x)[2,,]) )) expect_equal(cumprod(x), cumprod_ref) + expect_equal(cumprod(x[[1]]), x[[1]]) cummax_ref = new_rvar(rbind( cummax(draws_of(x)[1,,]), cummax(draws_of(x)[2,,]) )) expect_equal(cummax(x), cummax_ref) + expect_equal(cummax(x[[1]]), x[[1]]) cummin_ref = new_rvar(rbind( cummin(draws_of(x)[1,,]), cummin(draws_of(x)[2,,]) )) expect_equal(cummin(x), cummin_ref) + expect_equal(cummin(x[[1]]), x[[1]]) }) test_that("complex number operators work", { From 30442fa8ce9a68a10ff3a579adb65f4ad53f8d8b Mon Sep 17 00:00:00 2001 From: Matthew Kay Date: Sun, 19 Nov 2023 20:34:27 -0600 Subject: [PATCH 07/10] guards and tests for operations on complex rvars --- R/rvar-.R | 8 ++++ R/rvar-dist.R | 3 ++ R/rvar-summaries-over-draws.R | 9 ++++- R/rvar-summaries-within-draws.R | 26 ++++++++++--- tests/testthat/test-discrete-summaries.R | 4 ++ tests/testthat/test-rvar-cast.R | 13 ++++++- tests/testthat/test-rvar-dist.R | 8 ++++ tests/testthat/test-rvar-math.R | 17 +++++++++ .../testthat/test-rvar-summaries-over-draws.R | 37 ++++++++++++++++++- .../test-rvar-summaries-within-draws.R | 28 ++++++++++++++ 10 files changed, 143 insertions(+), 10 deletions(-) diff --git a/R/rvar-.R b/R/rvar-.R index 6b23946f..6783893c 100755 --- a/R/rvar-.R +++ b/R/rvar-.R @@ -474,6 +474,14 @@ setOldClass(get_rvar_class(ordered(NULL))) # helpers: validation ----------------------------------------------------------------- +# check the given rvar is not complex +check_rvar_not_complex <- function(x, f = NULL) { + if (is_rvar_complex(x)) { + f <- if (is.null(f)) "" else paste0("`", f, "` ") + stop_no_call("Cannot apply ", f, "function to complex rvars.") + } +} + # Check the passed yank index (for x[[...]]) is valid check_rvar_yank_index = function(x, i, ...) { index <- dots_list(i, ..., .preserve_empty = TRUE, .ignore_empty = "none") diff --git a/R/rvar-dist.R b/R/rvar-dist.R index 5942953c..310e36eb 100755 --- a/R/rvar-dist.R +++ b/R/rvar-dist.R @@ -40,6 +40,7 @@ #' @name rvar-dist #' @export density.rvar <- function(x, at, ...) { + check_rvar_not_complex(x, "density") summarise_rvar_by_element(x, function(draws) { d <- density(draws, cut = 0, ...) f <- approxfun(d$x, d$y, yleft = 0, yright = 0) @@ -66,6 +67,7 @@ distributional::cdf #' @rdname rvar-dist #' @export cdf.rvar <- function(x, q, ...) { + check_rvar_not_complex(x, "cdf") summarise_rvar_by_element(x, function(draws) { ecdf(draws)(q) }) @@ -91,6 +93,7 @@ cdf.rvar_ordered <- function(x, q, ...) { #' @rdname rvar-dist #' @export quantile.rvar <- function(x, probs, ...) { + check_rvar_not_complex(x, "quantile") summarise_rvar_by_element_via_matrix(x, "quantile", function(..., names) t(matrixStats::colQuantiles(..., useNames = FALSE)), diff --git a/R/rvar-summaries-over-draws.R b/R/rvar-summaries-over-draws.R index 7c7dc682..54b7a785 100755 --- a/R/rvar-summaries-over-draws.R +++ b/R/rvar-summaries-over-draws.R @@ -88,7 +88,7 @@ Pr.logical <- function(x, ...) { #' @rdname rvar-summaries-over-draws #' @export Pr.rvar <- function(x, ...) { - if (!is.logical(draws_of(x))) { + if (!is_rvar_logical(x)) { stop_no_call("Can only use `Pr()` on logical random variables.") } mean(x, ...) @@ -100,6 +100,7 @@ Pr.rvar <- function(x, ...) { #' @rdname rvar-summaries-over-draws #' @export median.rvar <- function(x, ...) { + check_rvar_not_complex(x, "median") summarise_rvar_by_element_via_matrix( x, "median", function(...) matrixStats::colMedians(..., useNames = FALSE), ... ) @@ -108,6 +109,7 @@ median.rvar <- function(x, ...) { #' @rdname rvar-summaries-over-draws #' @export min.rvar <- function(x, ...) { + check_rvar_not_complex(x, "min") summarise_rvar_by_element_via_matrix( x, "min", function(...) matrixStats::colMins(..., useNames = FALSE), ... ) @@ -116,6 +118,7 @@ min.rvar <- function(x, ...) { #' @rdname rvar-summaries-over-draws #' @export max.rvar <- function(x, ...) { + check_rvar_not_complex(x, "max") summarise_rvar_by_element_via_matrix( x, "max", function(...) matrixStats::colMaxs(..., useNames = FALSE), ... ) @@ -140,6 +143,7 @@ prod.rvar <- function(x, ...) { #' @rdname rvar-summaries-over-draws #' @export all.rvar <- function(x, ...) { + check_rvar_not_complex(x, "all") summarise_rvar_by_element_via_matrix( x, "all", function(...) matrixStats::colAlls(..., useNames = FALSE), .ordered_okay = FALSE, ... ) @@ -148,6 +152,7 @@ all.rvar <- function(x, ...) { #' @rdname rvar-summaries-over-draws #' @export any.rvar <- function(x, ...) { + check_rvar_not_complex(x, "any") summarise_rvar_by_element_via_matrix( x, "any", function(...) matrixStats::colAnys(..., useNames = FALSE), .ordered_okay = FALSE, ... ) @@ -224,6 +229,7 @@ mad.default <- function(x, ...) stats::mad(x, ...) #' @rdname rvar-summaries-over-draws #' @export mad.rvar <- function(x, ...) { + check_rvar_not_complex(x, "mad") summarise_rvar_by_element_via_matrix( x, "mad", function(...) matrixStats::colMads(..., useNames = FALSE), .ordered_okay = FALSE, ... ) @@ -240,6 +246,7 @@ mad.rvar_ordered <- function(x, ...) { #' @rdname rvar-summaries-over-draws #' @export range.rvar <- function(x, ...) { + check_rvar_not_complex(x, "range") summarise_rvar_by_element_via_matrix( x, "range", function(...) t(matrixStats::colRanges(..., useNames = FALSE)), .extra_dim = 2, .extra_dimnames = list(NULL), ... diff --git a/R/rvar-summaries-within-draws.R b/R/rvar-summaries-within-draws.R index e511091f..40057df3 100755 --- a/R/rvar-summaries-within-draws.R +++ b/R/rvar-summaries-within-draws.R @@ -59,8 +59,10 @@ rvar_mean <- function(..., na.rm = FALSE) { #' @rdname rvar-summaries-within-draws #' @export rvar_median <- function(..., na.rm = FALSE) { + x <- c(...) + check_rvar_not_complex(x, "rvar_median") summarise_rvar_within_draws_via_matrix( - c(...), "median", matrixStats::rowMedians, na.rm = na.rm, .ordered_okay = TRUE + x, "median", matrixStats::rowMedians, na.rm = na.rm, .ordered_okay = TRUE ) } @@ -83,16 +85,20 @@ rvar_prod <- function(..., na.rm = FALSE) { #' @rdname rvar-summaries-within-draws #' @export rvar_min <- function(..., na.rm = FALSE) { + x <- c(...) + check_rvar_not_complex(x, "rvar_min") summarise_rvar_within_draws_via_matrix( - c(...), "min", matrixStats::rowMins, na.rm = na.rm, .ordered_okay = TRUE + x, "min", matrixStats::rowMins, na.rm = na.rm, .ordered_okay = TRUE ) } #' @rdname rvar-summaries-within-draws #' @export rvar_max <- function(..., na.rm = FALSE) { + x <- c(...) + check_rvar_not_complex(x, "rvar_max") summarise_rvar_within_draws_via_matrix( - c(...), "max", matrixStats::rowMaxs, na.rm = na.rm, .ordered_okay = TRUE + x, "max", matrixStats::rowMaxs, na.rm = na.rm, .ordered_okay = TRUE ) } @@ -122,6 +128,7 @@ rvar_mad <- function(..., constant = 1.4826, na.rm = FALSE) { if (is_rvar_ordered(x)) { x <- as_rvar_numeric(x) } + check_rvar_not_complex(x, "rvar_mad") summarise_rvar_within_draws_via_matrix( x, "mad", matrixStats::rowMads, constant = constant, na.rm = na.rm ) @@ -133,8 +140,10 @@ rvar_mad <- function(..., constant = 1.4826, na.rm = FALSE) { #' @rdname rvar-summaries-within-draws #' @export rvar_range <- function(..., na.rm = FALSE) { + x <- c(...) + check_rvar_not_complex(x, "rvar_range") summarise_rvar_within_draws_via_matrix( - c(...), "range", matrixStats::rowRanges, na.rm = na.rm, .ordered_okay = TRUE + x, "range", matrixStats::rowRanges, na.rm = na.rm, .ordered_okay = TRUE ) } @@ -147,6 +156,7 @@ rvar_quantile <- function(..., probs, names = FALSE, na.rm = FALSE) { names <- as_one_logical(names) na.rm <- as_one_logical(na.rm) x <- c(...) + check_rvar_not_complex(x, "rvar_quantile") # use type 1 for ordered rvar because it is the inverse of the ECDF of # a discrete distribution (hence appropriate for ordinal samples), otherwise @@ -172,16 +182,20 @@ rvar_quantile <- function(..., probs, names = FALSE, na.rm = FALSE) { #' @rdname rvar-summaries-within-draws #' @export rvar_all <- function(..., na.rm = FALSE) { + x <- c(...) + check_rvar_not_complex(x, "rvar_all") summarise_rvar_within_draws_via_matrix( - c(...), "all", matrixStats::rowAlls, na.rm = na.rm + x, "all", matrixStats::rowAlls, na.rm = na.rm ) } #' @rdname rvar-summaries-within-draws #' @export rvar_any <- function(..., na.rm = FALSE) { + x <- c(...) + check_rvar_not_complex(x, "rvar_any") summarise_rvar_within_draws_via_matrix( - c(...), "any", matrixStats::rowAnys, na.rm = na.rm + x, "any", matrixStats::rowAnys, na.rm = na.rm ) } diff --git a/tests/testthat/test-discrete-summaries.R b/tests/testthat/test-discrete-summaries.R index 78ce902b..b43d5837 100644 --- a/tests/testthat/test-discrete-summaries.R +++ b/tests/testthat/test-discrete-summaries.R @@ -5,22 +5,26 @@ test_that("modal_category works on vectors", { expect_equal(modal_category(logical()), logical()) expect_equal(modal_category(double()), double()) expect_equal(modal_category(integer()), integer()) + expect_equal(modal_category(complex()), complex()) expect_equal(modal_category(character()), character()) expect_equal(modal_category(factor()), character()) expect_equal(modal_category(ordered(NULL)), character()) expect_equal(modal_category(NA), NA) expect_equal(modal_category(c(1,2.1,2.1,3,3)), 2.1) + expect_equal(modal_category(c(1,2.1i,2.1i,3,3)), 2.1i) expect_equal(modal_category(c("a","b","b","c","c")), "b") expect_equal(modal_category(factor(c("a","b","b","c","c"))), "b") }) test_that("modal_category works on rvars", { expect_equal(modal_category(rvar()), double()) + expect_equal(modal_category(rvar(complex())), complex()) expect_equal(modal_category(rvar_factor()), character()) expect_equal(modal_category(rvar_ordered()), character()) expect_equal(modal_category(c(rvar(c(1,2.1,2.1,3,3)), rvar(1))), c(2.1, 1)) + expect_equal(modal_category(c(rvar(c(1,2.1i,2.1i,3,3)), rvar(1))), c(2.1i, 1)) expect_equal(modal_category(c(rvar(c("a","b","b","c","c")), rvar("c"))), c("b","c")) }) diff --git a/tests/testthat/test-rvar-cast.R b/tests/testthat/test-rvar-cast.R index f411c56a..d4330f7a 100755 --- a/tests/testthat/test-rvar-cast.R +++ b/tests/testthat/test-rvar-cast.R @@ -36,7 +36,7 @@ test_that("as_rvar preserves dimension names", { }) -# as_rvar_numeric/integer/logical ------------------------------------------------ +# as_rvar_numeric/complex/integer/logical ----------------------------------------- test_that("as_rvar_numeric works", { x_array = array( @@ -57,6 +57,17 @@ test_that("as_rvar_numeric works", { expect_type(draws_of(as_rvar_numeric(x_fct)), "double") }) +test_that("as_rvar_complex works", { + x_array = array( + 1:24, dim = c(2,4,3), + dimnames = list(NULL, A = paste0("a", 1:4), B = paste0("b", 1:3)) + ) + x <- rvar(x_array) + + expect_equal(as_rvar_complex(x), new_rvar(x_array + 0i)) + expect_type(draws_of(as_rvar_complex(x)), "complex") +}) + test_that("as_rvar_integer works", { x_array = array( 1L:24L, dim = c(2,4,3), diff --git a/tests/testthat/test-rvar-dist.R b/tests/testthat/test-rvar-dist.R index 09c2460c..9e51cc43 100755 --- a/tests/testthat/test-rvar-dist.R +++ b/tests/testthat/test-rvar-dist.R @@ -60,3 +60,11 @@ test_that("distributional functions work on an rvar_ordered", { expect_equal(quantile(x, c(.3, .5, .9, 1)), letters[2:5]) }) + +test_that("distributional functions throw an error on complex rvars", { + x <- rvar(1:10 + 10:1 * 1i) + + expect_error(density(x, 5 + 1i), "Cannot apply.*to complex rvars") + expect_error(cdf(x, 5 + 1i), "Cannot apply.*to complex rvars") + expect_error(quantile(x, 0.5), "Cannot apply.*to complex rvars") +}) diff --git a/tests/testthat/test-rvar-math.R b/tests/testthat/test-rvar-math.R index 5c4e6916..9fc5e110 100755 --- a/tests/testthat/test-rvar-math.R +++ b/tests/testthat/test-rvar-math.R @@ -30,6 +30,8 @@ test_that("math operators works", { expect_equal(2 ^ x, new_rvar(2 ^ (x_array))) expect_equal(x ^ y, new_rvar(x_array ^ y_array)) + expect_equal((x * 1i) ^ 2, new_rvar(-(x_array ^ 2) + 0i)) + # ensure broadcasting of constants retains shape z2 <- new_rvar(array(1, dim = c(1,1))) z4 <- new_rvar(array(2, dim = c(1,1,1,1))) @@ -94,6 +96,8 @@ test_that("comparison operators work", { expect_equal(x != 5, new_rvar(x_array != 5)) expect_equal(5 != x, new_rvar(5 != x_array)) expect_equal(x != y, new_rvar(x_array != y_array)) + + expect_error(new_rvar(1i) > new_rvar(1i)) }) test_that("comparison operators work on rvar_factors", { @@ -183,6 +187,7 @@ test_that("cumulative functions work", { )) expect_equal(cumsum(x), cumsum_ref) expect_equal(cumsum(x[[1]]), x[[1]]) + expect_equal(cumsum(x + x * 1i), cumsum_ref + cumsum_ref * 1i) cumprod_ref = new_rvar(rbind( cumprod(draws_of(x)[1,,]), @@ -191,12 +196,19 @@ test_that("cumulative functions work", { expect_equal(cumprod(x), cumprod_ref) expect_equal(cumprod(x[[1]]), x[[1]]) + complex_cumprod_ref <- new_rvar(rbind( + cumprod(draws_of(x)[1,,] + draws_of(x)[1,,] * 1i), + cumprod(draws_of(x)[2,,] + draws_of(x)[2,,] * 1i) + )) + expect_equal(cumprod(x + x * 1i), complex_cumprod_ref) + cummax_ref = new_rvar(rbind( cummax(draws_of(x)[1,,]), cummax(draws_of(x)[2,,]) )) expect_equal(cummax(x), cummax_ref) expect_equal(cummax(x[[1]]), x[[1]]) + expect_error(cummax(x * 1i)) cummin_ref = new_rvar(rbind( cummin(draws_of(x)[1,,]), @@ -204,6 +216,7 @@ test_that("cumulative functions work", { )) expect_equal(cummin(x), cummin_ref) expect_equal(cummin(x[[1]]), x[[1]]) + expect_error(cummin(x * 1i)) }) test_that("complex number operators work", { @@ -232,6 +245,7 @@ test_that("matrix multiplication works", { x_array[4,,] %*% y_array[4,,] )) expect_equal(x %**% y, xy_ref) + expect_equal((x * 1i) %**% (y * 1i), -xy_ref + 0i) x_array = array(1:6, dim = c(2,3)) @@ -244,6 +258,7 @@ test_that("matrix multiplication works", { x_array[2,] %*% y_array[2,] )) expect_equal(x %**% y, xy_ref) + expect_equal((x * 1i) %**% (y * 1i), -xy_ref + 0i) # automatic promotion to row/col vector of numeric vectors x_meany_ref = new_rvar(abind::abind(along = 0, @@ -251,12 +266,14 @@ test_that("matrix multiplication works", { x_array[2,] %*% colMeans(y_array) )) expect_equal(x %**% colMeans(y_array), x_meany_ref) + expect_equal((x * 1i) %**% (colMeans(y_array) * 1i), -x_meany_ref + 0i) meanx_y_ref = new_rvar(abind::abind(along = 0, colMeans(x_array) %*% y_array[1,], colMeans(x_array) %*% y_array[2,] )) expect_equal(colMeans(x_array) %**% y, meanx_y_ref) + expect_equal((colMeans(x_array) * 1i) %**% (y * 1i), -meanx_y_ref + 0i) # dimension name preservation m1 <- as_rvar(diag(1:3)) diff --git a/tests/testthat/test-rvar-summaries-over-draws.R b/tests/testthat/test-rvar-summaries-over-draws.R index 25fbda0e..011efed2 100755 --- a/tests/testthat/test-rvar-summaries-over-draws.R +++ b/tests/testthat/test-rvar-summaries-over-draws.R @@ -3,6 +3,8 @@ test_that("numeric summaries work", { x_array <- array(1:24, dim = c(4,2,3), dimnames = list(NULL, a = c("a1", "a2"), b = c("b1", "b2", "b3"))) x <- new_rvar(x_array) + x_cmp_array <- x_array + x_array*1i + x_cmp <- new_rvar(x_cmp_array) x_letters <- array(letters[1:24], dim = c(4,2,3), dimnames = list(NULL, a = c("a1", "a2"), b = c("b1", "b2", "b3"))) x_ord <- rvar_ordered(x_letters, levels = letters) x_fct <- rvar_factor(x_letters, levels = letters) @@ -13,6 +15,12 @@ test_that("numeric summaries work", { expect_equal(min(x), apply(x_array, c(2,3), min)) expect_equal(max(x), apply(x_array, c(2,3), max)) + expect_error(median(x_cmp)) + expect_equal(sum(x_cmp), apply(x_cmp_array, c(2,3), sum)) + expect_equal(prod(x_cmp), apply(x_cmp_array, c(2,3), prod)) + expect_error(min(x_cmp)) + expect_error(max(x_cmp)) + ordered_out <- function(x) structure( x, dim = c(2,3), dimnames = list(a = c("a1", "a2"), b = c("b1", "b2", "b3")), levels = letters, class = c("ordered", "factor") @@ -46,6 +54,10 @@ test_that("means work", { expect_error(Pr(x_array)) expect_equal(E(x_array), mean(x_array)) + x_cmp_array <- x_array + x_array*1i + x_cmp <- new_rvar(x_cmp_array) + expect_equal(mean(x_cmp), apply(x_cmp_array, c(2,3), mean)) + # test vector rvars as well since these should be summarized down to vectors # (not one-dimensional arrays) y_array <- array(1:24, dim = c(4,6), dimnames = list(NULL, paste0("a", 1:6))) @@ -62,6 +74,8 @@ test_that("means work", { test_that("spread functions work", { x_array <- array(1:24, dim = c(4,2,3), dimnames = list(NULL, a = c("a1", "a2"), b = c("b1", "b2", "b3"))) x <- new_rvar(x_array) + x_cmp_array <- x_array + x_array*1i + x_cmp <- new_rvar(x_cmp_array) x_letters <- array(letters[1:24], dim = c(4,2,3), dimnames = list(NULL, a = c("a1", "a2"), b = c("b1", "b2", "b3"))) x_ord <- rvar_ordered(x_letters, levels = letters) x_fct <- rvar_factor(x_letters, levels = letters) @@ -71,6 +85,11 @@ test_that("spread functions work", { expect_equal(var(x), apply(x_array, c(2,3), var)) expect_equal(mad(x), apply(x_array, c(2,3), mad)) + expect_equal(sd(x_cmp), apply(x_cmp_array, c(2,3), sd)) + expect_equal(variance(x_cmp), apply(x_cmp_array, c(2,3), var)) + expect_equal(var(x_cmp), apply(x_cmp_array, c(2,3), var)) + expect_error(mad(x_cmp), "Cannot apply.*complex rvars") + expect_error(sd(x_ord)) expect_error(variance(x_ord)) expect_error(var(x_ord)) @@ -114,6 +133,7 @@ test_that("range works", { ) ) expect_error(range(x_fct)) + expect_error(range(rvar(1i)), "Cannot apply.*complex rvars") y_array <- array(1:24, dim = c(4,6), dimnames = list(NULL, paste0("a", 1:6))) y <- new_rvar(y_array) @@ -140,8 +160,10 @@ test_that("logical summaries work", { expect_equal(all(y > 10), apply(y_array > 10, 2, all)) expect_equal(any(y > 10), apply(y_array > 10, 2, any)) - expect_error(all(rvar("a"))) - expect_error(any(rvar("a"))) + expect_error(all(rvar("a")), "Cannot apply.*rvar_factor") + expect_error(any(rvar("a")), "Cannot apply.*rvar_factor") + expect_error(all(rvar(1i)), "Cannot apply.*complex rvars") + expect_error(any(rvar(1i)), "Cannot apply.*complex rvars") }) @@ -152,15 +174,21 @@ test_that("special value predicates work", { dim = c(4,2,3), dimnames = list(NULL, a = c("a1", "a2"), b = c("b1", "b2", "b3")) ) x <- new_rvar(x_array) + x_cmp_array <- x_array + 1i + x_cmp <- new_rvar(x_cmp_array) x_letters <- array(c("a",NA,letters[3:12], NaN, letters[14:24]), dim = c(4,2,3), dimnames = list(NULL, a = c("a1", "a2"), b = c("b1", "b2", "b3"))) x_ord <- rvar_ordered(x_letters, levels = letters) x_fct <- rvar_factor(x_letters, levels = letters) .dimnames = list(a = c("a1", "a2"), b = c("b1", "b2", "b3")) expect_equal(is.finite(x), array(c(rep(FALSE, 4), rep(TRUE, 2)), dim = c(2,3), dimnames = .dimnames)) + expect_equal(is.finite(x_cmp), array(c(rep(FALSE, 4), rep(TRUE, 2)), dim = c(2,3), dimnames = .dimnames)) expect_equal(is.infinite(x), array(c(FALSE, TRUE, TRUE, FALSE, FALSE, FALSE), dim = c(2,3), dimnames = .dimnames)) + expect_equal(is.infinite(x_cmp), array(c(FALSE, TRUE, TRUE, FALSE, FALSE, FALSE), dim = c(2,3), dimnames = .dimnames)) expect_equal(is.nan(x), array(c(FALSE, FALSE, FALSE, TRUE, FALSE, FALSE), dim = c(2,3), dimnames = .dimnames)) + expect_equal(is.nan(x_cmp), array(c(FALSE, FALSE, FALSE, TRUE, FALSE, FALSE), dim = c(2,3), dimnames = .dimnames)) expect_equal(is.na(x), array(c(TRUE, FALSE, FALSE, TRUE, FALSE, FALSE), dim = c(2,3), dimnames = .dimnames)) + expect_equal(is.na(x_cmp), array(c(TRUE, FALSE, FALSE, TRUE, FALSE, FALSE), dim = c(2,3), dimnames = .dimnames)) .dimnames = list(a = c("a1", "a2"), b = c("b1", "b2", "b3")) expect_equal(is.finite(x_ord), array(c(FALSE, TRUE, TRUE, FALSE, rep(TRUE, 2)), dim = c(2,3), dimnames = .dimnames)) @@ -186,15 +214,20 @@ test_that("special value predicates work", { test_that("anyNA works", { x_array <- array(1:24, dim = c(4,2,3), dimnames = list(NULL, a = c("a1", "a2"), b = c("b1", "b2", "b3"))) x <- new_rvar(x_array) + x_cmp_array <- x_array + 1i + x_cmp <- new_rvar(x_cmp_array) x_letters <- array(letters[1:24], dim = c(4,2,3), dimnames = list(NULL, a = c("a1", "a2"), b = c("b1", "b2", "b3"))) x_ord <- rvar_ordered(x_letters, levels = letters) x_fct <- rvar_factor(x_letters, levels = letters) expect_equal(anyNA(x), FALSE) + expect_equal(anyNA(x_cmp), FALSE) expect_equal(anyNA(x_fct), FALSE) expect_equal(anyNA(x_ord), FALSE) x[2,1] <- NA expect_equal(anyNA(x), TRUE) + x_cmp[2,1] <- NA + expect_equal(anyNA(x_cmp), TRUE) x_fct[2,1] <- NA expect_equal(anyNA(x_fct), TRUE) x_ord[2,1] <- NA diff --git a/tests/testthat/test-rvar-summaries-within-draws.R b/tests/testthat/test-rvar-summaries-within-draws.R index 50ad4faf..7ed38d95 100755 --- a/tests/testthat/test-rvar-summaries-within-draws.R +++ b/tests/testthat/test-rvar-summaries-within-draws.R @@ -3,6 +3,8 @@ test_that("numeric summary functions work", { x_array <- array(1:24, dim = c(4,2,3)) x <- new_rvar(x_array) + x_cmp_array <- x_array + x_array * 1i + x_cmp <- rvar(x_cmp_array) x_letters <- array(letters[1:24], dim = c(4,2,3)) x_ord <- rvar_ordered(x_letters, levels = letters) x_fct <- rvar_factor(x_letters, levels = letters) @@ -14,6 +16,13 @@ test_that("numeric summary functions work", { expect_equal(draws_of(rvar_min(x)), apply(x_array, 1, min), check.attributes = FALSE) expect_equal(draws_of(rvar_max(x)), apply(x_array, 1, max), check.attributes = FALSE) + expect_equal(draws_of(rvar_mean(x_cmp)), as.matrix(apply(x_cmp_array, 1, mean)), check.attributes = FALSE) + expect_error(draws_of(rvar_median(x_cmp))) + expect_equal(draws_of(rvar_sum(x_cmp)), as.matrix(apply(x_cmp_array, 1, sum)), check.attributes = FALSE) + expect_equal(draws_of(rvar_prod(x_cmp)), as.matrix(apply(x_cmp_array, 1, prod)), check.attributes = FALSE) + expect_error(draws_of(rvar_min(x_cmp))) + expect_error(draws_of(rvar_max(x_cmp))) + expect_error(rvar_mean(x_ord)) expect_equal(rvar_median(x_ord), rvar_ordered(letters[apply(x_array, 1, median)], levels = letters)) expect_error(rvar_sum(x_ord)) @@ -51,6 +60,8 @@ test_that("numeric summary functions work", { test_that("spread summary functions work", { x_array <- array(1:24, dim = c(4,2,3)) x <- new_rvar(x_array) + x_cmp_array <- x_array + x_array * 1i + x_cmp <- new_rvar(x_cmp_array) x_letters <- array(letters[1:24], dim = c(4,2,3)) x_ord <- rvar_ordered(x_letters, levels = letters) x_fct <- rvar_factor(x_letters, levels = letters) @@ -60,6 +71,10 @@ test_that("spread summary functions work", { expect_equal(draws_of(rvar_mad(x)), apply(x_array, 1, mad), check.attributes = FALSE) expect_equal(draws_of(rvar_mad(x, constant = 1)), apply(x_array, 1, mad, constant = 1), check.attributes = FALSE) + expect_equal(draws_of(rvar_sd(x_cmp)), apply(x_cmp_array, 1, sd), check.attributes = FALSE) + expect_equal(draws_of(rvar_var(x_cmp)), apply(x_cmp_array, 1, function(x) var(as.vector(x))), check.attributes = FALSE) + expect_error(rvar_mad(x_cmp)) + expect_error(rvar_sd(x_ord)) expect_error(rvar_var(x_ord)) expect_equal(rvar_mad(x_ord, constant = 1), rvar(apply(x_array, 1, mad, constant = 1))) @@ -92,6 +107,7 @@ test_that("rvar_range works", { x_ord <- rvar_ordered(x_letters, levels = letters) expect_equal(draws_of(rvar_range(x)), t(apply(x_array, 1, range)), check.attributes = FALSE) + expect_error(rvar_range(rvar(1i))) expect_equal(rvar_range(x_ord), rvar_ordered(array(letters[t(apply(x_array, 1, range))], dim = c(4, 2)), levels = letters)) expect_error(rvar_range(rvar_factor("a"))) @@ -128,6 +144,9 @@ test_that("rvar_quantile works", { # passing NULL should still result in a vector with length = length(probs) expect_equal(rvar_quantile(NULL, probs = c(0.25, 0.75)), as_rvar(c(NA_real_, NA_real_))) + + expect_error(rvar_quantile(rvar(1i), probs = 0.5)) + expect_error(rvar_quantile(rvar_factor("a"), probs = 0.5)) }) @@ -141,7 +160,9 @@ test_that("logical summaries work", { expect_equal(draws_of(rvar_any(x > 6)), as.matrix(apply(x_array > 6, 1, any)), check.attributes = FALSE) expect_error(rvar_all(rvar("a"))) + expect_error(rvar_all(rvar(1i))) expect_error(rvar_any(rvar("a"))) + expect_error(rvar_any(rvar(1i))) # default values on empty input expect_equal(rvar_all(), as_rvar(TRUE)) @@ -154,6 +175,8 @@ test_that("logical summaries work", { test_that("special value predicates work", { x_array <- c(1, Inf, -Inf, NaN, NA) x <- new_rvar(x_array) + x_cmp_array <- x_array + 1i + x_cmp <- new_rvar(x_cmp_array) x_letters <- factor(letters[c(1, 2, 3, NaN, NA)]) x_ord <- rvar_ordered(x_letters) x_fct <- rvar_factor(x_letters) @@ -163,6 +186,11 @@ test_that("special value predicates work", { expect_equal(draws_of(rvar_is_nan(x)), as.matrix(is.nan(x_array)), check.attributes = FALSE) expect_equal(draws_of(rvar_is_na(x)), as.matrix(is.na(x_array)), check.attributes = FALSE) + expect_equal(draws_of(rvar_is_finite(x_cmp)), as.matrix(is.finite(x_cmp_array)), check.attributes = FALSE) + expect_equal(draws_of(rvar_is_infinite(x_cmp)), as.matrix(is.infinite(x_cmp_array)), check.attributes = FALSE) + expect_equal(draws_of(rvar_is_nan(x_cmp)), as.matrix(is.nan(x_cmp_array)), check.attributes = FALSE) + expect_equal(draws_of(rvar_is_na(x_cmp)), as.matrix(is.na(x_cmp_array)), check.attributes = FALSE) + expect_equal(draws_of(rvar_is_finite(x_ord)), as.matrix(is.finite(x_letters)), check.attributes = FALSE) expect_equal(draws_of(rvar_is_infinite(x_ord)), as.matrix(is.infinite(x_letters)), check.attributes = FALSE) expect_equal(draws_of(rvar_is_nan(x_ord)), as.matrix(is.nan(x_letters)), check.attributes = FALSE) From cf2c2f2da60313235a0990ff4e214b517e5a3834 Mon Sep 17 00:00:00 2001 From: Matthew Kay Date: Sun, 19 Nov 2023 23:00:55 -0600 Subject: [PATCH 08/10] test coverage improvements --- R/as_draws.R | 2 +- R/misc.R | 5 --- tests/testthat/test-rvar-bind.R | 2 ++ tests/testthat/test-rvar-cast.R | 59 +++++++++++++++++++++++++++++++++ 4 files changed, 62 insertions(+), 6 deletions(-) diff --git a/R/as_draws.R b/R/as_draws.R index 261b077e..aa112585 100644 --- a/R/as_draws.R +++ b/R/as_draws.R @@ -145,7 +145,7 @@ validate_draws_per_variable <- function(...) { # '.nchains' is an additional argument in chain supporting formats stop_no_call("'.nchains' is not supported for this format.") } - out <- lapply(out, as_numeric_or_complex) + out <- lapply(out, function(x) if (is.numeric(x) || is.complex(x)) x else as.numeric(x)) ndraws_per_variable <- lengths(out) ndraws <- max(ndraws_per_variable) if (!all(ndraws_per_variable %in% c(1, ndraws))) { diff --git a/R/misc.R b/R/misc.R index 68fac39f..e5c681d8 100644 --- a/R/misc.R +++ b/R/misc.R @@ -111,11 +111,6 @@ as_one_character <- function(x, allow_na = FALSE) { x } -# coerce 'x' to a numeric or complex vector -as_numeric_or_complex <- function(x) { - if (is.numeric(x) || is.complex(x)) x else as.numeric(x) -} - # check if all inputs are NULL all_null <- function(...) { all(ulapply(list(...), is.null)) diff --git a/tests/testthat/test-rvar-bind.R b/tests/testthat/test-rvar-bind.R index 6f580ed5..c89126ea 100755 --- a/tests/testthat/test-rvar-bind.R +++ b/tests/testthat/test-rvar-bind.R @@ -24,6 +24,8 @@ test_that("c works on rvar", { expect_equal(vctrs::vec_c(5, x), rvar(array(c(5, 5, 5, 1:9), dim = c(3,4)))) expect_equal(c(x, 5L), rvar(array(c(1:9, 5, 5, 5), dim = c(3,4)))) expect_equal(vctrs::vec_c(5L, x), rvar(array(c(5, 5, 5, 1:9), dim = c(3,4)))) + expect_equal(c(x, 5i), rvar(array(c(1:9, 5i, 5i, 5i), dim = c(3,4)))) + expect_equal(vctrs::vec_c(5i, x), rvar(array(c(5i, 5i, 5i, 1:9), dim = c(3,4)))) expect_equal(c(x == 1, TRUE), rvar(array(c(1:9 == 1, TRUE, TRUE, TRUE), dim = c(3,4)))) expect_equal(vctrs::vec_c(TRUE, x == TRUE), rvar(array(c(TRUE, TRUE, TRUE, 1:9 == 1), dim = c(3,4)))) diff --git a/tests/testthat/test-rvar-cast.R b/tests/testthat/test-rvar-cast.R index d4330f7a..b697bd46 100755 --- a/tests/testthat/test-rvar-cast.R +++ b/tests/testthat/test-rvar-cast.R @@ -114,6 +114,10 @@ test_that("as_rvar_factor works", { draws_of(as_rvar_factor(array(1:4, dim = c(2,2)))), structure(1:4L, levels = c("1", "2", "3", "4"), dim = c(1, 2, 2), dimnames = list("1", NULL, NULL), class = "factor") ) + expect_equal( + draws_of(as_rvar_factor(array(1:4 * 1i, dim = c(2,2)))), + structure(1:4L, levels = c("0+1i", "0+2i", "0+3i", "0+4i"), dim = c(1, 2, 2), dimnames = list("1", NULL, NULL), class = "factor") + ) expect_equal( draws_of(as_rvar_factor(array(c(TRUE, TRUE, FALSE, FALSE), dim = c(2,2)))), structure(c(2, 2, 1, 1), levels = c("FALSE", "TRUE"), dim = c(1, 2, 2), dimnames = list("1", NULL, NULL), class = "factor") @@ -144,6 +148,10 @@ test_that("as_rvar_ordered works", { draws_of(as_rvar_ordered(array(1:4, dim = c(2,2)))), structure(1:4L, levels = c("1", "2", "3", "4"), dim = c(1, 2, 2), dimnames = list("1", NULL, NULL), class = c("ordered", "factor")) ) + expect_equal( + draws_of(as_rvar_ordered(array(1:4 * 1i, dim = c(2,2)))), + structure(1:4L, levels = c("0+1i", "0+2i", "0+3i", "0+4i"), dim = c(1, 2, 2), dimnames = list("1", NULL, NULL), class = c("ordered", "factor")) + ) expect_equal( draws_of(as_rvar_ordered(as_rvar(array(1:4, dim = c(2,2))))), structure(1:4L, levels = c("1", "2", "3", "4"), dim = c(1, 2, 2), dimnames = list("1", NULL, NULL), class = c("ordered", "factor")) @@ -216,6 +224,57 @@ test_that("casting to/from rvar/distribution objects works", { # type predicates --------------------------------------------------------- +test_that("is_rvar_XXX works", { + x <- rvar() + x_lgl <- rvar(logical()) + x_int <- rvar(integer()) + x_cmp <- rvar(complex()) + x_fct <- rvar(factor()) + x_ord <- rvar(ordered()) + + expect_true(is_rvar(x)) + expect_false(is_rvar_logical(x)) + expect_false(is_rvar_integer(x)) + expect_false(is_rvar_complex(x)) + expect_false(is_rvar_factor(x)) + expect_false(is_rvar_ordered(x)) + + expect_true(is_rvar(x_lgl)) + expect_true(is_rvar_logical(x_lgl)) + expect_false(is_rvar_integer(x_lgl)) + expect_false(is_rvar_complex(x_lgl)) + expect_false(is_rvar_factor(x_lgl)) + expect_false(is_rvar_ordered(x_lgl)) + + expect_true(is_rvar(x_int)) + expect_false(is_rvar_logical(x_int)) + expect_true(is_rvar_integer(x_int)) + expect_false(is_rvar_complex(x_int)) + expect_false(is_rvar_factor(x_int)) + expect_false(is_rvar_ordered(x_int)) + + expect_true(is_rvar(x_cmp)) + expect_false(is_rvar_logical(x_cmp)) + expect_false(is_rvar_integer(x_cmp)) + expect_true(is_rvar_complex(x_cmp)) + expect_false(is_rvar_factor(x_cmp)) + expect_false(is_rvar_ordered(x_cmp)) + + expect_true(is_rvar(x_fct)) + expect_false(is_rvar_logical(x_fct)) + expect_false(is_rvar_integer(x_fct)) + expect_false(is_rvar_complex(x_fct)) + expect_true(is_rvar_factor(x_fct)) + expect_false(is_rvar_ordered(x_fct)) + + expect_true(is_rvar(x_ord)) + expect_false(is_rvar_logical(x_ord)) + expect_false(is_rvar_integer(x_ord)) + expect_false(is_rvar_complex(x_ord)) + expect_true(is_rvar_factor(x_ord)) + expect_true(is_rvar_ordered(x_ord)) +}) + test_that("is.matrix/array on rvar works", { x_mat <- rvar(array(1:24, dim = c(2,2,6))) x_arr <- rvar(array(1:24, dim = c(2,2,3,2))) From 3bab85c1e7e86191ee5e9a102021fc7783bbdd78 Mon Sep 17 00:00:00 2001 From: Matthew Kay Date: Sun, 19 Nov 2023 23:09:07 -0600 Subject: [PATCH 09/10] test fix for old R --- tests/testthat/test-rvar-cast.R | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/testthat/test-rvar-cast.R b/tests/testthat/test-rvar-cast.R index b697bd46..00c9bc6d 100755 --- a/tests/testthat/test-rvar-cast.R +++ b/tests/testthat/test-rvar-cast.R @@ -230,7 +230,7 @@ test_that("is_rvar_XXX works", { x_int <- rvar(integer()) x_cmp <- rvar(complex()) x_fct <- rvar(factor()) - x_ord <- rvar(ordered()) + x_ord <- rvar(ordered(NULL)) expect_true(is_rvar(x)) expect_false(is_rvar_logical(x)) From 31f82f837cbae8637465096baac363024ee86a9f Mon Sep 17 00:00:00 2001 From: Matthew Kay Date: Fri, 24 Nov 2023 17:12:01 -0600 Subject: [PATCH 10/10] move sd.complex into sd.default because sd.matrix overrides sd.complex --- NAMESPACE | 1 - R/rvar-summaries-over-draws.R | 13 ++++++++----- man/rvar-summaries-over-draws.Rd | 3 --- 3 files changed, 8 insertions(+), 9 deletions(-) diff --git a/NAMESPACE b/NAMESPACE index 0679d0aa..f62d10ec 100755 --- a/NAMESPACE +++ b/NAMESPACE @@ -276,7 +276,6 @@ S3method(rhat_basic,default) S3method(rhat_basic,rvar) S3method(rhat_nested,default) S3method(rhat_nested,rvar) -S3method(sd,complex) S3method(sd,default) S3method(sd,rvar) S3method(split_chains,draws) diff --git a/R/rvar-summaries-over-draws.R b/R/rvar-summaries-over-draws.R index 54b7a785..ba38873b 100755 --- a/R/rvar-summaries-over-draws.R +++ b/R/rvar-summaries-over-draws.R @@ -206,11 +206,14 @@ var.rvar <- variance.rvar sd <- function(x, ...) UseMethod("sd") #' @rdname rvar-summaries-over-draws #' @export -sd.default <- function(x, ...) stats::sd(x, ...) -#' @rdname rvar-summaries-over-draws -#' @export -sd.complex <- function(x, ...) { - sqrt(variance(c(x), ...)) +sd.default <- function(x, ...) { + # because complex matrices do not dispatch on the complex class, check for + # complex numbers here + if (is.complex(x)) { + sqrt(variance(c(x), ...)) + } else { + stats::sd(x, ...) + } } #' @rdname rvar-summaries-over-draws #' @export diff --git a/man/rvar-summaries-over-draws.Rd b/man/rvar-summaries-over-draws.Rd index 50937479..b9cb1e9b 100755 --- a/man/rvar-summaries-over-draws.Rd +++ b/man/rvar-summaries-over-draws.Rd @@ -24,7 +24,6 @@ \alias{var.rvar} \alias{sd} \alias{sd.default} -\alias{sd.complex} \alias{sd.rvar} \alias{mad} \alias{mad.default} @@ -81,8 +80,6 @@ sd(x, ...) \method{sd}{default}(x, ...) -\method{sd}{complex}(x, ...) - \method{sd}{rvar}(x, ...) mad(x, ...)