Skip to content

Commit

Permalink
Added socket::set_blocking(bool)
Browse files Browse the repository at this point in the history
NOTE: May need different implementation on Windows(?)
  • Loading branch information
snej committed Sep 5, 2019
1 parent 3b4f65b commit 1ca3726
Show file tree
Hide file tree
Showing 5 changed files with 156 additions and 89 deletions.
1 change: 1 addition & 0 deletions include/sockpp/platform.h
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@
#include <unistd.h>
#include <sys/socket.h>
#include <arpa/inet.h>
#include <fcntl.h>
#include <netdb.h>
#include <signal.h>
#include <errno.h>
Expand Down
4 changes: 4 additions & 0 deletions include/sockpp/socket.h
Original file line number Diff line number Diff line change
Expand Up @@ -337,6 +337,10 @@ class socket
* @return The address of the remote peer, if this socket is connected.
*/
sock_address_any peer_address() const;
/**
* Puts the socket into nonblocking (false) or blocking (true) I/O mode.
*/
virtual bool set_blocking(bool blocking);
/**
* Gets the value of a socket option.
*
Expand Down
6 changes: 1 addition & 5 deletions include/sockpp/tls_socket.h
Original file line number Diff line number Diff line change
Expand Up @@ -100,21 +100,17 @@ namespace sockpp {
// I/O primitives must be reimplemented in subclasses:

virtual ssize_t read(void *buf, size_t n) override = 0;

virtual ioresult read_r(void *buf, size_t n) override = 0;

virtual bool read_timeout(const std::chrono::microseconds& to) override = 0;

virtual ssize_t write(const void *buf, size_t n) override = 0;

virtual ioresult write_r(const void *buf, size_t n) override = 0;

virtual bool write_timeout(const std::chrono::microseconds& to) override = 0;

virtual ssize_t write(const std::vector<iovec> &ranges) override {
return ranges.empty() ? 0 : write(ranges[0].iov_base, ranges[0].iov_len);
}

virtual bool set_blocking(bool blocking) override = 0;

virtual void close() override {
if (stream_) {
Expand Down
216 changes: 132 additions & 84 deletions src/mbedtls_context.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@


// TODO: Better logging(?)
#define log(FMT,...) fprintf(stderr, FMT "\n", ## __VA_ARGS__)
#define log(FMT,...) fprintf(stderr, "TLS: " FMT "\n", ## __VA_ARGS__)


namespace sockpp {
Expand Down Expand Up @@ -123,30 +123,32 @@ namespace sockpp {
return;
}

if (check_mbed_ret(mbedtls_ssl_setup(&ssl_, context_.ssl_config_.get()),
if (check_mbed_setup(mbedtls_ssl_setup(&ssl_, context_.ssl_config_.get()),
"mbedtls_ssl_setup"))
return;
if (!hostname.empty() && check_mbed_ret(mbedtls_ssl_set_hostname(&ssl_, hostname.c_str()),
if (!hostname.empty() && check_mbed_setup(mbedtls_ssl_set_hostname(&ssl_, hostname.c_str()),
"mbedtls_ssl_set_hostname"))
return;

mbedtls_ssl_set_bio(&ssl_, this,
[](void *ctx, const uint8_t *buf, size_t len) {
return ((mbedtls_socket*)ctx)->bio_send(buf, len); },
nullptr,
[](void *ctx, uint8_t *buf, size_t len, uint32_t timeout) {
return ((mbedtls_socket*)ctx)->bio_recv_timeout(buf, len, timeout); });
open_ = true;
#if defined(_WIN32)
// Winsock does not allow us to tell if a socket is nonblocking, so assume it isn't
bool blocking = true;
#else
int flags = fcntl(stream().handle(), F_GETFL, 0);
bool blocking = (flags < 0 || (flags & O_NONBLOCK) == 0);
#endif
setup_bio(blocking);

// Run the TLS handshake:
int status;
do {
open_ = true; // temporarily, so BIO methods won't fail
status = mbedtls_ssl_handshake(&ssl_);
open_ = false;
} while (status == MBEDTLS_ERR_SSL_WANT_READ || status == MBEDTLS_ERR_SSL_WANT_WRITE
|| status == MBEDTLS_ERR_SSL_CRYPTO_IN_PROGRESS);
if (check_mbed_ret(status, "mbedtls_ssl_handshake") != 0) {
if (check_mbed_setup(status, "mbedtls_ssl_handshake") != 0)
return;
}

uint32_t verify_flags = mbedtls_ssl_get_verify_result(&ssl_);
if (verify_flags != 0 && verify_flags != uint32_t(-1)
Expand All @@ -155,10 +157,25 @@ namespace sockpp {
mbedtls_x509_crt_verify_info(vrfy_buf, sizeof( vrfy_buf ), "", verify_flags);
log("Cert verify failed: %s", vrfy_buf );
clear(MBEDTLS_ERR_X509_CERT_VERIFY_FAILED);
open_ = false;
reset();
return;
}
open_ = true;
}


void setup_bio(bool blocking) {
mbedtls_ssl_send_t *f_send = [](void *ctx, const uint8_t *buf, size_t len) {
return ((mbedtls_socket*)ctx)->bio_send(buf, len); };
mbedtls_ssl_recv_t *f_recv = nullptr;
mbedtls_ssl_recv_timeout_t *f_recv_timeout = nullptr;
if (blocking)
f_recv_timeout = [](void *ctx, uint8_t *buf, size_t len, uint32_t timeout) {
return ((mbedtls_socket*)ctx)->bio_recv_timeout(buf, len, timeout); };
else
f_recv = [](void *ctx, uint8_t *buf, size_t len) {
return ((mbedtls_socket*)ctx)->bio_recv(buf, len); };
mbedtls_ssl_set_bio(&ssl_, this, f_send, f_recv, f_recv_timeout);
}


Expand All @@ -169,18 +186,18 @@ namespace sockpp {
}


int check_mbed_ret(int ret, const char *fn) {
if (ret != 0) {
log_mbed_ret(ret, fn);
clear(ret); // sets last_error
reset(); // marks me as closed/invalid
stream().close();
virtual void close() override {
if (open_) {
mbedtls_ssl_close_notify(&ssl_);
open_ = false;
}
return ret;
tls_socket::close();
}


// -------- certificate / trust API


uint32_t peer_certificate_status() override {
return mbedtls_ssl_get_verify_result(&ssl_);
}
Expand Down Expand Up @@ -218,40 +235,6 @@ namespace sockpp {
// -------- stream_socket I/O


static int translate_mbed_err(int mbedErr) {
switch (mbedErr) {
case MBEDTLS_ERR_SSL_PEER_CLOSE_NOTIFY:
return 0;
case MBEDTLS_ERR_SSL_WANT_READ:
case MBEDTLS_ERR_SSL_WANT_WRITE:
return EWOULDBLOCK;
case MBEDTLS_ERR_NET_CONN_RESET:
return ECONNRESET;
case MBEDTLS_ERR_NET_RECV_FAILED:
case MBEDTLS_ERR_NET_SEND_FAILED:
return EIO;
default:
return mbedErr;
}
}


inline ssize_t check_mbed_io(int mbedResult) {
int result = translate_mbed_err(mbedResult);
if (result < 0) {
clear(result); // sets last_error
result = -1;
}
return result;
}


static inline ioresult ioresult_from_mbed(int mbedResult) {
mbedResult = translate_mbed_err(mbedResult);
return mbedResult < 0 ? ioresult(0, mbedResult) : ioresult(mbedResult, 0);
}


ssize_t read(void *buf, size_t length) override {
return check_mbed_io( mbedtls_ssl_read(&ssl_, (uint8_t*)buf, length) );
}
Expand All @@ -268,11 +251,15 @@ namespace sockpp {


ssize_t write(const void *buf, size_t length) override {
if (length == 0)
return 0;
return check_mbed_io( mbedtls_ssl_write(&ssl_, (const uint8_t*)buf, length) );
}


ioresult write_r(const void *buf, size_t length) override {
if (length == 0)
return {};
return ioresult_from_mbed( mbedtls_ssl_write(&ssl_, (const uint8_t*)buf, length) );
}

Expand All @@ -282,62 +269,123 @@ namespace sockpp {
}


// -------- mbedTLS BIO callbacks
bool set_blocking(bool blocking) override {
bool ok = stream().set_blocking(blocking);
if (ok)
setup_bio(blocking);
return ok;
}


template <bool reading>
static int bio_return_value(ioresult result) {
if (result.count >= 0)
return (int)result.count;
switch (result.error) {
case EPIPE:
case ECONNRESET:
return MBEDTLS_ERR_NET_CONN_RESET;
case EINTR:
#if defined(EAGAIN)
case EAGAIN:
#endif
#if defined(EWOULDBLOCK) && EWOULDBLOCK != EAGAIN
case EWOULDBLOCK:
#endif
return reading ? MBEDTLS_ERR_SSL_WANT_READ
: MBEDTLS_ERR_SSL_WANT_WRITE;
default:
return reading ? MBEDTLS_ERR_NET_RECV_FAILED
: MBEDTLS_ERR_NET_SEND_FAILED;
}
}
// -------- mbedTLS BIO callbacks


int bio_send(const void* buf, size_t length) {
if (!open_)
return 0;
return MBEDTLS_ERR_NET_CONN_RESET;
return bio_return_value<false>(stream().write_r(buf, length));
}


int bio_recv(void* buf, size_t length) {
if (!open_)
return MBEDTLS_ERR_NET_CONN_RESET;
return bio_return_value<true>(stream().read_r(buf, length));
}


int bio_recv_timeout(void* buf, size_t length, uint32_t timeout) {
if (!open_)
return 0;
return MBEDTLS_ERR_NET_CONN_RESET;
if (timeout > 0)
stream().read_timeout(chrono::milliseconds(timeout));

int n = bio_return_value<true>(stream().read_r(buf, length));
int n = bio_recv(buf, length);

if (timeout > 0)
stream().read_timeout(chrono::hours(1000)); //FIXME: How do I turn off a timeout?
return n;
}


virtual void close() override {
if (open_) {
mbedtls_ssl_close_notify(&ssl_);
// -------- error handling


// Translates mbedTLS error code to POSIX (errno)
static int translate_mbed_err(int mbedErr) {
switch (mbedErr) {
case MBEDTLS_ERR_SSL_PEER_CLOSE_NOTIFY:
return 0;
case MBEDTLS_ERR_SSL_WANT_READ:
case MBEDTLS_ERR_SSL_WANT_WRITE:
log(">>> mbedtls_socket returning EWOULDBLOCK");
return EWOULDBLOCK;
case MBEDTLS_ERR_NET_CONN_RESET:
return ECONNRESET;
case MBEDTLS_ERR_NET_RECV_FAILED:
case MBEDTLS_ERR_NET_SEND_FAILED:
return EIO;
default:
return mbedErr;
}
}


// Handles an mbedTLS error return value during setup, closing me on error
int check_mbed_setup(int ret, const char *fn) {
if (ret != 0) {
log_mbed_ret(ret, fn);
clear(ret); // sets last_error
reset(); // marks me as closed/invalid
stream().close();
open_ = false;
}
tls_socket::close();
return ret;
}


// Handles an mbedTLS read/write return value, storing any error in last_error
inline ssize_t check_mbed_io(int mbedResult) {
int result = translate_mbed_err(mbedResult);
if (result < 0) {
clear(result); // sets last_error
result = -1;
}
return result;
}


// Handles an mbedTLS read/write return value, converting it to an ioresult.
static inline ioresult ioresult_from_mbed(int mbedResult) {
mbedResult = translate_mbed_err(mbedResult);
return mbedResult < 0 ? ioresult(0, mbedResult) : ioresult(mbedResult, 0);
}


// Translates ioresult to an mbedTLS error code to return from a BIO function.
template <bool reading>
static int bio_return_value(ioresult result) {
if (result.count >= 0)
return (int)result.count;
switch (result.error) {
case EPIPE:
case ECONNRESET:
return MBEDTLS_ERR_NET_CONN_RESET;
case EINTR:
case EWOULDBLOCK:
#if defined(EAGAIN) && EAGAIN != EWOULDBLOCK // these are usually synonyms
case EAGAIN:
#endif
log(">>> BIO returning MBEDTLS_ERR_SSL_WANT_%s", reading ?"READ":"WRITE");
return reading ? MBEDTLS_ERR_SSL_WANT_READ
: MBEDTLS_ERR_SSL_WANT_WRITE;
default:
return reading ? MBEDTLS_ERR_NET_RECV_FAILED
: MBEDTLS_ERR_NET_SEND_FAILED;
}
}


};


Expand Down
18 changes: 18 additions & 0 deletions src/socket.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,24 @@ sock_address_any socket::peer_address() const
return sock_address_any(addrStore, len);
}

// --------------------------------------------------------------------------
// Puts the socket into nonblocking or blocking I/O mode.

bool socket::set_blocking(bool blocking) {
#if defined(_WIN32)
u_long mode = !blocking;
return check_ret_bool(::ioctlsocket(handle_, FIONBIO, &mode));
#else
int flags = check_ret(::fcntl(handle_, F_GETFL, 0));
if (flags < 0)
return false;
int newFlags = blocking ? (flags & ~O_NONBLOCK) : (flags | O_NONBLOCK);
if (newFlags == flags)
return true;
return check_ret_bool(::fcntl(handle_, F_SETFL, newFlags));
#endif
}

// --------------------------------------------------------------------------

bool socket::get_option(int level, int optname, void* optval, socklen_t* optlen) const
Expand Down

0 comments on commit 1ca3726

Please sign in to comment.