Skip to content

Commit

Permalink
redo logic for -i flag
Browse files Browse the repository at this point in the history
  • Loading branch information
mitzimorris committed Oct 30, 2024
1 parent b7ac1bc commit 22f8176
Show file tree
Hide file tree
Showing 2 changed files with 58 additions and 11 deletions.
35 changes: 24 additions & 11 deletions src/cmdstan/stansummary.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -150,22 +150,35 @@ Example: stansummary model_chain_1.csv model_chain_2.csv

bool reorder_params = true;
if (requested_params_vec.size() > 0) {
reorder_params = false;
std::set<std::string> requested_params(requested_params_vec.begin(),
requested_params_vec.end());
std::set<std::string> check_requested_params(requested_params_vec.begin(),
requested_params_vec.end());
for (size_t i = 0; i < param_names.size(); ++i) {
check_requested_params.erase(param_names[i]);
auto num_requested = requested_params_vec.size();
std::vector<std::string> valid_params;
std::vector<std::string> invalid_params;

std::set<std::string> pnames(param_names.begin(), param_names.end());
for (std::string request : requested_params_vec) {
auto it = pnames.find(request);
if (it == pnames.end()) {
auto find_dups = std::find(invalid_params.begin(), invalid_params.end(), request);
if (find_dups == invalid_params.end()) {
invalid_params.emplace_back(request);
}
} else {
valid_params.emplace_back(request);
auto find_dups = std::find(valid_params.begin(), valid_params.end(), request);
if (find_dups == valid_params.end()) {
valid_params.emplace_back(request);
}
}
}
if (check_requested_params.size() == 0) {
if (invalid_params.empty()) {
reorder_params = false;
param_names.clear();
std::copy(requested_params.begin(), requested_params.end(),
std::copy(valid_params.begin(), valid_params.end(),
std::back_inserter(param_names));
} else {
std::cout << "--include_param: Unrecognized parameter(s): ";
for (auto param : requested_params) {
std::cout << "'" << param << "' ";
for (size_t i = 0; i < invalid_params.size(); ++i) {
std::cout << "'" << invalid_params[i] << "' ";
}
std::cout << std::endl;
return return_codes::NOT_OK;
Expand Down
34 changes: 34 additions & 0 deletions src/test/interface/stansummary_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -589,3 +589,37 @@ TEST(CommandStansummary, check_csv_output_include_param) {
FAIL();
}
}

TEST(CommandStansummary, check_csv_output_include_param_order) {
std::vector<std::string> expect_names = {"y[1,1]", "y[2,2]", "y[2,1]", "y[1,3]"};
std::string path_separator;
path_separator.push_back(get_path_separator());
std::string csv_file = "src" + path_separator + "test" + path_separator
+ "interface" + path_separator + "matrix_output.csv";
std::stringstream ss_command;
ss_command << "bin" << path_separator << "stansummary "
<< " -i y.1.1 -i y.2.2 -i y.2.1 -i y.1.3 "
<< csv_file;
run_command_output out = run_command(ss_command.str());
ASSERT_FALSE(out.hasError);
std::istringstream target_stream(out.output);
std::string line;

std::getline(target_stream, line); // model name
std::getline(target_stream, line); // chain info
std::getline(target_stream, line); // blank
std::getline(target_stream, line); // warmup time
std::getline(target_stream, line); // sample time
std::getline(target_stream, line); // blank
std::getline(target_stream, line); // header
std::getline(target_stream, line); // blank
boost::char_separator<char> sep(" \t\n");
for (size_t i = 0; i < expect_names.size(); ++i) {
std::getline(target_stream, line);
boost::tokenizer<boost::char_separator<char>> line_tok(line, sep);
for (const auto& token : line_tok) {
EXPECT_EQ(token, expect_names[i]);
break;
}
}
}

0 comments on commit 22f8176

Please sign in to comment.