Skip to content

Commit

Permalink
Cherry-pick diagnose changes from #1290
Browse files Browse the repository at this point in the history
  • Loading branch information
WardBrian committed Oct 31, 2024
1 parent c7aa451 commit c8bcb99
Show file tree
Hide file tree
Showing 6 changed files with 67 additions and 67 deletions.
110 changes: 56 additions & 54 deletions src/cmdstan/diagnose.cpp
Original file line number Diff line number Diff line change
@@ -1,12 +1,15 @@
#include <cmdstan/return_codes.hpp>
#include <cmdstan/stansummary_helper.hpp>
#include <stan/mcmc/chains.hpp>
#include <stan/mcmc/chainset.hpp>
#include <algorithm>
#include <fstream>
#include <iomanip>
#include <ios>
#include <iostream>

double RHAT_MAX = 1.05;
using cmdstan::return_codes;

double RHAT_MAX = 1.01499; // round to 1.01

void diagnose_usage() {
std::cout << "USAGE: diagnose <filename 1> [<filename 2> ... <filename N>]"
Expand All @@ -26,7 +29,7 @@ void diagnose_usage() {
int main(int argc, const char *argv[]) {
if (argc == 1) {
diagnose_usage();
return 0;
return return_codes::OK;
}

// Parse any arguments specifying filenames
Expand All @@ -45,49 +48,47 @@ int main(int argc, const char *argv[]) {

if (!filenames.size()) {
std::cout << "No valid input files, exiting." << std::endl;
return 0;
return return_codes::NOT_OK;
}

std::cout << std::fixed << std::setprecision(2);

// Parse specified files
std::cout << "Processing csv files: " << filenames[0];
ifstream.open(filenames[0].c_str());

stan::io::stan_csv stan_csv
= stan::io::stan_csv_reader::parse(ifstream, &std::cout);
stan::mcmc::chains<> chains(stan_csv);
ifstream.close();

if (filenames.size() > 1)
std::cout << ", ";
else
std::cout << std::endl << std::endl;

for (std::vector<std::string>::size_type chain = 1; chain < filenames.size();
++chain) {
std::cout << filenames[chain];
ifstream.open(filenames[chain].c_str());
stan_csv = stan::io::stan_csv_reader::parse(ifstream, &std::cout);
chains.add(stan_csv);
ifstream.close();
if (chain < filenames.size() - 1)
std::cout << ", ";
else
std::cout << std::endl << std::endl;
std::vector<stan::io::stan_csv> csv_parsed;
for (int i = 0; i < filenames.size(); ++i) {
std::ifstream infile;
std::stringstream out;
stan::io::stan_csv sample;
infile.open(filenames[i].c_str());
try {
sample = stan::io::stan_csv_reader::parse(infile, &out);
// csv_reader warnings are errors - fail fast.
if (!out.str().empty()) {
throw std::invalid_argument(out.str());
}
csv_parsed.push_back(sample);
} catch (const std::invalid_argument &e) {
std::cout << "Cannot parse input csv file: " << filenames[i] << e.what()
<< "." << std::endl;
return return_codes::NOT_OK;
}
}

stan::mcmc::chainset chains(csv_parsed);
stan::io::stan_csv_metadata metadata = csv_parsed[0].metadata;
std::vector<std::string> param_names = csv_parsed[0].header;
size_t num_params = param_names.size();
int num_samples = chains.num_samples();
std::vector<std::string> bad_n_eff_names;
std::vector<std::string> bad_rhat_names;
bool has_errors = false;

for (int i = 0; i < chains.num_params(); ++i) {
if (chains.param_name(i) == std::string("treedepth__")) {
for (int i = 0; i < num_params; ++i) {
if (param_names[i] == std::string("treedepth__")) {
std::cout << "Checking sampler transitions treedepth." << std::endl;
int max_limit = stan_csv.metadata.max_depth;
int max_limit = metadata.max_depth;
long n_max = 0;
Eigen::VectorXd t_samples = chains.samples(i);
Eigen::MatrixXd draws = chains.samples(i);
Eigen::VectorXd t_samples
= Eigen::Map<Eigen::VectorXd>(draws.data(), draws.size());
for (long n = 0; n < t_samples.size(); ++n) {
if (t_samples(n) >= max_limit) {
++n_max;
Expand All @@ -109,7 +110,7 @@ int main(int argc, const char *argv[]) {
std::cout << "Treedepth satisfactory for all transitions." << std::endl
<< std::endl;
}
} else if (chains.param_name(i) == std::string("divergent__")) {
} else if (param_names[i] == std::string("divergent__")) {
std::cout << "Checking sampler transitions for divergences." << std::endl;
int n_divergent = chains.samples(i).sum();
if (n_divergent > 0) {
Expand All @@ -129,26 +130,22 @@ int main(int argc, const char *argv[]) {
std::cout << "No divergent transitions found." << std::endl
<< std::endl;
}
} else if (chains.param_name(i) == std::string("energy__")) {
} else if (param_names[i] == std::string("energy__")) {
std::cout << "Checking E-BFMI - sampler transitions HMC potential energy."
<< std::endl;
Eigen::VectorXd e_samples = chains.samples(i);
Eigen::MatrixXd draws = chains.samples(i);
Eigen::VectorXd e_samples
= Eigen::Map<Eigen::VectorXd>(draws.data(), draws.size());
double delta_e_sq_mean = 0;
double e_mean = 0;
double e_var = 0;
e_mean += e_samples(0);
e_var += e_samples(0) * (e_samples(0) - e_mean);
double e_mean = chains.mean(i);
double e_var = chains.variance(i);
for (long n = 1; n < e_samples.size(); ++n) {
double e = e_samples(n);
double delta_e_sq = (e - e_samples(n - 1)) * (e - e_samples(n - 1));
double d = delta_e_sq - delta_e_sq_mean;
delta_e_sq_mean += d / n;
d = e - e_mean;
e_mean += d / (n + 1);
e_var += d * (e - e_mean);
}

e_var /= static_cast<double>(e_samples.size() - 1);
double e_bfmi = delta_e_sq_mean / e_var;
double e_bfmi_threshold = 0.3;
if (e_bfmi < e_bfmi_threshold) {
Expand All @@ -163,14 +160,16 @@ int main(int argc, const char *argv[]) {
} else {
std::cout << "E-BFMI satisfactory." << std::endl << std::endl;
}
} else if (chains.param_name(i).find("__") == std::string::npos) {
double n_eff = chains.effective_sample_size(i);
} else if (param_names[i].find("__") == std::string::npos) {
auto [ess_bulk, ess_tail] = chains.split_rank_normalized_ess(i);
double n_eff = ess_bulk < ess_tail ? ess_bulk : ess_tail;
if (n_eff / num_samples < 0.001)
bad_n_eff_names.push_back(chains.param_name(i));
bad_n_eff_names.push_back(param_names[i]);

double split_rhat = chains.split_potential_scale_reduction(i);
auto [rhat_bulk, rhat_tail] = chains.split_rank_normalized_rhat(i);
double split_rhat = rhat_bulk > rhat_tail ? rhat_bulk : rhat_tail;
if (split_rhat > RHAT_MAX)
bad_rhat_names.push_back(chains.param_name(i));
bad_rhat_names.push_back(param_names[i]);
}
}
if (bad_n_eff_names.size() > 0) {
Expand All @@ -187,13 +186,15 @@ int main(int argc, const char *argv[]) {
<< " may be substantially lower than quoted." << std::endl
<< std::endl;
} else {
std::cout << "Effective sample size satisfactory." << std::endl
std::cout << "Rank-normalized split effective sample size satisfactory "
<< "for all parameters." << std::endl
<< std::endl;
}

if (bad_rhat_names.size() > 0) {
has_errors = true;
std::cout << "The following parameters had split R-hat greater than "
std::cout << "The following parameters had rank-normalized split R-hat "
"greater than "
<< RHAT_MAX << ":" << std::endl;
std::cout << " ";
for (size_t n = 0; n < bad_rhat_names.size() - 1; ++n)
Expand All @@ -207,13 +208,14 @@ int main(int argc, const char *argv[]) {
<< " effective parameterization." << std::endl
<< std::endl;
} else {
std::cout << "Split R-hat values satisfactory all parameters." << std::endl
std::cout << "Rank-normalized split R-hat values satisfactory "
<< "for all parameters." << std::endl
<< std::endl;
}
if (!has_errors)
std::cout << "Processing complete, no problems detected." << std::endl;
else
std::cout << "Processing complete." << std::endl;

return 0;
return return_codes::OK;
}
4 changes: 2 additions & 2 deletions src/test/interface/example_output/corr_gauss.nom
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@ No divergent transitions found.
Checking E-BFMI - sampler transitions HMC potential energy.
E-BFMI satisfactory.

Effective sample size satisfactory.
Rank-normalized split effective sample size satisfactory for all parameters.

Split R-hat values satisfactory all parameters.
Rank-normalized split R-hat values satisfactory for all parameters.

Processing complete.
4 changes: 2 additions & 2 deletions src/test/interface/example_output/corr_gauss_depth15.nom
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@ No divergent transitions found.
Checking E-BFMI - sampler transitions HMC potential energy.
E-BFMI satisfactory.

Effective sample size satisfactory.
Rank-normalized split effective sample size satisfactory for all parameters.

Split R-hat values satisfactory all parameters.
Rank-normalized split R-hat values satisfactory for all parameters.

Processing complete, no problems detected.
4 changes: 2 additions & 2 deletions src/test/interface/example_output/corr_gauss_depth8.nom
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@ No divergent transitions found.
Checking E-BFMI - sampler transitions HMC potential energy.
E-BFMI satisfactory.

Effective sample size satisfactory.
Rank-normalized split effective sample size satisfactory for all parameters.

Split R-hat values satisfactory all parameters.
Rank-normalized split R-hat values satisfactory for all parameters.

Processing complete.
4 changes: 2 additions & 2 deletions src/test/interface/example_output/eight_schools.nom
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@ Checking E-BFMI - sampler transitions HMC potential energy.
The E-BFMI, 0.26, is below the nominal threshold of 0.30 which suggests that HMC may have trouble exploring the target distribution.
If possible, try to reparameterize the model.

Effective sample size satisfactory.
Rank-normalized split effective sample size satisfactory for all parameters.

Split R-hat values satisfactory all parameters.
Rank-normalized split R-hat values satisfactory for all parameters.

Processing complete.
8 changes: 3 additions & 5 deletions src/test/interface/example_output/mix.nom
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,10 @@ No divergent transitions found.
Checking E-BFMI - sampler transitions HMC potential energy.
E-BFMI satisfactory.

The following parameters had fewer than 0.001 effective draws per transition:
mu[1], mu[2], theta
Such low values indicate that the effective sample size estimators may be biased high and actual performance may be substantially lower than quoted.
Rank-normalized split effective sample size satisfactory for all parameters.

The following parameters had split R-hat greater than 1.05:
mu[1], mu[2], theta
The following parameters had rank-normalized split R-hat greater than 1.01:
mu[1], mu[2], sigma[1], theta
Such high values indicate incomplete mixing and biased estimation.
You should consider regularizating your model with additional prior information or a more effective parameterization.

Expand Down

0 comments on commit c8bcb99

Please sign in to comment.