Skip to content

Commit e998a07

Browse files
committed
Refactor YAML conversion logic directly into configuration classes and remove YAMLConverters
- Moved YAML encoding and decoding logic from `YAMLConverters.h` into the respective configuration classes (`ModelSettings`, `PopulationDemographic`, and `TransmissionSettings`). - Removed unnecessary includes and deleted `YAMLConverters.h` since its functionality is now embedded in the configuration classes. - Updated the `Config.cpp` and `Config.h` to remove mutex-related includes and macros (`GETTER` and `SETTER`), simplifying field access. - Refactored test files to reflect the updated structure by removing references to the now-deleted `YAMLConverters.h`. - Introduced direct YAML serialization and deserialization for `date::year_month_day` inside the `ModelSettings` and `PopulationDemographic` classes.
1 parent 762c8ce commit e998a07

9 files changed

+191
-208
lines changed

src/Configuration/Config.cpp

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,6 @@
33

44
#include <yaml-cpp/yaml.h>
55

6-
#include <mutex>
7-
8-
#include "YAMLConverters.h"
9-
106
void Config::load(const std::string &filename) {
117
config_file_path_ = filename;
128
YAML::Node config = YAML::LoadFile(filename);

src/Configuration/Config.h

Lines changed: 0 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -9,17 +9,6 @@
99

1010
#include "ConfigData.h"
1111

12-
#define GETTER(Type, StructName, FieldName, Validator) \
13-
[[nodiscard]] Type get_##FieldName() const { \
14-
return get_field(config_data_.StructName.FieldName); \
15-
}
16-
17-
#define SETTER(Type, StructName, FieldName, Validator) \
18-
void set_##FieldName(Type FieldName) { \
19-
Validator(FieldName); \
20-
set_field(config_data_.StructName.FieldName, FieldName); \
21-
}
22-
2312
class Config {
2413
public:
2514
// Constructor and Destructor

src/Configuration/ModelSettings.h

Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
#define MODEL_SETTINGS_H
33

44
#include <date/date.h>
5+
#include <yaml-cpp/yaml.h>
56

67
#include <stdexcept>
78

@@ -72,5 +73,86 @@ class ModelSettings {
7273
date::year_month_day ending_date_;
7374
int start_collect_data_day_;
7475
};
76+
77+
namespace YAML {
78+
template <>
79+
struct convert<date::year_month_day> {
80+
static Node encode(const date::year_month_day &rhs) {
81+
std::stringstream ss;
82+
ss << rhs;
83+
return Node(ss.str());
84+
}
85+
86+
static bool decode(const Node &node, date::year_month_day &rhs) {
87+
if (!node.IsScalar()) {
88+
throw std::runtime_error("Invalid date format: not a scalar.");
89+
}
90+
91+
std::stringstream ss(node.as<std::string>());
92+
date::year_month_day ymd{};
93+
ss >> date::parse("%F", ymd); // %F matches YYYY-MM-DD format
94+
95+
if (ss.fail()) {
96+
throw std::runtime_error("Invalid date format: failed to parse.");
97+
}
98+
99+
rhs = ymd;
100+
return true;
101+
}
102+
};
103+
104+
template <>
105+
struct convert<ModelSettings> {
106+
static Node encode(const ModelSettings &rhs) {
107+
Node node;
108+
node["days_between_stdout_output"] = rhs.get_days_between_stdout_output();
109+
node["initial_seed_number"] = rhs.get_initial_seed_number();
110+
node["record_genome_db"] = rhs.get_record_genome_db();
111+
node["starting_date"] = rhs.get_starting_date();
112+
node["start_of_comparison_period"] = rhs.get_start_of_comparison_period();
113+
node["ending_date"] = rhs.get_ending_date();
114+
node["start_collect_data_day"] = rhs.get_start_collect_data_day();
115+
return node;
116+
}
117+
118+
static bool decode(const Node &node, ModelSettings &rhs) {
119+
if (!node["days_between_stdout_output"]) {
120+
throw std::runtime_error("Missing 'days_between_stdout_output' field.");
121+
}
122+
if (!node["initial_seed_number"]) {
123+
throw std::runtime_error("Missing 'initial_seed_number' field.");
124+
}
125+
if (!node["record_genome_db"]) {
126+
throw std::runtime_error("Missing 'record_genome_db' field.");
127+
}
128+
if (!node["starting_date"]) {
129+
throw std::runtime_error("Missing 'starting_date' field.");
130+
}
131+
if (!node["start_of_comparison_period"]) {
132+
throw std::runtime_error("Missing 'start_of_comparison_period' field.");
133+
}
134+
if (!node["ending_date"]) {
135+
throw std::runtime_error("Missing 'ending_date' field.");
136+
}
137+
if (!node["start_collect_data_day"]) {
138+
throw std::runtime_error("Missing 'start_collect_data_day' field.");
139+
}
140+
141+
// TODO: Add more error checking for each field
142+
143+
rhs.set_days_between_stdout_output(
144+
node["days_between_stdout_output"].as<int>());
145+
rhs.set_initial_seed_number(node["initial_seed_number"].as<int>());
146+
rhs.set_record_genome_db(node["record_genome_db"].as<bool>());
147+
rhs.set_starting_date(node["starting_date"].as<date::year_month_day>());
148+
rhs.set_start_of_comparison_period(
149+
node["start_of_comparison_period"].as<date::year_month_day>());
150+
rhs.set_ending_date(node["ending_date"].as<date::year_month_day>());
151+
rhs.set_start_collect_data_day(node["start_collect_data_day"].as<int>());
152+
return true;
153+
}
154+
};
155+
} // namespace YAML
156+
75157
#endif // MODEL_SETTINGS_H
76158

src/Configuration/PopulationDemographic.h

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
#ifndef POPULATION_DEMOGRAPHIC_H
22
#define POPULATION_DEMOGRAPHIC_H
33

4+
#include <yaml-cpp/yaml.h>
5+
46
#include <stdexcept>
57
#include <vector>
68

@@ -91,5 +93,78 @@ class PopulationDemographic {
9193
double artificial_rescaling_of_population_size_;
9294
};
9395

96+
namespace YAML {
97+
template <>
98+
struct convert<PopulationDemographic> {
99+
static Node encode(const PopulationDemographic &rhs) {
100+
Node node;
101+
node["number_of_age_classes"] = rhs.get_number_of_age_classes();
102+
node["age_structure"] = rhs.get_age_structure();
103+
node["initial_age_structure"] = rhs.get_initial_age_structure();
104+
node["birth_rate"] = rhs.get_birth_rate();
105+
node["death_rate_by_age_class"] = rhs.get_death_rate_by_age_class();
106+
node["mortality_when_treatment_fail_by_age_class"] =
107+
rhs.get_mortality_when_treatment_fail_by_age_class();
108+
node["artificial_rescaling_of_population_size"] =
109+
rhs.get_artificial_rescaling_of_population_size();
110+
return node;
111+
}
112+
113+
static bool decode(const Node &node, PopulationDemographic &rhs) {
114+
if (!node["number_of_age_classes"]) {
115+
throw std::runtime_error("Missing 'number_of_age_classes' field.");
116+
}
117+
if (!node["age_structure"]) {
118+
throw std::runtime_error("Missing 'age_structure' field.");
119+
}
120+
if (!node["initial_age_structure"]) {
121+
throw std::runtime_error("Missing 'initial_age_structure' field.");
122+
}
123+
if (!node["birth_rate"]) {
124+
throw std::runtime_error("Missing 'birth_rate' field.");
125+
}
126+
if (!node["death_rate_by_age_class"]) {
127+
throw std::runtime_error("Missing 'death_rate_by_age_class' field.");
128+
}
129+
if (!node["mortality_when_treatment_fail_by_age_class"]) {
130+
throw std::runtime_error(
131+
"Missing 'mortality_when_treatment_fail_by_age_class' field.");
132+
}
133+
if (!node["artificial_rescaling_of_population_size"]) {
134+
throw std::runtime_error(
135+
"Missing 'artificial_rescaling_of_population_size' field.");
136+
}
137+
138+
int number_of_age_classes = node["number_of_age_classes"].as<int>();
139+
rhs.set_number_of_age_classes(number_of_age_classes);
140+
141+
// Validate and assign age structure vectors
142+
auto age_structure = node["age_structure"].as<std::vector<int>>();
143+
rhs.set_age_structure(age_structure);
144+
145+
auto initial_age_structure =
146+
node["initial_age_structure"].as<std::vector<int>>();
147+
rhs.set_initial_age_structure(initial_age_structure);
148+
149+
rhs.set_birth_rate(node["birth_rate"].as<double>());
150+
151+
auto death_rate_by_age_class =
152+
node["death_rate_by_age_class"].as<std::vector<double>>();
153+
rhs.set_death_rate_by_age_class(death_rate_by_age_class);
154+
155+
auto mortality_when_treatment_fail_by_age_class =
156+
node["mortality_when_treatment_fail_by_age_class"]
157+
.as<std::vector<double>>();
158+
rhs.set_mortality_when_treatment_fail_by_age_class(
159+
mortality_when_treatment_fail_by_age_class);
160+
161+
rhs.set_artificial_rescaling_of_population_size(
162+
node["artificial_rescaling_of_population_size"].as<double>());
163+
164+
return true;
165+
}
166+
};
167+
} // namespace YAML
168+
94169
#endif // POPULATION_DEMOGRAPHIC_H
95170

src/Configuration/TransmissionSettings.h

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
#ifndef TRANSMISSION_SETTINGS_H
22
#define TRANSMISSION_SETTINGS_H
33

4+
#include <yaml-cpp/yaml.h>
5+
46
#include <stdexcept>
57

68
class TransmissionSettings {
@@ -35,5 +37,36 @@ class TransmissionSettings {
3537
double transmission_parameter_;
3638
double p_infection_from_an_infectious_bite_;
3739
};
40+
41+
namespace YAML {
42+
43+
template <>
44+
struct convert<TransmissionSettings> {
45+
static Node encode(const TransmissionSettings &rhs) {
46+
Node node;
47+
node["transmission_parameter"] = rhs.get_transmission_parameter();
48+
node["p_infection_from_an_infectious_bite"] =
49+
rhs.get_p_infection_from_an_infectious_bite();
50+
return node;
51+
}
52+
53+
static bool decode(const Node &node, TransmissionSettings &rhs) {
54+
if (!node["transmission_parameter"]) {
55+
throw std::runtime_error("Missing 'transmission_parameter' field.");
56+
}
57+
if (!node["p_infection_from_an_infectious_bite"]) {
58+
throw std::runtime_error(
59+
"Missing 'p_infection_from_an_infectious_bite' field.");
60+
}
61+
62+
rhs.set_transmission_parameter(node["transmission_parameter"].as<double>());
63+
rhs.set_p_infection_from_an_infectious_bite(
64+
node["p_infection_from_an_infectious_bite"].as<double>());
65+
return true;
66+
}
67+
};
68+
69+
} // namespace YAML
70+
3871
#endif // TRANSMISSION_SETTINGS_H
3972

0 commit comments

Comments
 (0)