Skip to content

Commit

Permalink
Update stansummary_test
Browse files Browse the repository at this point in the history
  • Loading branch information
WardBrian committed Oct 31, 2024
1 parent b1a1f0e commit c7aa451
Show file tree
Hide file tree
Showing 3 changed files with 54 additions and 47 deletions.
1 change: 0 additions & 1 deletion src/cmdstan/stansummary.cpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
#include <cmdstan/return_codes.hpp>
#include <cmdstan/stansummary_helper.hpp>
#include <stan/mcmc/chains.hpp>
#include <stan/io/ends_with.hpp>
#include <algorithm>
#include <fstream>
Expand Down
12 changes: 9 additions & 3 deletions src/cmdstan/stansummary_helper.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -163,34 +163,40 @@ bool is_container(const std::string &parameter_name) {
/**
* Return parameter name corresponding to column label.
*
* @tparam Tchains - either a mcmc::chains or mcmc::chainset object
* @param in column index
* @return variable name
*/
std::string base_param_name(const stan::mcmc::chainset &chains, int index) {
template <typename Tchains>
std::string base_param_name(const Tchains& chains, int index) {
std::string name = chains.param_name(index);
return name.substr(0, name.find("["));
}

/**
* Return parameter name corresponding to column label.
*
* @tparam Tchains - either a mcmc::chains or mcmc::chainset object
* @param in set of samples from one or more chains
* @param in column index
* @return parameter name
*/
std::string matrix_index(const stan::mcmc::chainset &chains, int index) {
template <typename Tchains>
std::string matrix_index(const Tchains &chains, int index) {
std::string name = chains.param_name(index);
return name.substr(name.find("["));
}

/**
* Return vector of dimensions for container variable.
*
* @tparam Tchains - either a mcmc::chains or mcmc::chainset object
* @param in set of samples from one or more chains
* @param in column index of first container element
* @return vector of dimensions
*/
std::vector<int> dimensions(const stan::mcmc::chainset &chains,
template <typename Tchains>
std::vector<int> dimensions(const Tchains &chains,
int start_index) {
std::vector<int> dims;
int dim;
Expand Down
88 changes: 45 additions & 43 deletions src/test/interface/stansummary_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -136,10 +136,10 @@ TEST(CommandStansummary, matrix_index_2d) {

TEST(CommandStansummary, header_tests) {
std::string expect
= " Mean MCSE StdDev 10% 50% 90% N_Eff "
"N_Eff/s R_hat\n";
= " Mean MCSE StdDev MAD 10% 50% 90%"
" ESS_bulk ESS_tail R_hat\n";
std::string expect_csv
= "name,Mean,MCSE,StdDev,10%,50%,90%,N_Eff,N_Eff/s,R_hat\n";
= "name,Mean,MCSE,StdDev,MAD,10%,50%,90%,ESS_bulk,ESS_tail,R_hat\n";
std::vector<std::string> pcts;
pcts.push_back("10");
pcts.push_back("50");
Expand All @@ -150,12 +150,14 @@ TEST(CommandStansummary, header_tests) {
EXPECT_FLOAT_EQ(probs[2], 0.9);

std::vector<std::string> header = get_header(pcts);
EXPECT_EQ(header.size(), pcts.size() + 6);
EXPECT_EQ(header.size(), pcts.size() + 7);
EXPECT_EQ(header[0], "Mean");
EXPECT_EQ(header[2], "StdDev");
EXPECT_EQ(header[4], "50%");
EXPECT_EQ(header[6], "N_Eff");
EXPECT_EQ(header[8], "R_hat");
EXPECT_EQ(header[3], "MAD");
EXPECT_EQ(header[5], "50%");
EXPECT_EQ(header[7], "ESS_bulk");
EXPECT_EQ(header[8], "ESS_tail");
EXPECT_EQ(header[9], "R_hat");

Eigen::VectorXi column_widths(header.size());
for (size_t i = 0, w = 5; i < header.size(); ++i, ++w) {
Expand Down Expand Up @@ -218,8 +220,8 @@ TEST(CommandStansummary, param_tests) {
Eigen::VectorXd warmup_times(filenames.size());
Eigen::VectorXd sampling_times(filenames.size());
Eigen::VectorXi thin(filenames.size());
stan::mcmc::chains<> chains = parse_csv_files(
filenames, metadata, warmup_times, sampling_times, thin, &std::cout);
auto chains = parse_csv_files(filenames, metadata, warmup_times,
sampling_times, thin, &std::cout);
EXPECT_EQ(chains.num_chains(), 1);
EXPECT_EQ(chains.num_params(), 8);

Expand All @@ -237,7 +239,7 @@ TEST(CommandStansummary, param_tests) {
size_t num_model_params = chains.num_params() - model_params_offset;
EXPECT_EQ(num_model_params, 1);

Eigen::MatrixXd model_params(num_model_params, 9);
Eigen::MatrixXd model_params(num_model_params, 10);
std::vector<int> model_param_idxes(num_model_params);
std::iota(model_param_idxes.begin(), model_param_idxes.end(),
model_params_offset);
Expand All @@ -246,7 +248,7 @@ TEST(CommandStansummary, param_tests) {
double mean_theta = model_params(0, 0);
EXPECT_TRUE(mean_theta > 0.25);
EXPECT_TRUE(mean_theta < 0.27);
double rhat_theta = model_params(0, 8);
double rhat_theta = model_params(0, 9);
EXPECT_TRUE(rhat_theta > 0.999);
EXPECT_TRUE(rhat_theta < 1.01);
}
Expand Down Expand Up @@ -386,7 +388,7 @@ TEST(CommandStansummary, bad_percentiles_arg) {
ASSERT_TRUE(out.hasError)
<< "\"" << out.command << "\" failed to quit with an error";

arg_percentiles = "--percentiles \"0,100\"";
arg_percentiles = "--percentiles \"101\"";
out = run_command(command + " " + arg_percentiles + " " + csv_file);
EXPECT_TRUE(boost::algorithm::contains(out.output, expected_message));
ASSERT_TRUE(out.hasError)
Expand Down Expand Up @@ -423,17 +425,17 @@ TEST(CommandStansummary, bad_include_param_args) {

TEST(CommandStansummary, check_console_output) {
std::string lp
= "lp__ -7.3 3.7e-02 0.77 -9.1 -7.0 -6.8 443 "
"19275 1.0";
= "lp__ -7.3 3.7e-02 0.77 0.30 -9.0 -7.0 -6.8 "
" 519 503 1.0";
std::string theta
= "theta 0.26 6.1e-03 0.12 0.079 0.25 0.47 384 "
"16683 1.00";
= "theta 0.26 6.1e-03 0.12 0.12 0.080 0.25 0.47 "
" 362 396 1.0";
std::string accept_stat
= "accept_stat__ 0.90 4.6e-03 1.5e-01 0.57 0.96 1.0 1026 "
"44597 1.00";
= "accept_stat__ 0.90 4.6e-03 1.5e-01 0.064 0.57 0.96 1.0 "
"1284 941 1.00";
std::string energy
= "energy__ 7.8 5.1e-02 1.0e+00 6.8 7.5 9.9 411 "
"17865 1.0";
= "energy__ 7.8 5.1e-02 1.0e+00 0.75 6.8 7.5 9.9 "
" 490 486 1.0";

std::string path_separator;
path_separator.push_back(get_path_separator());
Expand Down Expand Up @@ -478,17 +480,16 @@ TEST(CommandStansummary, check_console_output) {

TEST(CommandStansummary, check_csv_output) {
std::string csv_header
= "name,Mean,MCSE,StdDev,5%,50%,95%,N_Eff,N_Eff/s,R_hat";
= "name,Mean,MCSE,StdDev,MAD,5%,50%,95%,ESS_bulk,ESS_tail,R_hat";
std::string lp
= "\"lp__\",-7.2719,0.0365168,0.768874,-9.05757,-6.96978,-6.75008,443."
"328,19275.1,1.00037";
= "\"lp__\",-7.2719,0.0365168,0.768874,0.303688,-8.98426,-6.97009,-6."
"75007,519.29,503.309,1.00141";
std::string energy
= "\"energy__\""
",7.78428,0.0508815,1.0314,6.80383,7.46839,9.88601,410."
"898,17865.1,1.00075";
= "\"energy__\",7.78428,0.0508815,1.0314,0.745859,6.80565,7.46758,9.8864,"
"489.874,486.438,1.00495";
std::string theta
= "\"theta\",0.256552,0.00610844,0.119654,0.0786292,0.24996,0.470263,383."
"704,16682.8,0.999309";
= "\"theta\",0.256552,0.00610844,0.119654,0.120965,0.0802982,0.24996,0."
"47034,361.506,395.736,1.00186";

std::string path_separator;
path_separator.push_back(get_path_separator());
Expand Down Expand Up @@ -531,9 +532,9 @@ TEST(CommandStansummary, check_csv_output) {
}

TEST(CommandStansummary, check_csv_output_no_percentiles) {
std::string csv_header = "name,Mean,MCSE,StdDev,N_Eff,N_Eff/s,R_hat";
std::string csv_header = "name,Mean,MCSE,StdDev,MAD,ESS_bulk,ESS_tail,R_hat";
std::string lp
= "\"lp__\",-7.2719,0.0365168,0.768874,443.328,19275.1,1.00037";
= "\"lp__\",-7.2719,0.0365168,0.768874,0.303688,519.29,503.309,1.00141";

std::string path_separator;
path_separator.push_back(get_path_separator());
Expand Down Expand Up @@ -570,11 +571,12 @@ TEST(CommandStansummary, check_csv_output_no_percentiles) {

TEST(CommandStansummary, check_csv_output_sig_figs) {
std::string csv_header
= "name,Mean,MCSE,StdDev,5%,50%,95%,N_Eff,N_Eff/s,R_hat";
std::string lp = "\"lp__\",-7.3,0.037,0.77,-9.1,-7,-6.8,4.4e+02,1.9e+04,1";
std::string energy = "\"energy__\",7.8,0.051,1,6.8,7.5,9.9,4.1e+02,1.8e+04,1";
= "name,Mean,MCSE,StdDev,MAD,5%,50%,95%,ESS_bulk,ESS_tail,R_hat";
std::string lp = "\"lp__\",-7.3,0.037,0.77,0.3,-9,-7,-6.8,5.2e+02,5e+02,1";
std::string energy
= "\"energy__\",7.8,0.051,1,0.75,6.8,7.5,9.9,4.9e+02,4.9e+02,1";
std::string theta
= "\"theta\",0.26,0.0061,0.12,0.079,0.25,0.47,3.8e+02,1.7e+04,1";
= "\"theta\",0.26,0.0061,0.12,0.12,0.08,0.25,0.47,3.6e+02,4e+02,1";

std::string path_separator;
path_separator.push_back(get_path_separator());
Expand Down Expand Up @@ -619,20 +621,20 @@ TEST(CommandStansummary, check_csv_output_sig_figs) {

TEST(CommandStansummary, check_csv_output_include_param) {
std::string csv_header
= "name,Mean,MCSE,StdDev,5%,50%,95%,N_Eff,N_Eff/s,R_hat";
= "name,Mean,MCSE,StdDev,MAD,5%,50%,95%,ESS_bulk,ESS_tail,R_hat";
std::string lp
= "\"lp__\",-15.5617,0.97319,6.05585,-25.3432,-15.7562,-5.48405,38.7217,"
"372.539,1.00208";
= "\"lp__\",-15.5617,0.97319,6.05585,6.3817,-25.3182,-15.7598,-5.47732,"
"41.1897,113.537,1.00153";
std::string energy
= "\"energy__\",20.5888,1.01449,6.43127,10.2634,20.8487,30.9906,40.1879,"
"386.645,1.00109";
= "\"energy__\",20.5888,1.01449,6.43127,6.6161,10.2809,20.8278,30.9921,"
"42.5605,140.171,1.00069";
// note: skipping theta 1-5
std::string theta6
= "\"theta[6]\",5.001,0.365016,5.76072,-4.99688,5.23474,14.1597,249.074,"
"2396.33,1.00109";
= "\"theta[6]\",5.001,0.365016,5.76072,5.37947,-4.95375,5.22746,14.1688,"
"230.645,464.978,1.00054";
std::string theta7
= "\"theta[7]\",8.54125,0.650098,6.22195,-0.841225,8.09613,19.256,91."
"6001,881.279,0.999012";
= "\"theta[7]\",8.54125,0.650098,6.22195,5.35785,-0.814388,8.09342,19."
"2622,92.3075,241.177,1.00244";
// note: skipping theta 8
std::string message = "# Inference for Stan model: eight_schools_cp_model";

Expand Down

0 comments on commit c7aa451

Please sign in to comment.