diff --git a/fbpcf/engine/communication/IPartyCommunicationAgent.h b/fbpcf/engine/communication/IPartyCommunicationAgent.h index 2e49fd75..82c5b44f 100644 --- a/fbpcf/engine/communication/IPartyCommunicationAgent.h +++ b/fbpcf/engine/communication/IPartyCommunicationAgent.h @@ -16,8 +16,11 @@ #error "Machine must be little endian" #endif -namespace fbpcf::engine::communication { +namespace fbpcf::engine::util { +class EmpNetworkAdapter; +} +namespace fbpcf::engine::communication { /** * This is the network API between two parties. * NOTE: sendT/receiveT only work when the two parties have the same endianness @@ -89,6 +92,8 @@ class IPartyCommunicationAgent { virtual std::pair getTrafficStatistics() const = 0; private: + friend class util::EmpNetworkAdapter; + // convert a vector of bits into a vector of bytes static std::vector compressToBytes( const std::vector& bits) { @@ -125,6 +130,10 @@ class IPartyCommunicationAgent { } return bits; } + + virtual void recvImpl(void* data, int nBytes) = 0; + + virtual void sendImpl(const void* data, int nBytes) = 0; }; template <> diff --git a/fbpcf/engine/communication/InMemoryPartyCommunicationAgentHost.cpp b/fbpcf/engine/communication/InMemoryPartyCommunicationAgentHost.cpp index 97225c63..569ec8a1 100644 --- a/fbpcf/engine/communication/InMemoryPartyCommunicationAgentHost.cpp +++ b/fbpcf/engine/communication/InMemoryPartyCommunicationAgentHost.cpp @@ -9,20 +9,33 @@ namespace fbpcf::engine::communication { +void InMemoryPartyCommunicationAgent::sendImpl(const void* data, int nBytes) { + std::vector buffer(nBytes); + memcpy(buffer.data(), data, nBytes); + host_.send(myId_, buffer); + sentData_ += nBytes; +} + void InMemoryPartyCommunicationAgent::send( const std::vector& data) { - host_.send(myId_, data); - sentData_ += data.size(); + sendImpl(static_cast(data.data()), data.size()); } -std::vector InMemoryPartyCommunicationAgent::receive( - size_t size) { - auto result = host_.receive(myId_, size); - if (result.size() != size) { +void InMemoryPartyCommunicationAgent::recvImpl(void* data, int nBytes) { + auto result = host_.receive(myId_, nBytes); + + if (result.size() != nBytes) { throw std::runtime_error("unexpected message size!"); } - receivedData_ += size; - return result; + memcpy(data, result.data(), nBytes); + receivedData_ += nBytes; +} + +std::vector InMemoryPartyCommunicationAgent::receive( + size_t size) { + std::vector v(size); + recvImpl(static_cast(v.data()), size); + return v; } InMemoryPartyCommunicationAgentHost::InMemoryPartyCommunicationAgentHost() { diff --git a/fbpcf/engine/communication/InMemoryPartyCommunicationAgentHost.h b/fbpcf/engine/communication/InMemoryPartyCommunicationAgentHost.h index 88470da2..7a9ab1d2 100644 --- a/fbpcf/engine/communication/InMemoryPartyCommunicationAgentHost.h +++ b/fbpcf/engine/communication/InMemoryPartyCommunicationAgentHost.h @@ -45,6 +45,10 @@ class InMemoryPartyCommunicationAgent final : public IPartyCommunicationAgent { return {sentData_, receivedData_}; } + void recvImpl(void* data, int nBytes) override; + + void sendImpl(const void* data, int nBytes) override; + private: InMemoryPartyCommunicationAgentHost& host_; int myId_; diff --git a/fbpcf/engine/communication/SocketPartyCommunicationAgent.cpp b/fbpcf/engine/communication/SocketPartyCommunicationAgent.cpp index 86960a3c..b22365cc 100644 --- a/fbpcf/engine/communication/SocketPartyCommunicationAgent.cpp +++ b/fbpcf/engine/communication/SocketPartyCommunicationAgent.cpp @@ -87,42 +87,49 @@ SocketPartyCommunicationAgent::~SocketPartyCommunicationAgent() { } } -void SocketPartyCommunicationAgent::send( - const std::vector& data) { +void SocketPartyCommunicationAgent::sendImpl(const void* data, int nBytes) { size_t bytesWritten; if (!ssl_) { - bytesWritten = - fwrite(data.data(), sizeof(unsigned char), data.size(), outgoingPort_); + bytesWritten = fwrite(data, sizeof(unsigned char), nBytes, outgoingPort_); } else { - bytesWritten = SSL_write(ssl_, (void*)data.data(), data.size()); + bytesWritten = SSL_write(ssl_, data, nBytes); } - assert(bytesWritten == data.size()); + assert(bytesWritten == nBytes); sentData_ += bytesWritten; if (!ssl_) { fflush(outgoingPort_); } } -std::vector SocketPartyCommunicationAgent::receive(size_t size) { +void SocketPartyCommunicationAgent::send( + const std::vector& data) { + sendImpl(static_cast(data.data()), data.size()); +} + +void SocketPartyCommunicationAgent::recvImpl(void* data, int nBytes) { size_t bytesRead = 0; - std::vector rst(size); if (!ssl_) { - bytesRead = fread(rst.data(), sizeof(unsigned char), size, incomingPort_); + bytesRead = fread(data, sizeof(unsigned char), nBytes, incomingPort_); } else { // fread is blocking, but SSL_read is nonblocking. This discrepancy // can cause issues at the application level. We need to make sure that // both APIs behave consistently, so here we add a loop to ensure we // mimick blocking behavior. - while (bytesRead < size) { + while (bytesRead < nBytes) { bytesRead += SSL_read( ssl_, - rst.data() + (bytesRead * sizeof(unsigned char)), - size - bytesRead); + (unsigned char*)data + (bytesRead * sizeof(unsigned char)), + nBytes - bytesRead); } } - assert(bytesRead == size); + assert(bytesRead == nBytes); receivedData_ += bytesRead; +} + +std::vector SocketPartyCommunicationAgent::receive(size_t size) { + std::vector rst(size); + recvImpl(static_cast(rst.data()), size); return rst; } diff --git a/fbpcf/engine/communication/SocketPartyCommunicationAgent.h b/fbpcf/engine/communication/SocketPartyCommunicationAgent.h index df0495f7..edfa13b5 100644 --- a/fbpcf/engine/communication/SocketPartyCommunicationAgent.h +++ b/fbpcf/engine/communication/SocketPartyCommunicationAgent.h @@ -57,6 +57,10 @@ class SocketPartyCommunicationAgent final : public IPartyCommunicationAgent { return {sentData_, receivedData_}; } + void recvImpl(void* data, int nBytes) override; + + void sendImpl(const void* data, int nBytes) override; + private: void openServerPort(int portNo); void openClientPort(const std::string& serverAddress, int portNo); diff --git a/fbpcf/engine/util/EmpNetworkAdapter.h b/fbpcf/engine/util/EmpNetworkAdapter.h index d8041e61..848af6c4 100644 --- a/fbpcf/engine/util/EmpNetworkAdapter.h +++ b/fbpcf/engine/util/EmpNetworkAdapter.h @@ -22,15 +22,12 @@ class EmpNetworkAdapter { explicit EmpNetworkAdapter(communication::IPartyCommunicationAgent& agent) : agent_(agent) {} - void send_data(const void* data, int nByte) { - std::vector buffer(nByte); - memcpy(buffer.data(), data, nByte); - agent_.send(buffer); + void send_data(const void* data, int nBytes) { + agent_.sendImpl(data, nBytes); } - void recv_data(void* data, int nByte) { - auto buffer = agent_.receive(nByte); - memcpy(data, buffer.data(), nByte); + void recv_data(void* data, int nBytes) { + agent_.recvImpl(data, nBytes); } void send_block(const __m128i* data, int nBlock) {