diff --git a/src/agent/CMakeLists.txt b/src/agent/CMakeLists.txt index 97cd7adc0d..40a2249e37 100644 --- a/src/agent/CMakeLists.txt +++ b/src/agent/CMakeLists.txt @@ -12,9 +12,10 @@ project(Agent) set(SOURCES src/agent.cpp - src/task_manager.cpp + src/message_queue_utils.cpp src/register.cpp src/signal_handler.cpp + src/task_manager.cpp ) add_subdirectory(agent_info) diff --git a/src/agent/communicator/include/communicator.hpp b/src/agent/communicator/include/communicator.hpp index 5d1f028921..c7f7efd310 100644 --- a/src/agent/communicator/include/communicator.hpp +++ b/src/agent/communicator/include/communicator.hpp @@ -20,9 +20,13 @@ namespace communicator const std::function GetStringConfigValue); boost::asio::awaitable WaitForTokenExpirationAndAuthenticate(); - boost::asio::awaitable GetCommandsFromManager(std::queue& messageQueue); - boost::asio::awaitable StatefulMessageProcessingTask(std::queue& messageQueue); - boost::asio::awaitable StatelessMessageProcessingTask(std::queue& messageQueue); + boost::asio::awaitable GetCommandsFromManager(std::function onSuccess); + boost::asio::awaitable + StatefulMessageProcessingTask(std::function()> getMessages, + std::function onSuccess); + boost::asio::awaitable + StatelessMessageProcessingTask(std::function()> getMessages, + std::function onSuccess); private: long GetTokenRemainingSecs() const; diff --git a/src/agent/communicator/include/http_client.hpp b/src/agent/communicator/include/http_client.hpp index 0cc45e1b6b..3d3c561800 100644 --- a/src/agent/communicator/include/http_client.hpp +++ b/src/agent/communicator/include/http_client.hpp @@ -46,11 +46,12 @@ namespace http_client std::function onUnauthorized, std::function onSuccess = {}); - boost::asio::awaitable Co_MessageProcessingTask(const std::string& token, - HttpRequestParams params, - std::function messageGetter, - std::function onUnauthorized, - std::function onSuccess = {}); + boost::asio::awaitable + Co_MessageProcessingTask(const std::string& token, + HttpRequestParams params, + std::function()> messageGetter, + std::function onUnauthorized, + std::function onSuccess = {}); boost::beast::http::response PerformHttpRequest(const HttpRequestParams& params); diff --git a/src/agent/communicator/src/communicator.cpp b/src/agent/communicator/src/communicator.cpp index f7446bb48e..e48796f56e 100644 --- a/src/agent/communicator/src/communicator.cpp +++ b/src/agent/communicator/src/communicator.cpp @@ -67,7 +67,7 @@ namespace communicator return std::max(0L, static_cast(m_tokenExpTimeInSeconds - now_seconds)); } - boost::asio::awaitable Communicator::GetCommandsFromManager(std::queue& messageQueue) + boost::asio::awaitable Communicator::GetCommandsFromManager(std::function onSuccess) { auto onAuthenticationFailed = [this]() { @@ -75,7 +75,7 @@ namespace communicator }; const auto reqParams = http_client::HttpRequestParams(boost::beast::http::verb::get, m_managerIp, m_port, "/commands"); - co_await http_client::Co_MessageProcessingTask(m_token, reqParams, {}, onAuthenticationFailed); + co_await http_client::Co_MessageProcessingTask(m_token, reqParams, {}, onAuthenticationFailed, onSuccess); } boost::asio::awaitable Communicator::WaitForTokenExpirationAndAuthenticate() @@ -119,7 +119,9 @@ namespace communicator } } - boost::asio::awaitable Communicator::StatefulMessageProcessingTask(std::queue& messageQueue) + boost::asio::awaitable + Communicator::StatefulMessageProcessingTask(std::function()> getMessages, + std::function onSuccess) { auto onAuthenticationFailed = [this]() { @@ -127,10 +129,13 @@ namespace communicator }; const auto reqParams = http_client::HttpRequestParams(boost::beast::http::verb::post, m_managerIp, m_port, "/stateful"); - co_await http_client::Co_MessageProcessingTask(m_token, reqParams, {}, onAuthenticationFailed); + co_await http_client::Co_MessageProcessingTask( + m_token, reqParams, getMessages, onAuthenticationFailed, onSuccess); } - boost::asio::awaitable Communicator::StatelessMessageProcessingTask(std::queue& messageQueue) + boost::asio::awaitable + Communicator::StatelessMessageProcessingTask(std::function()> getMessages, + std::function onSuccess) { auto onAuthenticationFailed = [this]() { @@ -138,7 +143,8 @@ namespace communicator }; const auto reqParams = http_client::HttpRequestParams(boost::beast::http::verb::post, m_managerIp, m_port, "/stateless"); - co_await http_client::Co_MessageProcessingTask(m_token, reqParams, {}, onAuthenticationFailed); + co_await http_client::Co_MessageProcessingTask( + m_token, reqParams, getMessages, onAuthenticationFailed, onSuccess); } void Communicator::TryReAuthenticate() diff --git a/src/agent/communicator/src/http_client.cpp b/src/agent/communicator/src/http_client.cpp index 3a1f988b53..e8336b1160 100644 --- a/src/agent/communicator/src/http_client.cpp +++ b/src/agent/communicator/src/http_client.cpp @@ -96,11 +96,12 @@ namespace http_client std::cout << "Response body: " << boost::beast::buffers_to_string(res.body().data()) << std::endl; } - boost::asio::awaitable Co_MessageProcessingTask(const std::string& token, - HttpRequestParams reqParams, - std::function messageGetter, - std::function onUnauthorized, - std::function onSuccess) + boost::asio::awaitable + Co_MessageProcessingTask(const std::string& token, + HttpRequestParams reqParams, + std::function()> messageGetter, + std::function onUnauthorized, + std::function onSuccess) { using namespace std::chrono_literals; @@ -131,7 +132,7 @@ namespace http_client if (messageGetter != nullptr) { - reqParams.body = messageGetter(); + reqParams.body = co_await messageGetter(); } else { diff --git a/src/agent/include/agent.hpp b/src/agent/include/agent.hpp index c9f0c0cea3..020ccd9395 100644 --- a/src/agent/include/agent.hpp +++ b/src/agent/include/agent.hpp @@ -3,10 +3,10 @@ #include #include #include +#include #include #include -#include #include class Agent @@ -18,7 +18,7 @@ class Agent void Run(); private: - std::queue m_messageQueue; + MultiTypeQueue m_messageQueue; SignalHandler m_signalHandler; TaskManager m_taskManager; diff --git a/src/agent/include/message_queue_utils.hpp b/src/agent/include/message_queue_utils.hpp new file mode 100644 index 0000000000..f71aff6906 --- /dev/null +++ b/src/agent/include/message_queue_utils.hpp @@ -0,0 +1,15 @@ +#pragma once + +#include + +#include + +#include + +class IMultiTypeQueue; + +boost::asio::awaitable getMessagesFromQueue(IMultiTypeQueue& multiTypeQueue, MessageType messageType); + +void popMessagesFromQueue(IMultiTypeQueue& multiTypeQueue, MessageType messageType); + +void pushCommandsToQueue(IMultiTypeQueue& multiTypeQueue, const std::string& commands); diff --git a/src/agent/multitype_queue/include/message.hpp b/src/agent/multitype_queue/include/message.hpp index 85c49ebeda..83f515a5f8 100644 --- a/src/agent/multitype_queue/include/message.hpp +++ b/src/agent/multitype_queue/include/message.hpp @@ -31,4 +31,10 @@ class Message , moduleName(mN) { } + + // Define equality operator + bool operator==(const Message& other) const + { + return type == other.type && data == other.data; + } }; diff --git a/src/agent/src/agent.cpp b/src/agent/src/agent.cpp index 1e0ede2169..f5f2d281f9 100644 --- a/src/agent/src/agent.cpp +++ b/src/agent/src/agent.cpp @@ -1,6 +1,11 @@ #include +#include +#include + +#include #include +#include Agent::Agent() : m_communicator(m_agentInfo.GetUUID(), @@ -18,9 +23,17 @@ Agent::~Agent() void Agent::Run() { m_taskManager.EnqueueTask(m_communicator.WaitForTokenExpirationAndAuthenticate()); - m_taskManager.EnqueueTask(m_communicator.GetCommandsFromManager(m_messageQueue)); - m_taskManager.EnqueueTask(m_communicator.StatefulMessageProcessingTask(m_messageQueue)); - m_taskManager.EnqueueTask(m_communicator.StatelessMessageProcessingTask(m_messageQueue)); + + m_taskManager.EnqueueTask(m_communicator.GetCommandsFromManager( + [this](const std::string& response) { pushCommandsToQueue(m_messageQueue, response); })); + + m_taskManager.EnqueueTask(m_communicator.StatefulMessageProcessingTask( + [this]() { return getMessagesFromQueue(m_messageQueue, STATEFUL); }, + [this]([[maybe_unused]] const std::string& response) { popMessagesFromQueue(m_messageQueue, STATEFUL); })); + + m_taskManager.EnqueueTask(m_communicator.StatelessMessageProcessingTask( + [this]() { return getMessagesFromQueue(m_messageQueue, STATELESS); }, + [this]([[maybe_unused]] const std::string& response) { popMessagesFromQueue(m_messageQueue, STATELESS); })); m_signalHandler.WaitForSignal(); } diff --git a/src/agent/src/message_queue_utils.cpp b/src/agent/src/message_queue_utils.cpp new file mode 100644 index 0000000000..c23776c8f1 --- /dev/null +++ b/src/agent/src/message_queue_utils.cpp @@ -0,0 +1,49 @@ +#include + +#include + +#include + +#include + +namespace +{ + // This should eventually be replaced with a configuration parameter. + constexpr int NUM_EVENTS = 1; +} // namespace + +boost::asio::awaitable getMessagesFromQueue(IMultiTypeQueue& multiTypeQueue, MessageType messageType) +{ + const auto message = co_await multiTypeQueue.getNextNAwaitable(messageType, NUM_EVENTS); + + nlohmann::json jsonObj; + jsonObj["events"] = nlohmann::json::array(); + jsonObj["events"].push_back(message.data); + + co_return jsonObj.dump(); +} + +void popMessagesFromQueue(IMultiTypeQueue& multiTypeQueue, MessageType messageType) +{ + multiTypeQueue.popN(messageType, NUM_EVENTS); +} + +void pushCommandsToQueue(IMultiTypeQueue& multiTypeQueue, const std::string& commands) +{ + const auto jsonObj = nlohmann::json::parse(commands); + + if (jsonObj.contains("commands") && jsonObj["commands"].is_array()) + { + std::vector messages; + + for (const auto& command : jsonObj["commands"]) + { + messages.emplace_back(MessageType::COMMAND, command); + } + + if (!messages.empty()) + { + multiTypeQueue.push(messages); + } + } +} diff --git a/src/agent/tests/CMakeLists.txt b/src/agent/tests/CMakeLists.txt index 89c01feb5f..705ef54acd 100644 --- a/src/agent/tests/CMakeLists.txt +++ b/src/agent/tests/CMakeLists.txt @@ -15,3 +15,7 @@ add_test(NAME RegisterTest COMMAND register_test) add_executable(signal_handler_test signal_handler_test.cpp) target_link_libraries(signal_handler_test PRIVATE Agent GTest::gtest) add_test(NAME SignalHandlerTest COMMAND signal_handler_test) + +add_executable(message_queue_utils_test message_queue_utils_test.cpp) +target_link_libraries(message_queue_utils_test PRIVATE Agent MultiTypeQueue GTest::gtest GTest::gmock) +add_test(NAME MessageQueueUtilsTest COMMAND message_queue_utils_test) diff --git a/src/agent/tests/message_queue_utils_test.cpp b/src/agent/tests/message_queue_utils_test.cpp new file mode 100644 index 0000000000..a41e3a2b7c --- /dev/null +++ b/src/agent/tests/message_queue_utils_test.cpp @@ -0,0 +1,91 @@ +#include + +#include +#include + +#include +#include +#include +#include + +class MockMultiTypeQueue : public MultiTypeQueue +{ +public: + MOCK_METHOD(boost::asio::awaitable, + getNextNAwaitable, + (MessageType, int, const std::string module), + (override)); + MOCK_METHOD(int, popN, (MessageType, int, const std::string module), (override)); + MOCK_METHOD(int, push, (std::vector), (override)); +}; + +class MessageQueueUtilsTest : public ::testing::Test +{ +protected: + boost::asio::io_context io_context; + MockMultiTypeQueue mockQueue; +}; + +TEST_F(MessageQueueUtilsTest, GetMessagesFromQueueTest) +{ + Message testMessage {MessageType::STATEFUL, "test_data"}; + + EXPECT_CALL(mockQueue, getNextNAwaitable(MessageType::STATEFUL, 1, "")) + .WillOnce([this, &testMessage]() -> boost::asio::awaitable { co_return testMessage; }); + + io_context.restart(); + + auto result = boost::asio::co_spawn( + io_context, getMessagesFromQueue(mockQueue, MessageType::STATEFUL), boost::asio::use_future); + + const auto timeout = std::chrono::steady_clock::now() + std::chrono::milliseconds(1); + io_context.run_until(timeout); + + ASSERT_TRUE(result.wait_for(std::chrono::milliseconds(1)) == std::future_status::ready); + + const auto jsonResult = result.get(); + + nlohmann::json expectedJson; + expectedJson["events"] = nlohmann::json::array(); + expectedJson["events"].push_back("test_data"); + + ASSERT_EQ(jsonResult, expectedJson.dump()); +} + +TEST_F(MessageQueueUtilsTest, PopMessagesFromQueueTest) +{ + EXPECT_CALL(mockQueue, popN(MessageType::STATEFUL, 1, "")).Times(1); + popMessagesFromQueue(mockQueue, MessageType::STATEFUL); +} + +TEST_F(MessageQueueUtilsTest, PushCommandsToQueueTest) +{ + nlohmann::json commandsJson; + commandsJson["commands"] = nlohmann::json::array(); + commandsJson["commands"].push_back("command_1"); + commandsJson["commands"].push_back("command_2"); + + std::vector expectedMessages; + expectedMessages.emplace_back(MessageType::COMMAND, "command_1"); + expectedMessages.emplace_back(MessageType::COMMAND, "command_2"); + + EXPECT_CALL(mockQueue, push(::testing::ContainerEq(expectedMessages))).Times(1); + + pushCommandsToQueue(mockQueue, commandsJson.dump()); +} + +TEST_F(MessageQueueUtilsTest, NoCommandsToPushTest) +{ + nlohmann::json commandsJson; + commandsJson["commands"] = nlohmann::json::array(); + + EXPECT_CALL(mockQueue, push(::testing::_)).Times(0); + + pushCommandsToQueue(mockQueue, commandsJson.dump()); +} + +int main(int argc, char** argv) +{ + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +}