Skip to content

Commit

Permalink
Use openSSL to implement TLS communication (#185)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #185

NOTE:This is a relatively complex diff, but it is necessary because we need all functionality to be implemented to be able to test end to end (client setup, server setup, read, write, and destroy). Otherwise, each individual diff would be untestable.

This diff sets up end to end TLS communication between the partner and publisher.
- Use SSL_CTX_use_certificate_file and SSL_CTX_use_Privatekey_file to load credentials.
- Use SSL_accept to listen for handshake requests from a client
- On the client side, use SSL_connect to initiate a handshake
- Both parties use SSL_read and SSL_write to communicate

There are a few other nuances (blocking vs nonblocking reads, passphrases) that are all explained with inline comments.

There are a few NON-goals of this diff
1) We are not testing whether this works on PC infra. That will happen in the future.
2) We are not analyzing performance regressions.

Reviewed By: RuiyuZhu

Differential Revision: D35555496

fbshipit-source-id: d64dc5758eed9cd8e9b27794836fea67b893256a
  • Loading branch information
adshastri authored and facebook-github-bot committed Apr 20, 2022
1 parent 879f86f commit f3d2e03
Show file tree
Hide file tree
Showing 3 changed files with 224 additions and 19 deletions.
199 changes: 183 additions & 16 deletions fbpcf/engine/communication/SocketPartyCommunicationAgent.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,52 +11,118 @@
#include <assert.h>
#include <netdb.h>
#include <netinet/in.h>
#include <openssl/err.h>
#include <openssl/ssl.h>
#include <string.h>
#include <sys/socket.h>
#include <unistd.h>
#include <stdexcept>
#include <string>
#include <cerrno>
#include <fstream>
#include <istream>

#include <folly/String.h>
#include "folly/logging/xlog.h"

namespace fbpcf::engine::communication {

const std::string CERT_FILE = "cert.pem";
const std::string PRIVATE_KEY_FILE = "key.pem";
const std::string PASSPHRASE_FILE = "passphrase.pem";

/*
Per openSSL documentation, this callback is used to provide
the passphrase to open the private key file. See
https://www.openssl.org/docs/manmaster/man3/SSL_CTX_set_default_passwd_cb.html.
*/
static int
passwordCallback(char* buf, int size, int /* rwflag */, void* userdata) {
strncpy(buf, (char*)userdata, size);
buf[size - 1] = '\0';
return strlen((char*)userdata);
}

/*
This function is only used temporarily since we only have self
signed certificates available. In the future, when we implement
a Private CA, this callback should not be used.
*/
static int callbackToSkipVerificationOfSelfSignedCert_UNSAFE(
X509_STORE_CTX* /* ctx */,
void* /* data */) {
return 1; // always pass cert verification
}

SocketPartyCommunicationAgent::SocketPartyCommunicationAgent(
int portNo,
bool useTls,
std::string tlsDir)
: sentData_(0), receivedData_(0) {
openServerPort(portNo);
: sentData_(0), receivedData_(0), ssl_(nullptr) {
if (useTls) {
openServerPortWithTls(portNo, tlsDir);
} else {
openServerPort(portNo);
}
}

SocketPartyCommunicationAgent::SocketPartyCommunicationAgent(
const std::string& serverAddress,
int portNo,
bool useTls,
std::string tlsDir)
: sentData_(0), receivedData_(0) {
openClientPort(serverAddress, portNo);
: sentData_(0), receivedData_(0), ssl_(nullptr) {
if (useTls) {
openClientPortWithTls(serverAddress, portNo, tlsDir);
} else {
openClientPort(serverAddress, portNo);
}
}

SocketPartyCommunicationAgent::~SocketPartyCommunicationAgent() {
fclose(outgoingPort_);
fclose(incomingPort_);
if (!ssl_) {
fclose(outgoingPort_);
fclose(incomingPort_);
} else {
SSL_shutdown(ssl_);
SSL_free(ssl_);
}
}

void SocketPartyCommunicationAgent::send(
const std::vector<unsigned char>& data) {
auto s =
fwrite(data.data(), sizeof(unsigned char), data.size(), outgoingPort_);
assert(s == data.size());
sentData_ += s;
fflush(outgoingPort_);
size_t bytesWritten;
if (!ssl_) {
bytesWritten =
fwrite(data.data(), sizeof(unsigned char), data.size(), outgoingPort_);
} else {
bytesWritten = SSL_write(ssl_, (void*)data.data(), data.size());
}
assert(bytesWritten == data.size());
sentData_ += bytesWritten;
if (!ssl_) {
fflush(outgoingPort_);
}
}

std::vector<unsigned char> SocketPartyCommunicationAgent::receive(size_t size) {
size_t bytesRead = 0;
std::vector<unsigned char> rst(size);
auto s = fread(rst.data(), sizeof(unsigned char), size, incomingPort_);
assert(s == size);
receivedData_ += s;

if (!ssl_) {
bytesRead = fread(rst.data(), sizeof(unsigned char), size, incomingPort_);
} else {
// fread is blocking, but SSL_read is nonblocking. This discrepancy
// can cause issues at the application level. We need to make sure that
// both APIs behave consistently, so here we add a loop to ensure we
// mimick blocking behavior.
while (bytesRead < size) {
bytesRead += SSL_read(
ssl_,
rst.data() + (bytesRead * sizeof(unsigned char)),
size - bytesRead);
}
}
assert(bytesRead == size);
receivedData_ += bytesRead;
return rst;
}

Expand Down Expand Up @@ -97,6 +163,107 @@ void SocketPartyCommunicationAgent::openClientPort(
return;
}

void SocketPartyCommunicationAgent::openServerPortWithTls(
int portNo,
std::string tlsDir) {
LOG(INFO) << "try to connect as server at port " << portNo << " with TLS";
const SSL_METHOD* method;
SSL_CTX* ctx;

method = TLS_server_method();
ctx = SSL_CTX_new(method);

// Set passphrase for reading key.pem
SSL_CTX_set_default_passwd_cb(ctx, passwordCallback);

auto passphrase_file = tlsDir + "/" + PASSPHRASE_FILE;
std::ifstream file_ptr(passphrase_file);
std::string passphrase_string = "";
file_ptr >> passphrase_string;
file_ptr.close();
SSL_CTX_set_default_passwd_cb_userdata(ctx, (void*)passphrase_string.c_str());

if (ctx == nullptr) {
LOG(INFO) << folly::errnoStr(errno);
throw std::runtime_error("Could not create tls context");
}

// Load the certificate file
if (SSL_CTX_use_certificate_file(
ctx, (tlsDir + "/" + CERT_FILE).c_str(), SSL_FILETYPE_PEM) <= 0) {
LOG(INFO) << folly::errnoStr(errno);
throw std::runtime_error("Error using certificate file");
}

// Load the private key file
if (SSL_CTX_use_PrivateKey_file(
ctx, (tlsDir + "/" + PRIVATE_KEY_FILE).c_str(), SSL_FILETYPE_PEM) <=
0) {
LOG(INFO) << folly::errnoStr(errno);
throw std::runtime_error("Error using private key file");
}

auto acceptedConnection = receiveFromClient(portNo);

const auto ssl = SSL_new(ctx);
SSL_set_fd(ssl, acceptedConnection);

// Accept handshake from client
if (SSL_accept(ssl) <= 0) {
LOG(INFO) << folly::errnoStr(errno);
throw std::runtime_error("Error on accepting ssl");
}

LOG(INFO) << "connected as server at port " << portNo << " with TLS";

ssl_ = ssl;
}

void SocketPartyCommunicationAgent::openClientPortWithTls(
const std::string& serverAddress,
int portNo,
std::string /* tls_dir */) {
XLOGF(
INFO,
"try to connect as client to {} at port {} with TLS",
serverAddress,
portNo);
const SSL_METHOD* method = TLS_client_method();
SSL_CTX* ctx = SSL_CTX_new(method);

// set cert verification callback for self signed certs
// comment above has more information
SSL_CTX_set_cert_verify_callback(
ctx, callbackToSkipVerificationOfSelfSignedCert_UNSAFE, nullptr);

if (ctx == nullptr) {
LOG(INFO) << folly::errnoStr(errno);
throw std::runtime_error("could not create tls context");
}

SSL* ssl = SSL_new(ctx);

if (ssl == nullptr) {
LOG(INFO) << folly::errnoStr(errno);
throw std::runtime_error("could not create tls object");
}

const auto sockfd = connectToHost(serverAddress, portNo);

SSL_set_fd(ssl, sockfd);

// initiate handshake with server
const int status = SSL_connect(ssl);
if (status != 1) {
LOG(INFO) << folly::errnoStr(errno);
throw std::runtime_error("could not complete tls handshake");
}

XLOGF(INFO, "connected as client to {} at port {}", serverAddress, portNo);

ssl_ = ssl;
}

int SocketPartyCommunicationAgent::connectToHost(
const std::string& serverAddress,
int portNo) {
Expand Down
9 changes: 9 additions & 0 deletions fbpcf/engine/communication/SocketPartyCommunicationAgent.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

#pragma once

#include <openssl/ssl.h>
#include <string>

#include "fbpcf/engine/communication/IPartyCommunicationAgent.h"
Expand Down Expand Up @@ -60,6 +61,12 @@ class SocketPartyCommunicationAgent final : public IPartyCommunicationAgent {
void openServerPort(int portNo);
void openClientPort(const std::string& serverAddress, int portNo);

void openServerPortWithTls(int portNo, std::string tlsDir);
void openClientPortWithTls(
const std::string& serverAddress,
int portNo,
std::string tlsDir);

/*
* helper functions for shared code between TLS and non-TLS implementations
*/
Expand All @@ -70,6 +77,8 @@ class SocketPartyCommunicationAgent final : public IPartyCommunicationAgent {

uint64_t sentData_;
uint64_t receivedData_;

SSL* ssl_;
};

} // namespace fbpcf::engine::communication
35 changes: 32 additions & 3 deletions fbpcf/engine/communication/test/PartyCommunicationAgentTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
#include <emmintrin.h>
#include <gmock/gmock.h>
#include <gtest/gtest.h>
#include <filesystem>
#include <memory>
#include <mutex>
#include <random>
Expand All @@ -18,6 +19,7 @@
#include "fbpcf/engine/communication/InMemoryPartyCommunicationAgentHost.h"
#include "fbpcf/engine/communication/SocketPartyCommunicationAgentFactory.h"
#include "fbpcf/engine/communication/test/AgentFactoryCreationHelper.h"
#include "fbpcf/engine/communication/test/TlsCommunicationUtils.h"

namespace fbpcf::engine::communication {

Expand Down Expand Up @@ -59,7 +61,10 @@ TEST(InMemoryPartyCommunicationAgentTest, testSendAndReceive) {
thread0.join();
}

TEST(SocketPartyCommunicationAgentTest, testSendAndReceive) {
TEST(SocketPartyCommunicationAgentTest, testSendAndReceiveWithTls) {
auto tempdir = std::filesystem::temp_directory_path();
setUpTlsFiles(tempdir);

std::random_device rd;
std::default_random_engine defEngine(rd());
std::uniform_int_distribution<int> intDistro(10000, 25000);
Expand All @@ -72,17 +77,41 @@ TEST(SocketPartyCommunicationAgentTest, testSendAndReceive) {
std::map<int, SocketPartyCommunicationAgentFactory::PartyInfo> partyInfo = {
{0, {"127.0.0.1", intDistro(defEngine)}},
{1, {"127.0.0.1", intDistro(defEngine)}}};

auto factory0 = std::make_unique<SocketPartyCommunicationAgentFactory>(
0, partyInfo, true, tempdir);
auto factory1 = std::make_unique<SocketPartyCommunicationAgentFactory>(
1, partyInfo, true, tempdir);

int size = 1048576; // 1024 ^ 2
auto thread0 = std::thread(testAgentFactory, 0, size, std::move(factory0));
auto thread1 = std::thread(testAgentFactory, 1, size, std::move(factory1));

thread1.join();
thread0.join();

deleteTlsFiles(tempdir);
}

TEST(SocketPartyCommunicationAgentTest, testSendAndReceiveWithoutTls) {
std::random_device rd;
std::default_random_engine defEngine(rd());
std::uniform_int_distribution<int> intDistro(10000, 25000);

std::map<int, SocketPartyCommunicationAgentFactory::PartyInfo> partyInfo = {
{0, {"127.0.0.1", intDistro(defEngine)}},
{1, {"127.0.0.1", intDistro(defEngine)}}};

auto factory0 =
std::make_unique<SocketPartyCommunicationAgentFactory>(0, partyInfo);
auto factory1 =
std::make_unique<SocketPartyCommunicationAgentFactory>(1, partyInfo);

int size = 1024;
int size = 1048576; // 1024 ^ 2
auto thread0 = std::thread(testAgentFactory, 0, size, std::move(factory0));
auto thread1 = std::thread(testAgentFactory, 1, size, std::move(factory1));

thread1.join();
thread0.join();
}

} // namespace fbpcf::engine::communication

0 comments on commit f3d2e03

Please sign in to comment.