-
-
Notifications
You must be signed in to change notification settings - Fork 375
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
31dbf30
commit 37cbdfb
Showing
2 changed files
with
454 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,207 @@ | ||
#ifndef STAN_CALLBACKS_DISPATCHER_HPP | ||
#define STAN_CALLBACKS_DISPATCHER_HPP | ||
|
||
#include <stan/callbacks/writer.hpp> | ||
#include <stan/callbacks/structured_writer.hpp> | ||
#include <memory> | ||
#include <unordered_map> | ||
#include <string> | ||
#include <vector> | ||
#include <stdexcept> | ||
#include <utility> | ||
#include <type_traits> | ||
#include <variant> | ||
|
||
namespace stan { | ||
namespace callbacks { | ||
|
||
/** | ||
* The <code>dispatcher</code> class manages a set of callbacks | ||
* for all outputs of one run of a Stan service. | ||
* Calls to the dispatcher's `dispatch` method are forwarded to | ||
* the callback registered on the channel. | ||
*/ | ||
|
||
|
||
/** | ||
* Enum `info_type` holds output type labels which are used by | ||
* the dispatcher class to map outputs to output channels. | ||
*/ | ||
enum class info_type { | ||
CONFIG, // series of string messages | ||
SAMPLE, // draw from posterior | ||
SAMPLE_RAW, // draw from posterior | ||
METRIC, // struct with kv pairs 'metric_type', 'stepsize', 'inv_metric' | ||
ALGORITHM_STATE, // sampler state for returned draw | ||
}; | ||
|
||
/** | ||
* Efficient enum type lookups | ||
*/ | ||
struct info_type_hash { | ||
std::size_t operator()(const info_type& type) const { | ||
return std::hash<int>()(static_cast<int>(type)); | ||
} | ||
}; | ||
|
||
|
||
/** | ||
* Base type for all callbacks, needed for type erasure. | ||
*/ | ||
class channel { | ||
public: | ||
virtual ~channel() = default; | ||
}; | ||
|
||
|
||
/** | ||
* A `writer_channel` holds a reference to a stan::callbacks::writer object | ||
* and forwards information to the writer's operator (). | ||
*/ | ||
class writer_channel : public channel { | ||
private: | ||
stan::callbacks::writer* writer_; | ||
|
||
public: | ||
explicit writer_channel(stan::callbacks::writer* w) : writer_(w) { | ||
if (!w) { | ||
throw std::runtime_error("config error, null writer"); | ||
} | ||
} | ||
|
||
// Handle all types that writer supports via operator() | ||
void dispatch() { (*writer_)(); } | ||
void dispatch(const std::string& value) { (*writer_)(value); } | ||
void dispatch(const std::vector<double>& value) { (*writer_)(value); } | ||
void dispatch(const std::vector<std::string>& value) { (*writer_)(value); } | ||
|
||
// Handle any Eigen Matrix type | ||
template <int R, int C> | ||
void dispatch(const Eigen::Matrix<double, R, C>& value) { | ||
(*writer_)(value); | ||
} | ||
|
||
// No key-value support for plain writers | ||
template <typename T> | ||
void dispatch(const std::string&, const T&) {} | ||
}; | ||
|
||
/** | ||
* A `structured writer_channel` holds a reference to a | ||
* stan::callbacks::structured_writer object and forwards | ||
* information to the appropriate method. | ||
*/ | ||
class structured_writer_channel : public channel { | ||
private: | ||
stan::callbacks::structured_writer* writer_; | ||
|
||
public: | ||
explicit structured_writer_channel(stan::callbacks::structured_writer* sw) | ||
: writer_(sw) { | ||
if (!sw) | ||
throw std::runtime_error("config error, null writer"); | ||
} | ||
// Forward all key-value calls directly to the writer | ||
void dispatch(const std::string& key) { writer_->write(key); } | ||
// Perfect forwarding for any key-value pair | ||
template <typename T> | ||
void dispatch(const std::string& key, T&& value) { | ||
writer_->write(key, std::forward<T>(value)); | ||
} | ||
void begin_record() { writer_->begin_record(); } | ||
void begin_record(const std::string& key) { writer_->begin_record(key); } | ||
void end_record() { writer_->end_record(); } | ||
}; | ||
|
||
/** | ||
* The `dispatcher` class provides methods to register and find output channels | ||
* and overloads method `dispatch` which forwards outputs to callbacks. | ||
*/ | ||
class dispatcher { | ||
private: | ||
/* Lookup registered channels for info_type. | ||
* Returns nullptr if no channel found. | ||
*/ | ||
template <typename channel_type> | ||
channel_type* find_channel(info_type type) { | ||
auto it = channels_.find(type); | ||
if (it == channels_.end()) | ||
return nullptr; | ||
return dynamic_cast<channel_type*>(it->second.get()); | ||
} | ||
|
||
std::unordered_map<info_type, std::unique_ptr<channel>, info_type_hash> | ||
channels_; | ||
|
||
public: | ||
dispatcher() = default; | ||
~dispatcher() = default; | ||
|
||
/* Add channel to map. | ||
* Assumes a 1:1 mapping between info type and callback. | ||
*/ | ||
void register_channel(info_type type, std::unique_ptr<channel> channel) { | ||
channels_[type] = std::move(channel); | ||
} | ||
|
||
// no-arg call to writer operator () | ||
void dispatch(info_type type) { | ||
if (auto* wc = find_channel<writer_channel>(type)) | ||
wc->dispatch(); | ||
} | ||
|
||
// Value is std::vector<double>, or std::vector<std::string> | ||
// clang-format off | ||
template <typename T, | ||
typename = std::enable_if_t< | ||
std::is_same_v<std::decay_t<T>, std::vector<double>> | ||
|| std::is_same_v<std::decay_t<T>, std::vector<std::string>>>> // NOLINT | ||
// clang-format on | ||
void dispatch(info_type type, T&& value) { | ||
if (auto* wc = find_channel<writer_channel>(type)) | ||
wc->dispatch(std::forward<T>(value)); | ||
} | ||
|
||
// Value is Eigen vector or matrix | ||
template <int R, int C> | ||
void dispatch(info_type type, const Eigen::Matrix<double, R, C>& value) { | ||
if (auto* wc = find_channel<writer_channel>(type)) | ||
wc->dispatch(value); | ||
} | ||
|
||
// Value is std::string | ||
void dispatch(info_type type, const std::string& value) { | ||
if (auto* wc = find_channel<writer_channel>(type)) | ||
wc->dispatch(value); | ||
else if (auto* sw = find_channel<structured_writer_channel>(type)) | ||
sw->dispatch(value); // (sic: actually the key part of k-v pair) | ||
} | ||
|
||
// Key-value pairs (forward to structured writers) | ||
template <typename T> | ||
void dispatch(info_type type, const std::string& key, T&& value) { | ||
if (auto* sw = find_channel<structured_writer_channel>(type)) | ||
sw->dispatch(key, std::forward<T>(value)); | ||
} | ||
|
||
// Record operations | ||
void begin_record(info_type type) { | ||
if (auto* sw = find_channel<structured_writer_channel>(type)) | ||
sw->begin_record(); | ||
} | ||
|
||
void begin_record(info_type type, const std::string& key) { | ||
if (auto* sw = find_channel<structured_writer_channel>(type)) | ||
sw->begin_record(key); | ||
} | ||
|
||
void end_record(info_type type) { | ||
if (auto* sw = find_channel<structured_writer_channel>(type)) | ||
sw->end_record(); | ||
} | ||
}; | ||
|
||
} // namespace callbacks | ||
} // namespace stan | ||
|
||
#endif // STAN_CALLBACKS_DISPATCHER_HPP |
Oops, something went wrong.