Skip to content

Commit

Permalink
feat: include full certificate chain (#114)
Browse files Browse the repository at this point in the history
Includes the full cert chain provided by CDA
when writing the cert file to the EMQX
work directory.
  • Loading branch information
jcosentino11 authored Oct 13, 2022
1 parent 8118d4b commit 89797d0
Show file tree
Hide file tree
Showing 4 changed files with 40 additions and 17 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,8 @@ class CertificateUpdatesHandler : public GG::SubscribeToCertificateUpdatesStream

// TODO: move these out of public
void OnStreamEvent(GG::CertificateUpdateEvent *response) override;
CertWriteStatus writeCertsToFiles(const Aws::Crt::String &privateKeyValue, const Aws::Crt::String &certValue);
CertWriteStatus writeCertsToFiles(const Aws::Crt::String &privateKeyValue, const Aws::Crt::String &certValue,
const Aws::Crt::Vector<Aws::Crt::String> &caCerts);

private:
const std::unique_ptr<std::filesystem::path> basePath;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@ static const std::filesystem::path EMQX_KEY_PATH = std::filesystem::path{"key.pe
static const std::filesystem::path EMQX_PEM_PATH = std::filesystem::path{"cert.pem"};

CertWriteStatus CertificateUpdatesHandler::writeCertsToFiles(const Aws::Crt::String &privateKeyValue,
const Aws::Crt::String &certValue) {
const Aws::Crt::String &certValue,
const Aws::Crt::Vector<Aws::Crt::String> &caCerts) {

if (!basePath) {
return CertWriteStatus::WRITE_ERROR_BASE_PATH;
Expand All @@ -37,8 +38,12 @@ CertWriteStatus CertificateUpdatesHandler::writeCertsToFiles(const Aws::Crt::Str
out_key_path << privateKeyValue.c_str();
out_key_path.close();

// write the entire cert chain as a pem file
auto out_pem_path = std::ofstream(path / EMQX_PEM_PATH);
out_pem_path << certValue.c_str();
for (const auto &caCert : caCerts) {
out_pem_path << std::endl << caCert.c_str();
}
out_pem_path.close();
} catch (std::exception &e) {
// TODO: unit test for this branch
Expand Down Expand Up @@ -81,7 +86,7 @@ void CertificateUpdatesHandler::OnStreamEvent(GG::CertificateUpdateEvent *respon
return;
}

const CertWriteStatus writeStatus = writeCertsToFiles(privateKey.value(), cert.value());
const CertWriteStatus writeStatus = writeCertsToFiles(privateKey.value(), cert.value(), allCAs.value());
if (writeStatus != CertWriteStatus::WRITE_SUCCESS) {
LOG_E(CERT_UPDATER_SUBJECT, "Failed to write certificates to files with code %d", (int)writeStatus);
return;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
#include <gtest/gtest.h>
#include <iostream>

#include "cda_integration.h"
#include "private/certificate_updater.h"
#include "test_utils.hpp"

Expand All @@ -24,22 +23,21 @@ class CertificateUpdatesHandlerTester : public ::testing::Test {
virtual void SetUp();
virtual void TearDown();

protected:
std::unique_ptr<std::filesystem::path> testPath = std::make_unique<std::filesystem::path>(filePath);

CertificateUpdatesHandler *handler;
CertificateUpdate testCertUpdate;
Optional<CertificateUpdate> optionalTestCertUpdate;
CertificateUpdateEvent *testResponse;
Vector<String> cas;
Vector<String> caCerts;
};

void CertificateUpdatesHandlerTester::SetUp() {
delete_certs();
testCertUpdate = CertificateUpdate();
optionalTestCertUpdate = Optional<CertificateUpdate>(testCertUpdate);
testResponse = new CertificateUpdateEvent();
cas = Vector<String>();
cas.emplace_back(testCACert);
caCerts = Vector<String>(testCACerts.begin(), testCACerts.end());
}

void CertificateUpdatesHandlerTester::TearDown() {
Expand Down Expand Up @@ -70,20 +68,26 @@ TEST_F(CertificateUpdatesHandlerTester, OnStreamEventTestNoUpdate) {
TEST_F(CertificateUpdatesHandlerTester, OnStreamEventTestValidResponse) {
optionalTestCertUpdate->SetPrivateKey(testPrivateKey.c_str());
optionalTestCertUpdate->SetCertificate(testCert.c_str());
optionalTestCertUpdate->SetCaCertificates(cas);
optionalTestCertUpdate->SetCaCertificates(caCerts);
testResponse->SetCertificateUpdate(optionalTestCertUpdate.value());

auto subscription_callback = std::make_unique<std::function<void(CertificateUpdateEvent *)>>(
[](CertificateUpdateEvent *) { std::cout << "callback" << std::endl; });
handler = new CertificateUpdatesHandler(std::move(testPath), std::move(subscription_callback));
handler->OnStreamEvent(testResponse);
EXPECT_TRUE(std::filesystem::exists(privateKeyFilePath));
EXPECT_EQ(readLines(privateKeyFilePath).front(), testPrivateKey);

auto expectedCertChain = std::vector<std::string>(testCACerts);
expectedCertChain.insert(expectedCertChain.begin(), testCert);

EXPECT_TRUE(std::filesystem::exists(certFilePath));
EXPECT_EQ(readLines(certFilePath), expectedCertChain);
}

TEST_F(CertificateUpdatesHandlerTester, OnStreamEventTestNoPrivateKey) {
optionalTestCertUpdate->SetCertificate(testCert.c_str());
optionalTestCertUpdate->SetCaCertificates(cas);
optionalTestCertUpdate->SetCaCertificates(caCerts);
testResponse->SetCertificateUpdate(optionalTestCertUpdate.value());

handler = new CertificateUpdatesHandler(std::move(testPath), nullptr);
Expand All @@ -94,7 +98,7 @@ TEST_F(CertificateUpdatesHandlerTester, OnStreamEventTestNoPrivateKey) {

TEST_F(CertificateUpdatesHandlerTester, OnStreamEventTestNoCert) {
optionalTestCertUpdate->SetPrivateKey(testPrivateKey.c_str());
optionalTestCertUpdate->SetCaCertificates(cas);
optionalTestCertUpdate->SetCaCertificates(caCerts);
testResponse->SetCertificateUpdate(optionalTestCertUpdate.value());

handler = new CertificateUpdatesHandler(std::move(testPath), nullptr);
Expand All @@ -117,7 +121,7 @@ TEST_F(CertificateUpdatesHandlerTester, OnStreamEventTestNoCaCert) {
TEST_F(CertificateUpdatesHandlerTester, OnStreamEventTestInvalidCertWrite) {
optionalTestCertUpdate->SetPrivateKey(testPrivateKey.c_str());
optionalTestCertUpdate->SetCertificate(testCert.c_str());
optionalTestCertUpdate->SetCaCertificates(cas);
optionalTestCertUpdate->SetCaCertificates(caCerts);
testResponse->SetCertificateUpdate(optionalTestCertUpdate.value());

handler = new CertificateUpdatesHandler(nullptr, nullptr);
Expand All @@ -132,7 +136,7 @@ TEST_F(CertificateUpdatesHandlerTester, WriteCertsToFilesTestNullBasePath) {
String crtPrivateKeyString(testPrivateKey);
String crtCertString(testCert);

CertWriteStatus retVal = handler->writeCertsToFiles(crtPrivateKeyString, crtCertString);
CertWriteStatus retVal = handler->writeCertsToFiles(crtPrivateKeyString, crtCertString, caCerts);
EXPECT_EQ(retVal, CertWriteStatus::WRITE_ERROR_BASE_PATH);
EXPECT_FALSE(std::filesystem::exists(privateKeyFilePath));
EXPECT_FALSE(std::filesystem::exists(certFilePath));
Expand All @@ -144,10 +148,8 @@ TEST_F(CertificateUpdatesHandlerTester, WriteCertsToFilesTestInvalidDir) {

String crtPrivateKeyString(testPrivateKey);
String crtCertString(testCert);
auto cas = Vector<String>();
cas.emplace_back(testCACert);

CertWriteStatus retVal = handler->writeCertsToFiles(crtPrivateKeyString, crtCertString);
CertWriteStatus retVal = handler->writeCertsToFiles(crtPrivateKeyString, crtCertString, caCerts);
EXPECT_EQ(retVal, CertWriteStatus::WRITE_ERROR_DIR_PATH);
EXPECT_FALSE(std::filesystem::exists(privateKeyFilePath));
EXPECT_FALSE(std::filesystem::exists(certFilePath));
Expand Down
17 changes: 16 additions & 1 deletion port_driver/tests/unit/src/cda_integration/test_utils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

#include <logger.h>
#include <string>
#include <fstream>
#include <aws/greengrass/GreengrassCoreIpcClient.h>

namespace GG = Aws::Greengrass;
Expand All @@ -16,7 +17,10 @@ static const std::string privateKeyFilePath = filePath + "/key.pem";
static const std::string certFilePath = filePath + "/cert.pem";
static const std::string testPrivateKey = "testPrivateKey";
static const std::string testCert = "testCert";
static const std::string testCACert = "testCACert";
static const std::vector<std::string> testCACerts = {
"intermediate",
"root"
};

[[maybe_unused]]
static struct aws_logger our_logger {};
Expand All @@ -26,6 +30,17 @@ static struct aws_logger_standard_options logger_options = {
.file = stderr,
};

[[maybe_unused]]
static std::vector<std::string> readLines(std::string filename) {
std::ifstream file(filename);
std::vector<std::string> lines;
std::string line;
while (std::getline(file, line)) {
lines.push_back(line);
}
return lines;
}

static const void delete_file(std::string fileName) {
if (!std::filesystem::remove(fileName)) {
std::cout << "Failed to delete " << fileName << std::endl;
Expand Down

0 comments on commit 89797d0

Please sign in to comment.