diff --git a/src/stan/callbacks/dispatcher.hpp b/src/stan/callbacks/dispatcher.hpp index 3280c105e5..4593a8d51e 100644 --- a/src/stan/callbacks/dispatcher.hpp +++ b/src/stan/callbacks/dispatcher.hpp @@ -12,159 +12,158 @@ #include #include - namespace stan { namespace callbacks { - enum class InfoType { - CONFIG, // series of string messages - SAMPLE, // draw from posterior - METRIC, // struct with kv pairs 'metric_type', 'stepsize', 'inv_metric' - ALGORITHM_STATE, // sampler state for returned draw - }; - - struct InfoTypeHash { - std::size_t operator()(const InfoType& type) const { - return std::hash()(static_cast(type)); - } - }; - - // Base type for type erasure. - class Channel { - public: - virtual ~Channel() = default; - }; - - // Adapter for plain writers. - class WriterChannel : public Channel { - public: - explicit WriterChannel(stan::callbacks::writer* w) : writer_(w) { - if (!w) { - throw std::runtime_error("config error, null writer"); - } - } - - // Handle all types that writer supports via operator() - void dispatch() { (*writer_)(); } - void dispatch(const std::string& value) { (*writer_)(value); } - void dispatch(const std::vector& value) {(*writer_)(value); } - void dispatch(const std::vector& value) { - (*writer_)(value); - } - - // Handle any Eigen Matrix type - template - void dispatch(const Eigen::Matrix& value) { - (*writer_)(value); - } - - // No key-value support for plain writers - template - void dispatch(const std::string&, const T&) {} - - private: - stan::callbacks::writer* writer_; - }; - - // Adapter for structured writers. - class StructuredWriterChannel : public Channel { - public: - explicit StructuredWriterChannel(stan::callbacks::structured_writer* sw) +enum class InfoType { + CONFIG, // series of string messages + SAMPLE, // draw from posterior + METRIC, // struct with kv pairs 'metric_type', 'stepsize', 'inv_metric' + ALGORITHM_STATE, // sampler state for returned draw +}; + +struct InfoTypeHash { + std::size_t operator()(const InfoType& type) const { + return std::hash()(static_cast(type)); + } +}; + +// Base type for type erasure. +class Channel { + public: + virtual ~Channel() = default; +}; + +// Adapter for plain writers. +class WriterChannel : public Channel { + public: + explicit WriterChannel(stan::callbacks::writer* w) : writer_(w) { + if (!w) { + throw std::runtime_error("config error, null writer"); + } + } + + // Handle all types that writer supports via operator() + void dispatch() { (*writer_)(); } + void dispatch(const std::string& value) { (*writer_)(value); } + void dispatch(const std::vector& value) { (*writer_)(value); } + void dispatch(const std::vector& value) { (*writer_)(value); } + + // Handle any Eigen Matrix type + template + void dispatch(const Eigen::Matrix& value) { + (*writer_)(value); + } + + // No key-value support for plain writers + template + void dispatch(const std::string&, const T&) {} + + private: + stan::callbacks::writer* writer_; +}; + +// Adapter for structured writers. +class StructuredWriterChannel : public Channel { + public: + explicit StructuredWriterChannel(stan::callbacks::structured_writer* sw) : writer_(sw) { - if (!sw) throw std::runtime_error("config error, null writer"); - } - // Forward all key-value calls directly to the writer - void dispatch(const std::string& key) { writer_->write(key); } - // Perfect forwarding for any key-value pair - template - void dispatch(const std::string& key, T&& value) { - writer_->write(key, std::forward(value)); - } - void begin_record() { writer_->begin_record(); } - void begin_record(const std::string& key) { writer_->begin_record(key); } - void end_record() { writer_->end_record(); } - - private: - stan::callbacks::structured_writer* writer_; - }; - - // dispatcher class - class dispatcher { - public: - dispatcher() = default; - ~dispatcher() = default; - - void register_channel(InfoType type, std::unique_ptr channel) { - channels_[type] = std::move(channel); - } - - // Empty call - void dispatch(InfoType type) { - if (auto* wc = find_channel(type)) - wc->dispatch(); - } - - // String, vector, vector - template , std::string> || - std::is_same_v, std::vector> || - std::is_same_v, std::vector> - >> - void dispatch(InfoType type, T&& value) { - if (auto* wc = find_channel(type)) - wc->dispatch(std::forward(value)); - } - - // Eigen matrix types - template - void dispatch(InfoType type, const Eigen::Matrix& value) { - if (auto* wc = find_channel(type)) - wc->dispatch(value); - } - - // Key with no value (null) - void dispatch(InfoType type, const std::string& key) { - if (auto* sw = find_channel(type)) - sw->dispatch(key); - } - - // Key-value pairs (forward to structured writers) - template - void dispatch(InfoType type, const std::string& key, T&& value) { - if (auto* sw = find_channel(type)) - sw->dispatch(key, std::forward(value)); - } - - // Record operations - void begin_record(InfoType type) { - if (auto* sw = find_channel(type)) - sw->begin_record(); - } - - void begin_record(InfoType type, const std::string& key) { - if (auto* sw = find_channel(type)) - sw->begin_record(key); - } - - void end_record(InfoType type) { - if (auto* sw = find_channel(type)) - sw->end_record(); - } - - private: - // Helper to find and cast a channel of specific type - template - ChannelType* find_channel(InfoType type) { - auto it = channels_.find(type); - if (it == channels_.end()) return nullptr; - return dynamic_cast(it->second.get()); - } - - std::unordered_map, - InfoTypeHash> channels_; - }; + if (!sw) + throw std::runtime_error("config error, null writer"); + } + // Forward all key-value calls directly to the writer + void dispatch(const std::string& key) { writer_->write(key); } + // Perfect forwarding for any key-value pair + template + void dispatch(const std::string& key, T&& value) { + writer_->write(key, std::forward(value)); + } + void begin_record() { writer_->begin_record(); } + void begin_record(const std::string& key) { writer_->begin_record(key); } + void end_record() { writer_->end_record(); } + + private: + stan::callbacks::structured_writer* writer_; +}; + +// dispatcher class +class dispatcher { + public: + dispatcher() = default; + ~dispatcher() = default; + + void register_channel(InfoType type, std::unique_ptr channel) { + channels_[type] = std::move(channel); + } + + // Empty call + void dispatch(InfoType type) { + if (auto* wc = find_channel(type)) + wc->dispatch(); + } + + // String, vector, vector + template < + typename T, + typename = std::enable_if_t< + std::is_same_v< + std::decay_t, + std:: + string> || std::is_same_v, std::vector> || std::is_same_v, std::vector>>> + void dispatch(InfoType type, T&& value) { + if (auto* wc = find_channel(type)) + wc->dispatch(std::forward(value)); + } + + // Eigen matrix types + template + void dispatch(InfoType type, const Eigen::Matrix& value) { + if (auto* wc = find_channel(type)) + wc->dispatch(value); + } + + // Key with no value (null) + void dispatch(InfoType type, const std::string& key) { + if (auto* sw = find_channel(type)) + sw->dispatch(key); + } + + // Key-value pairs (forward to structured writers) + template + void dispatch(InfoType type, const std::string& key, T&& value) { + if (auto* sw = find_channel(type)) + sw->dispatch(key, std::forward(value)); + } + + // Record operations + void begin_record(InfoType type) { + if (auto* sw = find_channel(type)) + sw->begin_record(); + } + + void begin_record(InfoType type, const std::string& key) { + if (auto* sw = find_channel(type)) + sw->begin_record(key); + } + + void end_record(InfoType type) { + if (auto* sw = find_channel(type)) + sw->end_record(); + } + + private: + // Helper to find and cast a channel of specific type + template + ChannelType* find_channel(InfoType type) { + auto it = channels_.find(type); + if (it == channels_.end()) + return nullptr; + return dynamic_cast(it->second.get()); + } + + std::unordered_map, InfoTypeHash> + channels_; +}; } // namespace callbacks } // namespace stan diff --git a/src/test/unit/callbacks/dispatcher_test.cpp b/src/test/unit/callbacks/dispatcher_test.cpp index 39cdb32025..37c2cf635c 100644 --- a/src/test/unit/callbacks/dispatcher_test.cpp +++ b/src/test/unit/callbacks/dispatcher_test.cpp @@ -39,17 +39,20 @@ class DispatcherTest : public ::testing::Test { ss_metric.str(std::string()); ss_metric.clear(); - dispatcher.register_channel(InfoType::CONFIG, - std::unique_ptr( - new stan::callbacks::WriterChannel(&writer_config))); + dispatcher.register_channel( + InfoType::CONFIG, + std::unique_ptr( + new stan::callbacks::WriterChannel(&writer_config))); - dispatcher.register_channel(InfoType::SAMPLE, - std::unique_ptr( - new stan::callbacks::WriterChannel(&writer_sample))); + dispatcher.register_channel( + InfoType::SAMPLE, + std::unique_ptr( + new stan::callbacks::WriterChannel(&writer_sample))); - dispatcher.register_channel(InfoType::METRIC, - std::unique_ptr( - new stan::callbacks::StructuredWriterChannel(&writer_metric))); + dispatcher.register_channel( + InfoType::METRIC, + std::unique_ptr( + new stan::callbacks::StructuredWriterChannel(&writer_metric))); } void TearDown() {} @@ -148,171 +151,174 @@ TEST_F(DispatcherTest, StructuredBeginEndRecord) { EXPECT_NE(output.find("{"), std::string::npos); EXPECT_NE(output.find("}"), std::string::npos); ======= -TEST_F(DispatcherTest, MetricStructuredKeyValueRecord) { - // For METRIC (structured writer), open a record, dispatch key/value pairs, - // then close the record. - dispatcher.begin_record(InfoType::METRIC); - dispatcher.dispatch(InfoType::METRIC, "metric_type", std::string("diag")); - dispatcher.dispatch(InfoType::METRIC, "stepsize", 0.6789); - // For the inv_metric, assume the caller converts the vector to a - // comma-separated string. - std::vector inv_metric = {0.1, 0.2, 0.3}; - std::string inv_metric_str; - for (size_t i = 0; i < inv_metric.size(); ++i) { - inv_metric_str += std::to_string(inv_metric[i]); - if (i != inv_metric.size() - 1) - inv_metric_str += ","; - } - dispatcher.dispatch(InfoType::METRIC, "inv_metric", inv_metric_str); - dispatcher.end_record(InfoType::METRIC); - // Expected output: - // Begin record marker, followed by key/value pairs each formatted as - // "key:value;" and then end record marker. - std::cout << ss_metric.str() << std::endl; + TEST_F(DispatcherTest, MetricStructuredKeyValueRecord) { + // For METRIC (structured writer), open a record, dispatch key/value pairs, + // then close the record. + dispatcher.begin_record(InfoType::METRIC); + dispatcher.dispatch(InfoType::METRIC, "metric_type", std::string("diag")); + dispatcher.dispatch(InfoType::METRIC, "stepsize", 0.6789); + // For the inv_metric, assume the caller converts the vector to a + // comma-separated string. + std::vector inv_metric = {0.1, 0.2, 0.3}; + std::string inv_metric_str; + for (size_t i = 0; i < inv_metric.size(); ++i) { + inv_metric_str += std::to_string(inv_metric[i]); + if (i != inv_metric.size() - 1) + inv_metric_str += ","; + } + dispatcher.dispatch(InfoType::METRIC, "inv_metric", inv_metric_str); + dispatcher.end_record(InfoType::METRIC); + // Expected output: + // Begin record marker, followed by key/value pairs each formatted as + // "key:value;" and then end record marker. + std::cout << ss_metric.str() << std::endl; >>>>>>> 89d756b23a601c560cb63a09930f5fcbce011efc -} + } -// Test structured writer key-value pairs with string value -TEST_F(DispatcherTest, StructuredKeyStringValue) { - dispatcher.begin_record(InfoType::METRIC); - dispatcher.dispatch(InfoType::METRIC, "key1", std::string("value1")); - dispatcher.end_record(InfoType::METRIC); - - std::string output = ss_metric.str(); - EXPECT_NE(output.find("key1"), std::string::npos); - EXPECT_NE(output.find("value1"), std::string::npos); -} + // Test structured writer key-value pairs with string value + TEST_F(DispatcherTest, StructuredKeyStringValue) { + dispatcher.begin_record(InfoType::METRIC); + dispatcher.dispatch(InfoType::METRIC, "key1", std::string("value1")); + dispatcher.end_record(InfoType::METRIC); + + std::string output = ss_metric.str(); + EXPECT_NE(output.find("key1"), std::string::npos); + EXPECT_NE(output.find("value1"), std::string::npos); + } <<<<<<< HEAD -// Test structured writer with multiple key-value types -TEST_F(DispatcherTest, StructuredMultipleValueTypes) { - dispatcher.begin_record(InfoType::METRIC); - dispatcher.dispatch(InfoType::METRIC, "string_key", std::string("string_value")); - dispatcher.dispatch(InfoType::METRIC, "int_key", 42); - dispatcher.dispatch(InfoType::METRIC, "double_key", 3.14159); - dispatcher.dispatch(InfoType::METRIC, "bool_key", true); - dispatcher.end_record(InfoType::METRIC); - - std::string output = ss_metric.str(); - EXPECT_NE(output.find("string_key"), std::string::npos); - EXPECT_NE(output.find("string_value"), std::string::npos); - EXPECT_NE(output.find("int_key"), std::string::npos); - EXPECT_NE(output.find("42"), std::string::npos); - EXPECT_NE(output.find("double_key"), std::string::npos); - EXPECT_NE(output.find("3.14159"), std::string::npos); - EXPECT_NE(output.find("bool_key"), std::string::npos); - EXPECT_NE(output.find("true"), std::string::npos); -} + // Test structured writer with multiple key-value types + TEST_F(DispatcherTest, StructuredMultipleValueTypes) { + dispatcher.begin_record(InfoType::METRIC); + dispatcher.dispatch(InfoType::METRIC, "string_key", + std::string("string_value")); + dispatcher.dispatch(InfoType::METRIC, "int_key", 42); + dispatcher.dispatch(InfoType::METRIC, "double_key", 3.14159); + dispatcher.dispatch(InfoType::METRIC, "bool_key", true); + dispatcher.end_record(InfoType::METRIC); -// Test structured writer with vector values -TEST_F(DispatcherTest, StructuredVectorValues) { - dispatcher.begin_record(InfoType::METRIC); - - std::vector doubles = {1.1, 2.2, 3.3}; - dispatcher.dispatch(InfoType::METRIC, "doubles", doubles); - - std::vector strings = {"one", "two", "three"}; - dispatcher.dispatch(InfoType::METRIC, "strings", strings); - - dispatcher.end_record(InfoType::METRIC); - - std::string output = ss_metric.str(); - EXPECT_NE(output.find("doubles"), std::string::npos); - EXPECT_NE(output.find("1.1"), std::string::npos); - EXPECT_NE(output.find("strings"), std::string::npos); - EXPECT_NE(output.find("one"), std::string::npos); -} + std::string output = ss_metric.str(); + EXPECT_NE(output.find("string_key"), std::string::npos); + EXPECT_NE(output.find("string_value"), std::string::npos); + EXPECT_NE(output.find("int_key"), std::string::npos); + EXPECT_NE(output.find("42"), std::string::npos); + EXPECT_NE(output.find("double_key"), std::string::npos); + EXPECT_NE(output.find("3.14159"), std::string::npos); + EXPECT_NE(output.find("bool_key"), std::string::npos); + EXPECT_NE(output.find("true"), std::string::npos); + } -// Test structured writer with Eigen values -TEST_F(DispatcherTest, StructuredEigenValues) { - dispatcher.begin_record(InfoType::METRIC); - - Eigen::MatrixXd matrix(2, 2); - matrix << 1.0, 2.0, 3.0, 4.0; - dispatcher.dispatch(InfoType::METRIC, "matrix", matrix); - - Eigen::VectorXd vector(3); - vector << 5.0, 6.0, 7.0; - dispatcher.dispatch(InfoType::METRIC, "vector", vector); - - dispatcher.end_record(InfoType::METRIC); - - std::string output = ss_metric.str(); - EXPECT_NE(output.find("matrix"), std::string::npos); - EXPECT_NE(output.find("1"), std::string::npos); - EXPECT_NE(output.find("4"), std::string::npos); - EXPECT_NE(output.find("vector"), std::string::npos); - EXPECT_NE(output.find("5"), std::string::npos); - EXPECT_NE(output.find("7"), std::string::npos); -} + // Test structured writer with vector values + TEST_F(DispatcherTest, StructuredVectorValues) { + dispatcher.begin_record(InfoType::METRIC); -// Test unregistered channel -TEST_F(DispatcherTest, UnregisteredChannel) { - // Dispatch to unregistered channel should silently do nothing - dispatcher.dispatch(InfoType::ALGORITHM_STATE, std::string("Message")); - dispatcher.dispatch(InfoType::ALGORITHM_STATE, std::vector{1.0, 2.0}); - dispatcher.begin_record(InfoType::ALGORITHM_STATE); - dispatcher.dispatch(InfoType::ALGORITHM_STATE, "key", "value"); - dispatcher.end_record(InfoType::ALGORITHM_STATE); - - // No exceptions should be thrown -} + std::vector doubles = {1.1, 2.2, 3.3}; + dispatcher.dispatch(InfoType::METRIC, "doubles", doubles); -// Test named record -TEST_F(DispatcherTest, NamedRecord) { - dispatcher.begin_record(InfoType::METRIC, "record_name"); - dispatcher.dispatch(InfoType::METRIC, "key", "value"); - dispatcher.end_record(InfoType::METRIC); - - std::string output = ss_metric.str(); - EXPECT_NE(output.find("record_name"), std::string::npos); - EXPECT_NE(output.find("key"), std::string::npos); - EXPECT_NE(output.find("value"), std::string::npos); -} + std::vector strings = {"one", "two", "three"}; + dispatcher.dispatch(InfoType::METRIC, "strings", strings); -// Test that begin_record and end_record on a plain writer channel are silently ignored -TEST_F(DispatcherTest, RecordOperationsOnPlainWriter) { - dispatcher.begin_record(InfoType::CONFIG); - dispatcher.end_record(InfoType::CONFIG); - - // Should not generate any output - EXPECT_EQ(ss_config.str(), ""); -} + dispatcher.end_record(InfoType::METRIC); -// Test complex sampler metric output pattern -TEST_F(DispatcherTest, ComplexSamplerMetricPattern) { - // This test simulates a more complex real-world usage pattern - - // Start a record for a sampling iteration - dispatcher.begin_record(InfoType::METRIC); - - // Add various diagnostic info - dispatcher.dispatch(InfoType::METRIC, "iter", 10); - dispatcher.dispatch(InfoType::METRIC, "lp", -105.2); - dispatcher.dispatch(InfoType::METRIC, "accept_stat", 0.8); - - // Add a nested object for adaptation - dispatcher.begin_record(InfoType::METRIC, "adaptation"); - dispatcher.dispatch(InfoType::METRIC, "step_size", 0.85); - - // Add an inverse metric matrix - Eigen::MatrixXd inv_metric(2, 2); - inv_metric << 1.2, 0.1, 0.1, 0.9; - dispatcher.dispatch(InfoType::METRIC, "inv_metric", inv_metric); - - // End adaptation object - dispatcher.end_record(InfoType::METRIC); - - // End the main record - dispatcher.end_record(InfoType::METRIC); - - // Verify key entries exist in the output - std::string output = ss_metric.str(); - EXPECT_NE(output.find("iter"), std::string::npos); - EXPECT_NE(output.find("10"), std::string::npos); - EXPECT_NE(output.find("lp"), std::string::npos); - EXPECT_NE(output.find("-105.2"), std::string::npos); - EXPECT_NE(output.find("adaptation"), std::string::npos); - EXPECT_NE(output.find("step_size"), std::string::npos); - EXPECT_NE(output.find("inv_metric"), std::string::npos); -} + std::string output = ss_metric.str(); + EXPECT_NE(output.find("doubles"), std::string::npos); + EXPECT_NE(output.find("1.1"), std::string::npos); + EXPECT_NE(output.find("strings"), std::string::npos); + EXPECT_NE(output.find("one"), std::string::npos); + } + + // Test structured writer with Eigen values + TEST_F(DispatcherTest, StructuredEigenValues) { + dispatcher.begin_record(InfoType::METRIC); + + Eigen::MatrixXd matrix(2, 2); + matrix << 1.0, 2.0, 3.0, 4.0; + dispatcher.dispatch(InfoType::METRIC, "matrix", matrix); + + Eigen::VectorXd vector(3); + vector << 5.0, 6.0, 7.0; + dispatcher.dispatch(InfoType::METRIC, "vector", vector); + + dispatcher.end_record(InfoType::METRIC); + + std::string output = ss_metric.str(); + EXPECT_NE(output.find("matrix"), std::string::npos); + EXPECT_NE(output.find("1"), std::string::npos); + EXPECT_NE(output.find("4"), std::string::npos); + EXPECT_NE(output.find("vector"), std::string::npos); + EXPECT_NE(output.find("5"), std::string::npos); + EXPECT_NE(output.find("7"), std::string::npos); + } + + // Test unregistered channel + TEST_F(DispatcherTest, UnregisteredChannel) { + // Dispatch to unregistered channel should silently do nothing + dispatcher.dispatch(InfoType::ALGORITHM_STATE, std::string("Message")); + dispatcher.dispatch(InfoType::ALGORITHM_STATE, + std::vector{1.0, 2.0}); + dispatcher.begin_record(InfoType::ALGORITHM_STATE); + dispatcher.dispatch(InfoType::ALGORITHM_STATE, "key", "value"); + dispatcher.end_record(InfoType::ALGORITHM_STATE); + + // No exceptions should be thrown + } + + // Test named record + TEST_F(DispatcherTest, NamedRecord) { + dispatcher.begin_record(InfoType::METRIC, "record_name"); + dispatcher.dispatch(InfoType::METRIC, "key", "value"); + dispatcher.end_record(InfoType::METRIC); + + std::string output = ss_metric.str(); + EXPECT_NE(output.find("record_name"), std::string::npos); + EXPECT_NE(output.find("key"), std::string::npos); + EXPECT_NE(output.find("value"), std::string::npos); + } + + // Test that begin_record and end_record on a plain writer channel are + // silently ignored + TEST_F(DispatcherTest, RecordOperationsOnPlainWriter) { + dispatcher.begin_record(InfoType::CONFIG); + dispatcher.end_record(InfoType::CONFIG); + + // Should not generate any output + EXPECT_EQ(ss_config.str(), ""); + } + + // Test complex sampler metric output pattern + TEST_F(DispatcherTest, ComplexSamplerMetricPattern) { + // This test simulates a more complex real-world usage pattern + + // Start a record for a sampling iteration + dispatcher.begin_record(InfoType::METRIC); + + // Add various diagnostic info + dispatcher.dispatch(InfoType::METRIC, "iter", 10); + dispatcher.dispatch(InfoType::METRIC, "lp", -105.2); + dispatcher.dispatch(InfoType::METRIC, "accept_stat", 0.8); + + // Add a nested object for adaptation + dispatcher.begin_record(InfoType::METRIC, "adaptation"); + dispatcher.dispatch(InfoType::METRIC, "step_size", 0.85); + + // Add an inverse metric matrix + Eigen::MatrixXd inv_metric(2, 2); + inv_metric << 1.2, 0.1, 0.1, 0.9; + dispatcher.dispatch(InfoType::METRIC, "inv_metric", inv_metric); + + // End adaptation object + dispatcher.end_record(InfoType::METRIC); + + // End the main record + dispatcher.end_record(InfoType::METRIC); + + // Verify key entries exist in the output + std::string output = ss_metric.str(); + EXPECT_NE(output.find("iter"), std::string::npos); + EXPECT_NE(output.find("10"), std::string::npos); + EXPECT_NE(output.find("lp"), std::string::npos); + EXPECT_NE(output.find("-105.2"), std::string::npos); + EXPECT_NE(output.find("adaptation"), std::string::npos); + EXPECT_NE(output.find("step_size"), std::string::npos); + EXPECT_NE(output.find("inv_metric"), std::string::npos); + }