diff --git a/include/libnuraft/asio_service_options.hxx b/include/libnuraft/asio_service_options.hxx index 7598dd8a..c719ec6a 100644 --- a/include/libnuraft/asio_service_options.hxx +++ b/include/libnuraft/asio_service_options.hxx @@ -19,6 +19,7 @@ limitations under the License. #include #include +#include namespace nuraft { @@ -59,6 +60,12 @@ struct asio_service_meta_cb_params { uint64_t commit_idx_; }; +/** + * Response callback function for customer resolvers. + */ +using asio_service_custom_resolver_response = + std::function< void(const std::string&, const std::string&, std::error_code) >; + /** * Options used for initialization of Asio service. */ @@ -76,56 +83,102 @@ struct asio_service_options { , read_resp_meta_(nullptr) , invoke_resp_cb_on_empty_meta_(true) , verify_sn_(nullptr) + , custom_resolver_(nullptr) {} - // Number of ASIO worker threads. - // If zero, it will be automatically set to number of cores. + /** + * Number of ASIO worker threads. + * If zero, it will be automatically set to number of cores. + */ size_t thread_pool_size_; - // Lifecycle callback function on worker thread start. + /** + * Lifecycle callback function on worker thread start. + */ std::function< void(uint32_t) > worker_start_; - // Lifecycle callback function on worker thread stop. + /** + * Lifecycle callback function on worker thread stop. + */ std::function< void(uint32_t) > worker_stop_; - // If `true`, enable SSL/TLS secure connection. + /** + * If `true`, enable SSL/TLS secure connection. + */ bool enable_ssl_; - // If `true`, skip certificate verification. + /** + * If `true`, skip certificate verification. + */ bool skip_verification_; - // Path to certification & key files. + /** + * Path to server certificate file. + */ std::string server_cert_file_; + + /** + * Path to server key file. + */ std::string server_key_file_; + + /** + * Path to root certificate file. + */ std::string root_cert_file_; - // Callback function for writing Raft RPC request metadata. + /** + * Callback function for writing Raft RPC request metadata. + */ std::function< std::string(const asio_service_meta_cb_params&) > write_req_meta_; - // Callback function for reading and verifying Raft RPC request metadata. - // If it returns false, the request will be discarded. + /** + * Callback function for reading and verifying Raft RPC request metadata. + * If it returns `false`, the request will be discarded. + */ std::function< bool( const asio_service_meta_cb_params&, const std::string& ) > read_req_meta_; - // If `true`, it will invoke `read_req_meta_` even though - // the received meta is empty. + /** + * If `true`, it will invoke `read_req_meta_` even though + * the received meta is empty. + */ bool invoke_req_cb_on_empty_meta_; - // Callback function for writing Raft RPC response metadata. + /** + * Callback function for writing Raft RPC response metadata. + */ std::function< std::string(const asio_service_meta_cb_params&) > write_resp_meta_; - // Callback function for reading and verifying Raft RPC response metadata. - // If it returns false, the response will be ignored. + /** + * Callback function for reading and verifying Raft RPC response metadata. + * If it returns false, the response will be ignored. + */ std::function< bool( const asio_service_meta_cb_params&, const std::string& ) > read_resp_meta_; - // If `true`, it will invoke `read_resp_meta_` even though - // the received meta is empty. + /** + * If `true`, it will invoke `read_resp_meta_` even though + * the received meta is empty. + */ bool invoke_resp_cb_on_empty_meta_; - // Callback function for verifying certificate subject name. - // If not given, subject name will not be verified. + /** + * Callback function for verifying certificate subject name. + * If not given, subject name will not be verified. + */ std::function< bool(const std::string&) > verify_sn_; + + /** + * Custom IP address resolver. If given, it will be invoked + * before the connection is established. + * + * If you want to selectively bypass some hosts, just pass the given + * host and port to the response function as they are. + */ + std::function< void( const std::string&, + const std::string&, + asio_service_custom_resolver_response ) > custom_resolver_; }; } diff --git a/src/asio_service.cxx b/src/asio_service.cxx index d16645ab..b8b8b403 100644 --- a/src/asio_service.cxx +++ b/src/asio_service.cxx @@ -1017,39 +1017,37 @@ class asio_rpc_client break; } - asio::ip::tcp::resolver::query q - ( host_, port_, asio::ip::tcp::resolver::query::all_matching ); - - resolver_.async_resolve - ( q, - [self, this, req, when_done, send_timeout_ms] - ( std::error_code err, - asio::ip::tcp::resolver::iterator itor ) -> void - { - if (!err) { - asio::async_connect - ( socket(), - itor, - std::bind( &asio_rpc_client::connected, - self, - req, - when_done, - send_timeout_ms, - std::placeholders::_1, - std::placeholders::_2 ) ); - } else { - ptr rsp; - ptr except - ( cs_new - ( lstrfmt("failed to resolve host %s " - "due to error %d, %s") - .fmt( host_.c_str(), - err.value(), - err.message().c_str() ), - req ) ); - when_done(rsp, except); - } - } ); + if (impl_->get_options().custom_resolver_) { + impl_->get_options().custom_resolver_( + host_, + port_, + [this, self, req, when_done, send_timeout_ms] + ( const std::string& resolved_host, + const std::string& resolved_port, + std::error_code err ) { + if (!err) { + p_in( "custom resolver: %s:%s to %s:%s", + host_.c_str(), port_.c_str(), + resolved_host.c_str(), resolved_port.c_str() ); + execute_resolver(self, req, resolved_host, resolved_port, + when_done, send_timeout_ms); + } else { + ptr rsp; + ptr except + ( cs_new + ( lstrfmt("failed to resolve host %s by given " + "custom resolver " + "due to error %d, %s") + .fmt( host_.c_str(), + err.value(), + err.message().c_str() ), + req ) ); + when_done(rsp, except); + } + } ); + } else { + execute_resolver(self, req, host_, port_, when_done, send_timeout_ms); + } return; } @@ -1174,6 +1172,47 @@ class asio_rpc_client std::placeholders::_2 ) ); } private: + void execute_resolver(ptr self, + ptr req, + const std::string& host, + const std::string& port, + rpc_handler when_done, + uint64_t send_timeout_ms) { + asio::ip::tcp::resolver::query q + ( host, port, asio::ip::tcp::resolver::query::all_matching ); + + resolver_.async_resolve + ( q, + [self, this, req, when_done, host, port, send_timeout_ms] + ( std::error_code err, + asio::ip::tcp::resolver::iterator itor ) -> void + { + if (!err) { + asio::async_connect + ( socket(), + itor, + std::bind( &asio_rpc_client::connected, + self, + req, + when_done, + send_timeout_ms, + std::placeholders::_1, + std::placeholders::_2 ) ); + } else { + ptr rsp; + ptr except + ( cs_new + ( lstrfmt("failed to resolve host %s " + "due to error %d, %s") + .fmt( host.c_str(), + err.value(), + err.message().c_str() ), + req ) ); + when_done(rsp, except); + } + } ); + } + void set_busy_flag(bool to) { if (to == true) { bool exp = false; diff --git a/tests/unit/asio_service_test.cxx b/tests/unit/asio_service_test.cxx index d0eb6b70..48d7a9cb 100644 --- a/tests/unit/asio_service_test.cxx +++ b/tests/unit/asio_service_test.cxx @@ -2463,6 +2463,86 @@ int parallel_log_append_test() { return 0; } +int custom_resolver_test() { + reset_log_files(); + + std::string s1_addr = "S1:1234"; + std::string s2_addr = "S2:1234"; + std::string s3_addr = "S3:1234"; + + RaftAsioPkg s1(1, s1_addr); + RaftAsioPkg s2(2, s2_addr); + RaftAsioPkg s3(3, s3_addr); + std::vector pkgs = {&s1, &s2, &s3}; + + // Enable custom resolver. + s1.useCustomResolver = s2.useCustomResolver = s3.useCustomResolver = true; + + _msg("launching asio-raft servers\n"); + CHK_Z( launch_servers(pkgs, false) ); + + _msg("organizing raft group\n"); + CHK_Z( make_group(pkgs) ); + + // Set async mode. + for (auto& entry: pkgs) { + RaftAsioPkg* pp = entry; + raft_params param = pp->raftServer->get_current_params(); + param.return_method_ = raft_params::async_handler; + param.parallel_log_appending_ = true; + pp->raftServer->update_params(param); + } + + // Append messages asynchronously. + const size_t NUM = 10; + std::list< ptr< cmd_result< ptr > > > handlers; + std::list idx_list; + std::mutex idx_list_lock; + auto do_async_append = [&]() { + handlers.clear(); + idx_list.clear(); + for (size_t ii=0; ii msg = buffer::alloc(test_msg.size() + 1); + msg->put(test_msg); + ptr< cmd_result< ptr > > ret = + s1.raftServer->append_entries( {msg} ); + + cmd_result< ptr >::handler_type my_handler = + std::bind( async_handler, + &idx_list, + &idx_list_lock, + std::placeholders::_1, + std::placeholders::_2 ); + ret->when_ready( my_handler ); + + handlers.push_back(ret); + } + }; + do_async_append(); + + TestSuite::sleep_sec(1, "wait for replication"); + + // Now all async handlers should have result. + { + std::lock_guard l(idx_list_lock); + CHK_EQ(NUM, idx_list.size()); + } + + // State machine should be identical. + CHK_OK( s2.getTestSm()->isSame( *s1.getTestSm() ) ); + CHK_OK( s3.getTestSm()->isSame( *s1.getTestSm() ) ); + + s1.raftServer->shutdown(); + s2.raftServer->shutdown(); + s3.raftServer->shutdown(); + TestSuite::sleep_sec(1, "shutting down"); + + SimpleLogger::shutdown(); + return 0; +} + + } // namespace asio_service_test; using namespace asio_service_test; @@ -2565,6 +2645,9 @@ int main(int argc, char** argv) { ts.doTest( "parallel log append test", parallel_log_append_test ); + ts.doTest( "custom resolver test", + custom_resolver_test ); + #ifdef ENABLE_RAFT_STATS _msg("raft stats: ENABLED\n"); #else diff --git a/tests/unit/raft_package_asio.hxx b/tests/unit/raft_package_asio.hxx index 6f196009..61bbe445 100644 --- a/tests/unit/raft_package_asio.hxx +++ b/tests/unit/raft_package_asio.hxx @@ -15,6 +15,7 @@ See the License for the specific language governing permissions and limitations under the License. **************************************************************************/ +#include "asio_service_options.hxx" #include "raft_functional_common.hxx" #include "internal_timer.hxx" @@ -51,6 +52,7 @@ public: , readReqMeta(nullptr) , writeReqMeta(nullptr) , alwaysInvokeCb(true) + , useCustomResolver(false) , myLogWrapper(nullptr) , myLog(nullptr) {} @@ -100,6 +102,21 @@ public: asio_opt.server_key_file_ = "./key.pem"; } + if (useCustomResolver) { + asio_opt.custom_resolver_ = + []( const std::string& host, + const std::string& port, + asio_service_custom_resolver_response when_done ) { + if (host.substr(0, 2) == "S1") { + when_done("127.0.0.1", "20010", std::error_code()); + } else if (host.substr(0, 2) == "S2") { + when_done("127.0.0.1", "20020", std::error_code()); + } else { + when_done("127.0.0.1", "20030", std::error_code()); + } + }; + } + if (readReqMeta) asio_opt.read_req_meta_ = readReqMeta; if (writeReqMeta) asio_opt.write_req_meta_ = writeReqMeta; @@ -233,6 +250,8 @@ public: bool alwaysInvokeCb; + bool useCustomResolver; + ptr myLogWrapper; ptr myLog; };