diff --git a/R/nested_rhat.R b/R/nested_rhat.R index bfca55c6..e62533ae 100644 --- a/R/nested_rhat.R +++ b/R/nested_rhat.R @@ -34,12 +34,34 @@ rhat_nested.rvar <- function(x, superchain_ids, ...) { } .rhat_nested <- function(x, superchain_ids, ...) { + if (should_return_NA(x)) { + return(NA_real_) + } x <- as.matrix(x) niterations <- NROW(x) - nchains_per_superchain <- max(table(superchain_ids)) + nchains <- NCOL(x) + + + # check that all chains are assigned a superchain + if (length(superchain_ids) != nchains) { + warning_no_call("Length of superchain_ids not equal to number of chains, returning NA.") + return(NA_real_) + } + + + # check that superchains are equal length + superchain_id_table <- table(superchain_ids) + nchains_per_superchain <- max(superchain_id_table) + + if (nchains_per_superchain != min(superchain_id_table)) { + warning_no_call("Number of chains per superchain is not the same for each superchain, returning NA.") + return(NA_real_) + } + superchains <- unique(superchain_ids) + # mean and variance of chains calculated as in rhat chain_mean <- matrixStats::colMeans2(x) chain_var <- matrixStats::colVars(x, center = chain_mean) diff --git a/tests/testthat/test-rhat_nested.R b/tests/testthat/test-rhat_nested.R index 1d71ea9e..15b3a929 100644 --- a/tests/testthat/test-rhat_nested.R +++ b/tests/testthat/test-rhat_nested.R @@ -9,3 +9,23 @@ test_that("rhat_nested returns reasonable values", { }) +test_that("rhat_nested handles special cases correctly", { + set.seed(1234) + x <- c(rnorm(10), NA) + expect_true(is.na(rhat_nested(x, superchain_ids = c(1)))) + + x <- c(rnorm(10), Inf) + expect_true(is.na(rhat_nested(x, superchain_ids = c(1,2,1,2)))) + + tau <- extract_variable_matrix(example_draws(), "tau") + expect_warning( + rhat_nested(tau, superchain_ids = c(1,1,1,3)), + "Number of chains per superchain is not the same for each superchain, returning NA." + ) + + tau <- extract_variable_matrix(example_draws(), "tau") + expect_warning( + rhat_nested(tau, superchain_ids = c(1,2)), + "Length of superchain_ids not equal to number of chains, returning NA." + ) +})