Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adapt command structure to new definition #408

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -19,15 +19,7 @@ namespace centralized_configuration
if (m_setGroupIdFunction && m_downloadGroupFilesFunction && m_validateFileFunction &&
m_reloadModulesFunction)
{
if (parameters.empty())
{
LogWarn("Group set failed, no group list");
co_return module_command::CommandExecutionResult {
module_command::Status::FAILURE,
"CentralizedConfiguration group set failed, no group list"};
}

groupIds = parameters[0].get<std::vector<std::string>>();
groupIds = parameters.get<std::vector<std::string>>();

if (!m_setGroupIdFunction(groupIds))
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -74,31 +74,6 @@ TEST(CentralizedConfiguration, ExecuteCommandReturnsFailureOnUnrecognizedCommand
io_context.run();
}

TEST(CentralizedConfiguration, ExecuteCommandReturnsFailureOnEmptyList)
{
boost::asio::io_context io_context;

boost::asio::co_spawn(
io_context,
[]() -> boost::asio::awaitable<void>
{
CentralizedConfiguration centralizedConfiguration;
centralizedConfiguration.SetGroupIdFunction([](const std::vector<std::string>&) { return true; });
centralizedConfiguration.SetDownloadGroupFilesFunction([](const std::string&, const std::string&)
{ return true; });
centralizedConfiguration.ValidateFileFunction([](const std::filesystem::path&) { return true; });
centralizedConfiguration.ReloadModulesFunction([]() {});
co_await TestExecuteCommand(centralizedConfiguration,
"set-group",
{},
module_command::Status::FAILURE,
"CentralizedConfiguration group set failed, no group list");
}(),
boost::asio::detached);

io_context.run();
}

TEST(CentralizedConfiguration, ExecuteCommandReturnsFailureOnParseParameters)
{
boost::asio::io_context io_context;
Expand All @@ -114,7 +89,7 @@ TEST(CentralizedConfiguration, ExecuteCommandReturnsFailureOnParseParameters)
centralizedConfiguration.ValidateFileFunction([](const std::filesystem::path&) { return true; });
centralizedConfiguration.ReloadModulesFunction([]() {});

const std::vector<std::string> parameterList = {true, "group2"};
const nlohmann::json parameterList = nlohmann::json::parse(R"([true, "group2"])");
co_await TestExecuteCommand(centralizedConfiguration,
"set-group",
parameterList,
Expand Down Expand Up @@ -153,7 +128,7 @@ TEST(CentralizedConfiguration, ExecuteCommandHandlesRecognizedCommands)
centralizedConfiguration.ValidateFileFunction([](const std::filesystem::path&) { return true; });
centralizedConfiguration.ReloadModulesFunction([]() {});

const nlohmann::json groupsList = nlohmann::json::parse(R"([["group1", "group2"]])");
const nlohmann::json groupsList = nlohmann::json::parse(R"(["group1", "group2"])");

co_await TestExecuteCommand(centralizedConfiguration,
"set-group",
Expand Down Expand Up @@ -199,7 +174,7 @@ TEST(CentralizedConfiguration, SetFunctionsAreCalledAndReturnsCorrectResultsForS

CentralizedConfiguration centralizedConfiguration(std::move(mockFileSystem));

const nlohmann::json groupsList = nlohmann::json::parse(R"([["group1", "group2"]])");
const nlohmann::json groupsList = nlohmann::json::parse(R"(["group1", "group2"])");

bool wasSetGroupIdFunctionCalled = false;
bool wasDownloadGroupFilesFunctionCalled = false;
Expand Down
25 changes: 25 additions & 0 deletions src/agent/command_handler/include/command_handler.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,19 @@ namespace command_handler
continue;
}

if (!CheckCommand(cmd.value()))
{
cmd.value().ExecutionResult.ErrorCode = module_command::Status::FAILURE;
cmd.value().ExecutionResult.Message = "Command is not valid";
LogError("Error checking module and args for command: {} {}. Error: {}",
cmd.value().Id,
cmd.value().Command,
cmd.value().ExecutionResult.Message);
ReportCommandResult(cmd.value());
PopCommandFromQueue();
continue;
}

if (!m_commandStore.StoreCommand(cmd.value()))
{
cmd.value().ExecutionResult.ErrorCode = module_command::Status::FAILURE;
Expand Down Expand Up @@ -118,6 +131,18 @@ namespace command_handler
}
}

/// @brief Check if the command is valid
///
/// This function checks if the given command is valid by looking it up in
/// the map of valid commands. If the command is valid, it checks if the
/// parameters are valid. If the command is not valid, it logs an error and
/// returns false. If the command is valid, it sets the module for the
/// command and returns true.
///
/// @param cmd The command to check
/// @return True if the command is valid, false otherwise
bool CheckCommand(module_command::CommandEntry& cmd);

/// @brief Indicates whether the command handler is running or not
std::atomic<bool> m_keepRunning = true;

Expand Down
30 changes: 30 additions & 0 deletions src/agent/command_handler/src/command_handler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,38 @@

namespace command_handler
{
const std::unordered_map<std::string, std::string> VALID_COMMANDS_MAP = {
{"set-group", "CentralizedConfiguration"}, {"update-group", "CentralizedConfiguration"}};

void CommandHandler::Stop()
{
m_keepRunning.store(false);
}

bool CommandHandler::CheckCommand(module_command::CommandEntry& cmd)
{
auto it = VALID_COMMANDS_MAP.find(cmd.Command);
if (it != VALID_COMMANDS_MAP.end())
{
if (!cmd.Parameters.empty())
{
for (const auto& param : cmd.Parameters.items())
{
if (!param.value().is_string() || param.value().get<std::string>().empty())
{
LogError("The command {} parameters must be non-empty strings.", cmd.Command);
return false;
}
}
}

cmd.Module = it->second;
return true;
}
else
{
LogError("The command {} is not valid.", cmd.Command);
return false;
}
}
} // namespace command_handler
4 changes: 2 additions & 2 deletions src/agent/src/agent.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ void Agent::ReloadModules()
m_configurationParser->ReloadConfiguration();
m_moduleManager.Stop();
m_moduleManager.Setup();
m_taskManager.EnqueueTask([this]() { m_moduleManager.Start(); }, "StartModuleManager");
m_moduleManager.Start();
}

void Agent::Run()
Expand Down Expand Up @@ -122,7 +122,7 @@ void Agent::Run()
"Stateless");

m_moduleManager.AddModules();
m_taskManager.EnqueueTask([this]() { m_moduleManager.Start(); }, "StartModuleManager");
m_moduleManager.Start();

m_taskManager.EnqueueTask(
m_commandHandler.CommandsProcessingTask<module_command::CommandEntry>(
Expand Down
41 changes: 19 additions & 22 deletions src/agent/src/command_handler_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,48 +13,51 @@ namespace
template<typename ExecuteFunction>
boost::asio::awaitable<void> ExecuteCommandTask(ExecuteFunction executeFunction,
module_command::CommandEntry commandEntry,
std::shared_ptr<module_command::CommandExecutionResult> result,
std::shared_ptr<bool> commandCompleted,
std::shared_ptr<boost::asio::steady_timer> timer)
std::shared_ptr<module_command::CommandExecutionResult> result)
{
try
{
*result = co_await executeFunction(commandEntry.Command, commandEntry.Parameters);
*commandCompleted = true;
timer->cancel();
}
catch (const std::exception& e)
{
result->ErrorCode = module_command::Status::FAILURE;
result->Message = "Error during command execution: " + std::string(e.what());
if (result->ErrorCode == module_command::Status::UNKNOWN)
{
result->ErrorCode = module_command::Status::FAILURE;
result->Message = "Error during command execution: " + std::string(e.what());
}
}
}

boost::asio::awaitable<void> TimerTask(std::shared_ptr<boost::asio::steady_timer> timer,
std::shared_ptr<module_command::CommandExecutionResult> result,
std::shared_ptr<bool> commandCompleted)
boost::asio::awaitable<void> TimerTask(std::shared_ptr<module_command::CommandExecutionResult> result)
{
try
{
constexpr auto timeout = std::chrono::minutes {60};
auto timer = std::make_shared<boost::asio::steady_timer>(co_await boost::asio::this_coro::executor);
timer->expires_after(timeout);
co_await timer->async_wait(boost::asio::use_awaitable);

if (!(*commandCompleted))
if (result->ErrorCode == module_command::Status::UNKNOWN)
{
result->ErrorCode = module_command::Status::TIMEOUT;
result->Message = "Command timed out";
}
}
catch (const boost::system::system_error& e)
{
if (!(*commandCompleted) && e.code() != boost::asio::error::operation_aborted)
if (e.code() != boost::asio::error::operation_aborted)
{
result->ErrorCode = module_command::Status::FAILURE;
result->Message = "System error: " + std::string(e.what());
if (result->ErrorCode == module_command::Status::UNKNOWN)
{
result->ErrorCode = module_command::Status::FAILURE;
result->Message = "System error: " + std::string(e.what());
}
}
}
catch (const std::exception& e)
{
if (!(*commandCompleted))
if (result->ErrorCode == module_command::Status::UNKNOWN)
{
result->ErrorCode = module_command::Status::FAILURE;
result->Message = "Unexpected error: " + std::string(e.what());
Expand All @@ -73,15 +76,9 @@ DispatchCommand(module_command::CommandEntry commandEntry,

LogInfo("Dispatching command {}({})", commandEntry.Command, commandEntry.Module);

const auto timeout = std::chrono::minutes(60);
auto timer = std::make_shared<boost::asio::steady_timer>(co_await boost::asio::this_coro::executor);
timer->expires_after(timeout);

auto result = std::make_shared<module_command::CommandExecutionResult>();
auto commandCompleted = std::make_shared<bool>(false);

co_await (TimerTask(timer, result, commandCompleted) ||
ExecuteCommandTask(executeFunction, commandEntry, result, commandCompleted, timer));
co_await (TimerTask(result) || ExecuteCommandTask(executeFunction, commandEntry, result));

commandEntry.ExecutionResult.ErrorCode = result->ErrorCode;
commandEntry.ExecutionResult.Message = result->Message;
Expand Down
29 changes: 11 additions & 18 deletions src/agent/src/message_queue_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -62,37 +62,30 @@ std::optional<module_command::CommandEntry> GetCommandFromQueue(std::shared_ptr<
nlohmann::json jsonData = m.data;

std::string id;
std::string module;
std::string command;
nlohmann::json parameters = nlohmann::json::array();

if (jsonData.contains("id") && jsonData["id"].is_string())
if (jsonData.contains("document_id") && jsonData["document_id"].is_string())
{
id = jsonData["id"].get<std::string>();
id = jsonData["document_id"].get<std::string>();
}

if (jsonData.contains("args") && jsonData["args"].is_array())
if (jsonData.contains("action") && jsonData["action"].is_object())
{
int index = 0;
for (const auto& arg : jsonData["args"])
if (jsonData["action"].contains("name") && jsonData["action"]["name"].is_string())
{
switch (index++)
command = jsonData["action"]["name"].get<std::string>();
}
if (jsonData["action"].contains("args") && jsonData["action"]["args"].is_array())
{
for (const auto& arg : jsonData["action"]["args"])
{
case 0:
if (arg.is_string())
module = arg.get<std::string>();
break;
case 1:
if (arg.is_string())
command = arg.get<std::string>();
break;
default: parameters.push_back(arg); break;
parameters.push_back(arg);
}
}
}

module_command::CommandEntry cmd(id, module, command, parameters, "", module_command::Status::IN_PROGRESS);

module_command::CommandEntry cmd(id, "", command, parameters, "", module_command::Status::IN_PROGRESS);
TomasTurina marked this conversation as resolved.
Show resolved Hide resolved
return cmd;
}

Expand Down
5 changes: 2 additions & 3 deletions src/agent/tests/message_queue_utils_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@
#include <gtest/gtest.h>
#include <nlohmann/json.hpp>

const nlohmann::json BASE_DATA_CONTENT = R"({"id":"112233", "args": ["origin_test",
"command_test", "parameters_test"]})"_json;
const nlohmann::json BASE_DATA_CONTENT =
R"({"document_id":"112233", "action":{"name":"command_test","args":["parameters_test"]}})"_json;

class MockMultiTypeQueue : public MultiTypeQueue
{
Expand Down Expand Up @@ -198,7 +198,6 @@ TEST_F(MessageQueueUtilsTest, GetCommandFromQueueTest)
auto cmd = GetCommandFromQueue(mockQueue);

ASSERT_EQ(cmd.has_value() ? cmd.value().Id : "", "112233");
ASSERT_EQ(cmd.has_value() ? cmd.value().Module : "", "origin_test");
ASSERT_EQ(cmd.has_value() ? cmd.value().Command : "", "command_test");
ASSERT_EQ(cmd.has_value() ? cmd.value().Parameters : nlohmann::json::array({""}),
nlohmann::json::array({"parameters_test"}));
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
{"commands":[{"id":"id","status":"sent","info":"string","args":["module","command",["arg1"]],"agent":{"id":"agentID"}}]}
{"commands":[{"document_id":"id","action":{"name":"command","version":"version","args":["arg1","arg2"]},"target":{"type":"type","id":"agentID"},"status":"sent"}]}

# CentralizedConfiguration set-group
{"commands":[{"id":"id1234","status":"sent","info":"string","args":["CentralizedConfiguration","set-group",["validYaml", "invalidYaml"]],"agent":{"id":"agentID"}}]}
{"commands":[{"document_id":"id","action":{"name":"set-group","version":"v5.0.0","args":["validYaml","invalidYaml"]},"target":{"type":"agent","id":"agentID"},"status":"sent"}]}

# CentralizedConfiguration update-group
{"commands":[{"id":"id123456","status":"sent","info":"string","args":["CentralizedConfiguration","update-group"],"agent":{"id":"agentID"}}]}
{"commands":[{"document_id":"id","action":{"name":"update-group","version":"v5.0.0","args":[]},"target":{"type":"agent","id":"agentID"},"status":"sent"}]}
3 changes: 2 additions & 1 deletion src/tests/mock-server/config/commands-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@
---
plugin: rest
resources:

- path: "/api/v1/commands"
method: GET
response:
statusCode: 401
scriptFile: commands.groovy
50 changes: 50 additions & 0 deletions src/tests/mock-server/config/commands.groovy
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
import groovy.json.JsonOutput
import java.time.Instant

def generateUUIDv7() {
def now = Instant.now().toEpochMilli()
def timeHex = String.format("%012x", now)
def randomBits = UUID.randomUUID().toString().replace("-", "").substring(12)
return "${timeHex.substring(0, 8)}-${timeHex.substring(8)}-7${randomBits.substring(0, 3)}-" +
"${randomBits.substring(3, 7)}-${randomBits.substring(7)}"

}

def actions = [
["name": "set-group", "version": "v5.0.0", "args": ["validYaml", "invalidYaml"]],
["name": "set-group", "version": "v5.0.0", "args": []],
["name": "set-group", "version": "v5.0.0", "args": [""]],
["name": "set-group", "version": "v5.0.0", "args": ["validYaml", 8]],
["name": "set-group", "version": "v5.0.0", "args": ["validYaml"]],
["name": "update-group", "version": "v5.0.0", "args": ["noNeedArgs"]],
["name": "update-group", "version": "v5.0.0"],
["name": "update-group", "version": "v5.0.0", "args": ""]
]

def numCommands = new Random().nextInt(3)
jr0me marked this conversation as resolved.
Show resolved Hide resolved

def commands = []
if (numCommands > 0) {
for (int i = 0; i < numCommands; i++) {
def action = actions[new Random().nextInt(actions.size())]
def command = [
"document_id": generateUUIDv7(),
"action": action,
"target": ["type": "agent", "id": "agentID"],
TomasTurina marked this conversation as resolved.
Show resolved Hide resolved
"status": "sent"
]
commands << command
}
}

if (commands.isEmpty()) {
respond {
withStatusCode(408)
}
} else {
def jsonResponse = JsonOutput.toJson(["commands": commands])
respond {
withStatusCode(200)
withContent(jsonResponse)
}
}
Loading
Loading