Skip to content

Commit

Permalink
refactor: move timeout logic to httpclient
Browse files Browse the repository at this point in the history
  • Loading branch information
jr0me committed Dec 2, 2024
1 parent 2336714 commit 7ab3e96
Show file tree
Hide file tree
Showing 8 changed files with 101 additions and 23 deletions.
12 changes: 12 additions & 0 deletions src/agent/communicator/include/communicator.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,15 @@ namespace communicator
LogWarn("batch_interval must be greater than zero. Using default value.");
m_batchInterval = config::agent::DEFAULT_BATCH_INTERVAL;
}

m_batchSize = getConfigValue.template operator()<int>("agent", "batch_size")
.value_or(config::agent::DEFAULT_BATCH_SIZE);

if (m_batchSize < 1)
{
LogWarn("batch_size must be greater than zero. Using default value.");
m_batchSize = config::agent::DEFAULT_BATCH_SIZE;
}
}

/// @brief Waits for the authentication token to expire and authenticates again
Expand Down Expand Up @@ -129,6 +138,9 @@ namespace communicator
/// @brief Time between batch requests
std::time_t m_batchInterval = config::agent::DEFAULT_BATCH_INTERVAL;

/// @brief Maximum number of messages to batch
int m_batchSize = config::agent::DEFAULT_BATCH_SIZE;

/// @brief The server URL
std::string m_serverUrl;

Expand Down
2 changes: 2 additions & 0 deletions src/agent/communicator/include/http_client.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ namespace http_client
/// @param onUnauthorized Callback for unauthorized access
/// @param connectionRetry Time in milliseconds to wait before retrying the connection
/// @param batchingInterval Time to wait between requests
/// @param m_batchSize The maximum number of messages to batch
/// @param onSuccess Callback for successful request completion
/// @param loopRequestCondition Condition to continue looping requests
/// @return Awaitable task for the HTTP request
Expand All @@ -50,6 +51,7 @@ namespace http_client
std::function<void()> onUnauthorized,
std::time_t connectionRetry,
std::time_t batchingInterval,
int m_batchSize,
std::function<void(const int, const std::string&)> onSuccess = {},
std::function<bool()> loopRequestCondition = {}) override;

Expand Down
2 changes: 2 additions & 0 deletions src/agent/communicator/include/ihttp_client.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ namespace http_client
/// @param onUnauthorized Action to take on unauthorized access
/// @param connectionRetry Time to wait before retrying the connection
/// @param batchingInterval Time to wait between requests
/// @param batchSize The maximum number of messages to batch
/// @param onSuccess Action to take on successful request
/// @param loopRequestCondition Condition to continue the request loop
/// @return Awaitable task for the HTTP request
Expand All @@ -48,6 +49,7 @@ namespace http_client
std::function<void()> onUnauthorized,
std::time_t connectionRetry,
std::time_t batchingInterval,
int batchSize,
std::function<void(const int, const std::string&)> onSuccess = {},
std::function<bool()> loopRequestCondition = {}) = 0;

Expand Down
13 changes: 11 additions & 2 deletions src/agent/communicator/src/communicator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -80,8 +80,15 @@ namespace communicator

const auto reqParams = http_client::HttpRequestParams(
boost::beast::http::verb::get, m_serverUrl, "/api/v1/commands", m_getHeaderInfo ? m_getHeaderInfo() : "");
co_await m_httpClient->Co_PerformHttpRequest(
m_token, reqParams, {}, onAuthenticationFailed, m_retryInterval, m_batchInterval, onSuccess, loopCondition);
co_await m_httpClient->Co_PerformHttpRequest(m_token,
reqParams,
{},
onAuthenticationFailed,
m_retryInterval,
m_batchInterval,
m_batchSize,
onSuccess,
loopCondition);
}

boost::asio::awaitable<void> Communicator::WaitForTokenExpirationAndAuthenticate()
Expand Down Expand Up @@ -149,6 +156,7 @@ namespace communicator
onAuthenticationFailed,
m_retryInterval,
m_batchInterval,
m_batchSize,
onSuccess,
loopCondition);
}
Expand Down Expand Up @@ -177,6 +185,7 @@ namespace communicator
onAuthenticationFailed,
m_retryInterval,
m_batchInterval,
m_batchSize,
onSuccess,
loopCondition);
}
Expand Down
45 changes: 32 additions & 13 deletions src/agent/communicator/src/http_client.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -95,15 +95,16 @@ namespace http_client
return req;
}

boost::asio::awaitable<void>
HttpClient::Co_PerformHttpRequest(std::shared_ptr<std::string> token,
HttpRequestParams reqParams,
std::function<boost::asio::awaitable<std::string>()> messageGetter,
std::function<void()> onUnauthorized,
std::time_t connectionRetry,
std::time_t batchInterval,
std::function<void(const std::string&)> onSuccess,
std::function<bool()> loopRequestCondition)
boost::asio::awaitable<void> HttpClient::Co_PerformHttpRequest(
std::shared_ptr<std::string> token,
HttpRequestParams reqParams,
std::function<boost::asio::awaitable<std::tuple<int, std::string>>()> messageGetter,
std::function<void()> onUnauthorized,
std::time_t connectionRetry,
std::time_t batchInterval,
int batchSize,
std::function<void(const int, const std::string&)> onSuccess,
std::function<bool()> loopRequestCondition)
{
using namespace std::chrono_literals;

Expand Down Expand Up @@ -155,10 +156,28 @@ namespace http_client

if (messageGetter != nullptr)
{
const auto messages = co_await messageGetter();
messagesCount = std::get<0>(messages);
LogTrace("Messages count: {}", messagesCount);
reqParams.Body = std::get<1>(messages);
boost::asio::steady_timer refreshTimer(co_await boost::asio::this_coro::executor);
boost::asio::steady_timer batchTimeoutTimer(co_await boost::asio::this_coro::executor);
batchTimeoutTimer.expires_after(std::chrono::milliseconds(batchInterval));

while (loopRequestCondition != nullptr && loopRequestCondition())
{
LogError("Loop request condition is true.");
// print batch size
LogError("Batch size: {}", batchSize);
const auto messages = co_await messageGetter();
messagesCount = std::get<0>(messages);

if (messagesCount >= batchSize || batchTimeoutTimer.expiry() <= std::chrono::steady_clock::now())
{
LogTrace("Messages count: {}", messagesCount);
reqParams.Body = std::get<1>(messages);
break;
}

refreshTimer.expires_after(std::chrono::milliseconds(100));
co_await refreshTimer.async_wait(boost::asio::use_awaitable);
}
}
else
{
Expand Down
9 changes: 6 additions & 3 deletions src/agent/communicator/tests/communicator_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -61,14 +61,15 @@ TEST(CommunicatorTest, StatefulMessageProcessingTask_Success)
EXPECT_EQ(message, "message-content");
};

EXPECT_CALL(*mockHttpClient, Co_PerformHttpRequest(_, _, _, _, _, _, _, _))
EXPECT_CALL(*mockHttpClient, Co_PerformHttpRequest(_, _, _, _, _, _, _, _, _))
.WillOnce(Invoke(
[](std::shared_ptr<std::string>,
http_client::HttpRequestParams,
std::function<boost::asio::awaitable<std::tuple<int, std::string>>()> pGetMessages,
std::function<void()>,
[[maybe_unused]] std::time_t connectionRetry,
[[maybe_unused]] std::time_t batchingInterval,
[[maybe_unused]] int batchSize,
std::function<void(const int, const std::string&)> pOnSuccess,
[[maybe_unused]] std::function<bool()> loopRequestCondition) -> boost::asio::awaitable<void>
{
Expand Down Expand Up @@ -109,14 +110,15 @@ TEST(CommunicatorTest, WaitForTokenExpirationAndAuthenticate_FailedAuthenticatio
}));

// A following call to Co_PerformHttpRequest should not have a token
EXPECT_CALL(*mockHttpClientPtr, Co_PerformHttpRequest(_, _, _, _, _, _, _, _))
EXPECT_CALL(*mockHttpClientPtr, Co_PerformHttpRequest(_, _, _, _, _, _, _, _, _))
.WillOnce(Invoke(
[](std::shared_ptr<std::string> token,
http_client::HttpRequestParams,
[[maybe_unused]] std::function<boost::asio::awaitable<std::tuple<int, std::string>>()> getMessages,
[[maybe_unused]] std::function<void()> onUnauthorized,
[[maybe_unused]] std::time_t connectionRetry,
[[maybe_unused]] std::time_t batchingInterval,
[[maybe_unused]] int batchSize,
[[maybe_unused]] std::function<void(const int, const std::string&)> onSuccess,
[[maybe_unused]] std::function<bool()> loopCondition) -> boost::asio::awaitable<void>
{
Expand Down Expand Up @@ -165,7 +167,7 @@ TEST(CommunicatorTest, StatelessMessageProcessingTask_CallsWithValidToken)
}));

std::string capturedToken;
EXPECT_CALL(*mockHttpClientPtr, Co_PerformHttpRequest(_, _, _, _, _, _, _, _))
EXPECT_CALL(*mockHttpClientPtr, Co_PerformHttpRequest(_, _, _, _, _, _, _, _, _))
.WillOnce(Invoke(
[&capturedToken](
std::shared_ptr<std::string> token,
Expand All @@ -174,6 +176,7 @@ TEST(CommunicatorTest, StatelessMessageProcessingTask_CallsWithValidToken)
[[maybe_unused]] std::function<void()> onUnauthorized,
[[maybe_unused]] std::time_t connectionRetry,
[[maybe_unused]] std::time_t batchingInterval,
[[maybe_unused]] int batchSize,
[[maybe_unused]] std::function<void(const int, const std::string&)> onSuccess,
[[maybe_unused]] std::function<bool()> loopCondition) -> boost::asio::awaitable<void>
{
Expand Down
38 changes: 34 additions & 4 deletions src/agent/communicator/tests/http_client_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
#include <functional>
#include <memory>
#include <string>
#include <utility>

// NOLINTBEGIN(cppcoreguidelines-avoid-capturing-lambda-coroutines,cppcoreguidelines-avoid-reference-coroutine-parameters)

Expand Down Expand Up @@ -233,6 +234,12 @@ TEST_F(HttpClientTest, Co_PerformHttpRequest_Success)
unauthorizedCalled = true;
};

auto loopCondition = true;
std::function<bool()> loopRequestCondition = [&loopCondition]()
{
return std::exchange(loopCondition, false);
};

const auto reqParams =
http_client::HttpRequestParams(boost::beast::http::verb::get, "https://localhost:8080", "/", "Wazuh 5.0.0");

Expand All @@ -242,8 +249,9 @@ TEST_F(HttpClientTest, Co_PerformHttpRequest_Success)
onUnauthorized,
5, // NOLINT
1, // NOLINT
1, // NOLINT
onSuccess,
nullptr);
loopRequestCondition);

boost::asio::io_context ioContext;
boost::asio::co_spawn(ioContext, std::move(task), boost::asio::detached);
Expand Down Expand Up @@ -288,6 +296,7 @@ TEST_F(HttpClientTest, Co_PerformHttpRequest_CallbacksNotCalledIfCannotConnect)
onUnauthorized,
5, // NOLINT
1, // NOLINT
1, // NOLINT
onSuccess,
nullptr);

Expand Down Expand Up @@ -327,6 +336,12 @@ TEST_F(HttpClientTest, Co_PerformHttpRequest_OnSuccessNotCalledIfAsyncWriteFails
unauthorizedCalled = true;
};

auto loopCondition = true;
std::function<bool()> loopRequestCondition = [&loopCondition]()
{
return std::exchange(loopCondition, false);
};

const auto reqParams =
http_client::HttpRequestParams(boost::beast::http::verb::get, "https://localhost:8080", "/", "Wazuh 5.0.0");
auto task = client->Co_PerformHttpRequest(std::make_shared<std::string>("token"),
Expand All @@ -335,8 +350,9 @@ TEST_F(HttpClientTest, Co_PerformHttpRequest_OnSuccessNotCalledIfAsyncWriteFails
onUnauthorized,
5, // NOLINT
1, // NOLINT
1, // NOLINT
onSuccess,
nullptr);
loopRequestCondition);

boost::asio::io_context ioContext;
boost::asio::co_spawn(ioContext, std::move(task), boost::asio::detached);
Expand Down Expand Up @@ -376,6 +392,12 @@ TEST_F(HttpClientTest, Co_PerformHttpRequest_OnSuccessNotCalledIfAsyncReadFails)
unauthorizedCalled = true;
};

auto loopCondition = true;
std::function<bool()> loopRequestCondition = [&loopCondition]()
{
return std::exchange(loopCondition, false);
};

const auto reqParams =
http_client::HttpRequestParams(boost::beast::http::verb::get, "https://localhost:8080", "/", "Wazuh 5.0.0");
auto task = client->Co_PerformHttpRequest(std::make_shared<std::string>("token"),
Expand All @@ -384,8 +406,9 @@ TEST_F(HttpClientTest, Co_PerformHttpRequest_OnSuccessNotCalledIfAsyncReadFails)
onUnauthorized,
5, // NOLINT
1, // NOLINT
1, // NOLINT
onSuccess,
nullptr);
loopRequestCondition);

boost::asio::io_context ioContext;
boost::asio::co_spawn(ioContext, std::move(task), boost::asio::detached);
Expand Down Expand Up @@ -424,6 +447,12 @@ TEST_F(HttpClientTest, Co_PerformHttpRequest_UnauthorizedCalledWhenAuthorization
unauthorizedCalled = true;
};

auto loopCondition = true;
std::function<bool()> loopRequestCondition = [&loopCondition]()
{
return std::exchange(loopCondition, false);
};

const auto reqParams =
http_client::HttpRequestParams(boost::beast::http::verb::get, "https://localhost:8080", "/", "Wazuh 5.0.0");
auto task = client->Co_PerformHttpRequest(std::make_shared<std::string>("token"),
Expand All @@ -432,8 +461,9 @@ TEST_F(HttpClientTest, Co_PerformHttpRequest_UnauthorizedCalledWhenAuthorization
onUnauthorized,
5, // NOLINT
1, // NOLINT
1, // NOLINT
onSuccess,
nullptr);
loopRequestCondition);

boost::asio::io_context ioContext;
boost::asio::co_spawn(ioContext, std::move(task), boost::asio::detached);
Expand Down
3 changes: 2 additions & 1 deletion src/agent/communicator/tests/mocks/mock_http_client.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,8 @@ class MockHttpClient : public http_client::IHttpClient
std::function<void()> onUnauthorized,
std::time_t connectionRetry,
std::time_t batchInterval,
std::function<void(const std::string&)> onSuccess,
int batchSize,
std::function<void(const int, const std::string&)> onSuccess,
std::function<bool()> loopRequestCondition),
(override));

Expand Down

0 comments on commit 7ab3e96

Please sign in to comment.