Skip to content

Commit

Permalink
Merge pull request #14 from bpuchala/2.X_sampling_fixture_params_json_io
Browse files Browse the repository at this point in the history
2.x sampling fixture params json io
  • Loading branch information
bpuchala authored Jul 16, 2024
2 parents a130350 + 442354e commit dec4619
Show file tree
Hide file tree
Showing 17 changed files with 360 additions and 94 deletions.
19 changes: 18 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,24 @@ All notable changes to `libcasm-monte` will be documented in this file.
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/),
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).

## [v2.0a1] - 2024-03-15

## [2.0a2] - 2024-07-16

### Added

- Added to_json for CompletionCheckParams, SamplingFixtureParams, SamplingParams, jsonResultsIO
- Added "json_quantities" option to SamplingParams
- Added Conversions::species_list()

### Changed

- Use shared_ptr to hold sampling fixtures in RunManager
- Output scalar quantities under "value" key in JSON results output
- Allow MethodLog to output to stdout
- Allow constructing libcasm.monte.ValueMap from dict


## [2.0a1] - 2024-03-15

The libcasm-monte package provides useful building blocks for Monte Carlo simulations. This includes:

Expand Down
9 changes: 5 additions & 4 deletions include/casm/monte/Conversions.hh
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,7 @@ class Conversions {

Index species_size() const;
Index species_index(std::string species_name) const;
std::vector<xtal::Molecule> const &species_list() const;
xtal::Molecule const &species_to_mol(Index species_index) const;
std::string const &species_name(Index species_index) const;
Index components_size(Index species_index) const;
Expand All @@ -122,14 +123,14 @@ class Conversions {

Index m_Nasym;
std::vector<Index> m_unitl_to_asym;
std::vector<std::set<Index> > m_asym_to_unitl;
std::vector<std::set<Index> > m_asym_to_b;
std::vector<std::set<Index>> m_asym_to_unitl;
std::vector<std::set<Index>> m_asym_to_b;

/// m_occ_to_species[asym][occ_index] -> species_index
std::vector<std::vector<Index> > m_occ_to_species;
std::vector<std::vector<Index>> m_occ_to_species;

/// m_species_to_occ[asym][species_index] -> occ_index
std::vector<std::vector<Index> > m_species_to_occ;
std::vector<std::vector<Index>> m_species_to_occ;
};

} // namespace monte
Expand Down
5 changes: 5 additions & 0 deletions include/casm/monte/MethodLog.hh
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,11 @@ struct MethodLog {
log.reset(*fout);
}
}

void reset_to_stdout() {
fout.reset();
log.reset();
}
};

} // namespace monte
Expand Down
69 changes: 60 additions & 9 deletions include/casm/monte/checks/io/json/CompletionCheck_json_io.hh
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,13 @@ template <typename StatisticsType>
struct CompletionCheckParams;

/// \brief Construct CompletionCheckParams<BasicStatistics> from JSON
inline void parse(InputParser<CompletionCheckParams<BasicStatistics>> &parser,
StateSamplingFunctionMap const &sampling_functions);
void parse(InputParser<CompletionCheckParams<BasicStatistics>> &parser,
StateSamplingFunctionMap const &sampling_functions);

/// \brief Convert CompletionCheckParams<BasicStatistics> to JSON
jsonParser &to_json(
CompletionCheckParams<BasicStatistics> const &completion_check_params,
jsonParser &json);

/// \brief CompletionCheckResults to JSON
template <typename StatisticsType>
Expand Down Expand Up @@ -116,17 +121,15 @@ void _parse_components(
(parser.self.find_at(option / "component_index") != parser.self.end());
bool has_name =
(parser.self.find_at(option / "component_name") != parser.self.end());
if (has_index && has_name) {
parser.insert_error(option,
"Error: cannot specify both \"component_index\" and "
"\"component_name\"");
} else if (has_index) {
if (has_index) {
_parse_component_index(parser, option, function, precision,
requested_precision);
} else if (has_name) {
}
if (has_name) {
_parse_component_name(parser, option, function, precision,
requested_precision);
} else {
}
if (!has_index && !has_name) {
// else, converge all components
for (Index index = 0; index < function.component_names.size(); ++index) {
requested_precision.emplace(
Expand Down Expand Up @@ -365,6 +368,54 @@ inline void parse(InputParser<CompletionCheckParams<BasicStatistics>> &parser,
}
}

/// \brief Convert CompletionCheckParams<BasicStatistics> to JSON
inline jsonParser &to_json(
CompletionCheckParams<BasicStatistics> const &completion_check_params,
jsonParser &json) {
// TODO: write out calc_statistics_f parameters
// to_json["calc_statistics_f_confidence"] = 0.95;
// to_json["calc_statistics_f_weighted_observations_method"] = 1;
// to_json["calc_statistics_f_n_resamples"] = 10000;

json["cutoff"] = completion_check_params.cutoff_params;

json["convergence"] = jsonParser::array();
for (auto const &pair : completion_check_params.requested_precision) {
auto const &key = pair.first;
auto const &req_prec = pair.second;

jsonParser tmp;
tmp["quantity"] = key.sampler_name;

std::vector<int> component_index;
component_index.push_back(key.component_index);
tmp["component_index"] = component_index;

std::vector<std::string> component_name;
component_name.push_back(key.component_name);
tmp["component_name"] = component_name;

to_json(req_prec, tmp);

json["convergence"].push_back(tmp);
}

// "spacing"
if (completion_check_params.log_spacing == false) {
json["spacing"] = "linear";
json["begin"] = completion_check_params.check_begin;
json["period"] = completion_check_params.check_period;
} else {
json["spacing"] = "log";
json["begin"] = completion_check_params.check_begin;
json["base"] = completion_check_params.check_base;
json["shift"] = completion_check_params.check_shift;
json["period_max"] = completion_check_params.check_period_max;
}

return json;
}

/// \brief CompletionCheckResults to JSON
template <typename StatisticsType>
jsonParser &to_json(CompletionCheckResults<StatisticsType> const &value,
Expand Down
6 changes: 3 additions & 3 deletions include/casm/monte/methods/kinetic_monte_carlo.hh
Original file line number Diff line number Diff line change
Expand Up @@ -137,9 +137,9 @@ void kinetic_monte_carlo(
kmc_data.time = 0.0;
kmc_data.atom_positions_cart = occ_location.atom_positions_cart();
kmc_data.prev_atom_positions_cart.clear();
for (auto &fixture : run_manager.sampling_fixtures) {
kmc_data.prev_time.emplace(fixture.label(), kmc_data.time);
kmc_data.prev_atom_positions_cart.emplace(fixture.label(),
for (auto &fixture_ptr : run_manager.sampling_fixtures) {
kmc_data.prev_time.emplace(fixture_ptr->label(), kmc_data.time);
kmc_data.prev_atom_positions_cart.emplace(fixture_ptr->label(),
kmc_data.atom_positions_cart);
}

Expand Down
13 changes: 10 additions & 3 deletions include/casm/monte/run_management/ResultsAnalysisFunction.hh
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,8 @@ template <typename ConfigType, typename StatisticsType>
std::map<std::string, Eigen::VectorXd> make_analysis(
Results<ConfigType, StatisticsType> const &results,
ResultsAnalysisFunctionMap<ConfigType, StatisticsType> const
&analysis_functions);
&analysis_functions,
std::vector<std::string> analysis_names);

// --- Implementation ---

Expand Down Expand Up @@ -110,9 +111,15 @@ template <typename ConfigType, typename StatisticsType>
std::map<std::string, Eigen::VectorXd> make_analysis(
Results<ConfigType, StatisticsType> const &results,
ResultsAnalysisFunctionMap<ConfigType, StatisticsType> const
&analysis_functions) {
&analysis_functions,
std::vector<std::string> analysis_names) {
std::map<std::string, Eigen::VectorXd> analysis;
for (auto const &pair : analysis_functions) {
for (std::string name : analysis_names) {
auto it = analysis_functions.find(name);
if (it == analysis_functions.end()) {
continue;
}
auto const &pair = *it;
auto const &f = pair.second;
try {
analysis.emplace(f.name, f(results));
Expand Down
43 changes: 23 additions & 20 deletions include/casm/monte/run_management/RunManager.hh
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ struct RunManager {
std::shared_ptr<engine_type> engine;

/// Sampling fixtures
std::vector<sampling_fixture_type> sampling_fixtures;
std::vector<std::shared_ptr<sampling_fixture_type>> sampling_fixtures;

/// \brief If true, the run is complete if any sampling fixture
/// is complete. Otherwise, all sampling fixtures must be
Expand Down Expand Up @@ -83,13 +83,14 @@ struct RunManager {
next_sample_time(0.0),
break_point_set(false) {
for (auto const &params : _sampling_fixture_params) {
sampling_fixtures.emplace_back(params, engine);
sampling_fixtures.emplace_back(
std::make_shared<sampling_fixture_type>(params, engine));
}
}

void initialize(Index steps_per_pass) {
for (auto &fixture : sampling_fixtures) {
fixture.initialize(steps_per_pass);
for (auto &fixture_ptr : sampling_fixtures) {
fixture_ptr->initialize(steps_per_pass);
}
break_point_set = false;
}
Expand All @@ -102,8 +103,8 @@ struct RunManager {
// check results
bool all_complete = true;
bool any_complete = false;
for (auto &fixture : sampling_fixtures) {
if (fixture.is_complete()) {
for (auto &fixture_ptr : sampling_fixtures) {
if (fixture_ptr->is_complete()) {
any_complete = true;
} else {
all_complete = false;
Expand All @@ -116,32 +117,32 @@ struct RunManager {
}

void write_status_if_due() {
for (auto &fixture : sampling_fixtures) {
fixture.write_status_if_due(run_index);
for (auto &fixture_ptr : sampling_fixtures) {
fixture_ptr->write_status_if_due(run_index);
}
}

void increment_n_accept() {
for (auto &fixture : sampling_fixtures) {
fixture.increment_n_accept();
for (auto &fixture_ptr : sampling_fixtures) {
fixture_ptr->increment_n_accept();
}
}

void increment_n_reject() {
for (auto &fixture : sampling_fixtures) {
fixture.increment_n_reject();
for (auto &fixture_ptr : sampling_fixtures) {
fixture_ptr->increment_n_reject();
}
}

void increment_step() {
for (auto &fixture : sampling_fixtures) {
fixture.increment_step();
for (auto &fixture_ptr : sampling_fixtures) {
fixture_ptr->increment_step();
}
}

void set_time(double event_time) {
for (auto &fixture : sampling_fixtures) {
fixture.set_time(event_time);
for (auto &fixture_ptr : sampling_fixtures) {
fixture_ptr->set_time(event_time);
}
}

Expand All @@ -151,7 +152,8 @@ struct RunManager {
state_type const &state,
PreSampleActionType pre_sample_f = PreSampleActionType(),
PostSampleActionType post_sample_f = PostSampleActionType()) {
for (auto &fixture : sampling_fixtures) {
for (auto &fixture_ptr : sampling_fixtures) {
auto &fixture = *fixture_ptr;
if (fixture.params().sampling_params.sample_mode !=
SAMPLE_MODE::BY_TIME) {
if (fixture.counter().count == fixture.next_sample_count()) {
Expand Down Expand Up @@ -193,7 +195,8 @@ struct RunManager {
void update_next_sampling_fixture() {
// update next_sample_time and next_sampling_fixture
next_sampling_fixture = nullptr;
for (auto &fixture : sampling_fixtures) {
for (auto &fixture_ptr : sampling_fixtures) {
auto &fixture = *fixture_ptr;
if (fixture.params().sampling_params.sample_mode ==
SAMPLE_MODE::BY_TIME) {
if (next_sampling_fixture == nullptr ||
Expand All @@ -210,8 +213,8 @@ struct RunManager {
/// Notes:
/// - Calls `finalize` for all sampling fixtures
void finalize(state_type const &final_state) {
for (auto &fixture : sampling_fixtures) {
fixture.finalize(final_state, run_index);
for (auto &fixture_ptr : sampling_fixtures) {
fixture_ptr->finalize(final_state, run_index);
}
}
};
Expand Down
8 changes: 7 additions & 1 deletion include/casm/monte/run_management/SamplingFixture.hh
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ struct SamplingFixtureParams {
_analysis_functions,
monte::SamplingParams _sampling_params,
monte::CompletionCheckParams<StatisticsType> _completion_check_params,
std::vector<std::string> _analysis_names,
std::unique_ptr<results_io_type> _results_io = nullptr,
monte::MethodLog _method_log = monte::MethodLog())
: label(_label),
Expand All @@ -44,6 +45,7 @@ struct SamplingFixtureParams {
analysis_functions(_analysis_functions),
sampling_params(_sampling_params),
completion_check_params(_completion_check_params),
analysis_names(_analysis_names),
results_io(std::move(_results_io)),
method_log(_method_log) {
for (auto const &name : sampling_params.sampler_names) {
Expand Down Expand Up @@ -84,6 +86,9 @@ struct SamplingFixtureParams {
/// Completion check params
monte::CompletionCheckParams<StatisticsType> completion_check_params;

/// Analysis functions to evaluate
std::vector<std::string> analysis_names;

/// Results I/O implementation -- May be empty
notstd::cloneable_ptr<results_io_type> results_io;

Expand Down Expand Up @@ -408,7 +413,8 @@ class SamplingFixture {
Log &log = m_params.method_log.log;
m_results.elapsed_clocktime = log.time_s();
m_results.completion_check_results = m_completion_check.results();
m_results.analysis = make_analysis(m_results, m_params.analysis_functions);
m_results.analysis = make_analysis(m_results, m_params.analysis_functions,
m_params.analysis_names);
m_results.n_accept = m_counter.n_accept;
m_results.n_reject = m_counter.n_reject;

Expand Down
3 changes: 3 additions & 0 deletions include/casm/monte/run_management/io/ResultsIO.hh
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

#include <vector>

#include "casm/casm_io/json/jsonParser.hh"
#include "casm/global/definitions.hh"
#include "casm/misc/cloneable_ptr.hh"

Expand All @@ -19,6 +20,8 @@ class ResultsIO : public notstd::Cloneable {

virtual void write(results_type const &results, ValueMap const &conditions,
Index run_index) = 0;

virtual jsonParser to_json() = 0;
};

} // namespace monte
Expand Down
Loading

0 comments on commit dec4619

Please sign in to comment.