Skip to content

Commit

Permalink
JNI/JSSE: optimize out array creation in WolfSSLEngine RecvAppData(),…
Browse files Browse the repository at this point in the history
… pass ByteBuffer down to JNI directly
  • Loading branch information
cconlon committed Dec 23, 2024
1 parent 9db7ff1 commit 2c6868c
Show file tree
Hide file tree
Showing 4 changed files with 378 additions and 86 deletions.
314 changes: 251 additions & 63 deletions native/com_wolfssl_WolfSSLSession.c
Original file line number Diff line number Diff line change
Expand Up @@ -1002,16 +1002,104 @@ JNIEXPORT jint JNICALL Java_com_wolfssl_WolfSSLSession_write
}
}

JNIEXPORT jint JNICALL Java_com_wolfssl_WolfSSLSession_read
(JNIEnv* jenv, jobject jcl, jlong sslPtr, jbyteArray raw, jint offset,
jint length, jint timeout)
/**
* Read len bytes from wolfSSL_read() back into provided output buffer.
*
* Internal function called by WolfSSLSession.read() calls.
*
* If wolfSSL_get_fd(ssl) returns a socket descriptor, try to wait for
* data with select()/poll() up to provided timeout.
*
* Returns number of bytes read on success, or negative on error.
*/
static int SSLReadNonblockingWithSelectPoll(WOLFSSL* ssl, byte* out,
int length, int timeout)
{
byte* data = NULL;
int size = 0, ret, err, sockfd;
int size, ret, err, sockfd;
int pollRx = 0;
int pollTx = 0;
wolfSSL_Mutex* jniSessLock = NULL;
SSLAppData* appData = NULL;

if (ssl == NULL || out == NULL) {
return BAD_FUNC_ARG;
}

/* get session mutex from SSL app data */
appData = (SSLAppData*)wolfSSL_get_app_data(ssl);
if (appData == NULL) {
return WOLFSSL_FAILURE;
}

jniSessLock = appData->jniSessLock;
if (jniSessLock == NULL) {
return WOLFSSL_FAILURE;
}

do {
/* lock mutex around session I/O before read attempt */
if (wc_LockMutex(jniSessLock) != 0) {
size = WOLFSSL_FAILURE;
break;
}

size = wolfSSL_read(ssl, out, length);
err = wolfSSL_get_error(ssl, size);

/* unlock mutex around session I/O after read attempt */
if (wc_UnLockMutex(jniSessLock) != 0) {
size = WOLFSSL_FAILURE;
break;
}

if (size < 0 &&
((err == SSL_ERROR_WANT_READ) || (err == SSL_ERROR_WANT_WRITE))) {

sockfd = wolfSSL_get_fd(ssl);
if (sockfd == -1) {
/* For I/O that does not use sockets, sockfd may be -1,
* skip try to call select() */
break;
}

if (err == SSL_ERROR_WANT_READ) {
pollRx = 1;
}
else if (err == SSL_ERROR_WANT_WRITE) {
pollTx = 1;
}

#if defined(WOLFJNI_USE_IO_SELECT) || defined(USE_WINDOWS_API)
ret = socketSelect(sockfd, timeout, pollRx);
#else
ret = socketPoll(sockfd, timeout, pollRx, pollTx);
#endif
if ((ret == WOLFJNI_IO_EVENT_RECV_READY) ||
(ret == WOLFJNI_IO_EVENT_SEND_READY)) {
/* loop around and try wolfSSL_read() again */
continue;
} else {
/* Java will throw SocketTimeoutException or
* SocketException if ret equals
* WOLFJNI_IO_EVENT_TIMEOUT, WOLFJNI_IO_EVENT_FD_CLOSED
* WOLFJNI_IO_EVENT_ERROR, WOLFJNI_IO_EVENT_POLLHUP or
* WOLFJNI_IO_EVENT_FAIL */
size = ret;
break;
}
}

} while (err == SSL_ERROR_WANT_WRITE || err == SSL_ERROR_WANT_READ);

return size;
}

JNIEXPORT jint JNICALL Java_com_wolfssl_WolfSSLSession_read__J_3BIII
(JNIEnv* jenv, jobject jcl, jlong sslPtr, jbyteArray raw, jint offset,
jint length, jint timeout)
{
int size = 0;
byte* data = NULL;
WOLFSSL* ssl = (WOLFSSL*)(uintptr_t)sslPtr;
(void)jcl;

Expand All @@ -1027,79 +1115,179 @@ JNIEXPORT jint JNICALL Java_com_wolfssl_WolfSSLSession_read
return SSL_FAILURE;
}

/* get session mutex from SSL app data */
appData = (SSLAppData*)wolfSSL_get_app_data(ssl);
if (appData == NULL) {
size = SSLReadNonblockingWithSelectPoll(ssl, data + offset,
(int)length, (int)timeout);

if (size < 0) {
(*jenv)->ReleaseByteArrayElements(jenv, raw, (jbyte*)data,
JNI_ABORT);
return WOLFSSL_FAILURE;
JNI_ABORT);
}
else {
/* JNI_COMMIT commits the data but does not free the local array
* 0 is used here to both commit and free */
(*jenv)->ReleaseByteArrayElements(jenv, raw, (jbyte*)data, 0);
}
}

jniSessLock = appData->jniSessLock;
if (jniSessLock == NULL) {
(*jenv)->ReleaseByteArrayElements(jenv, raw, (jbyte*)data,
JNI_ABORT);
return WOLFSSL_FAILURE;
return size;
}

JNIEXPORT jint JNICALL Java_com_wolfssl_WolfSSLSession_read__JLjava_nio_ByteBuffer_2II
(JNIEnv* jenv, jobject jcl, jlong sslPtr, jobject buf, jint length, jint timeout)
{
int size = 0;
int maxOutputSz;
int outSz = length;
byte* data = NULL;
WOLFSSL* ssl = (WOLFSSL*)(uintptr_t)sslPtr;

jclass excClass;
jclass buffClass;
jmethodID positionMeth;
jmethodID limitMeth;
jmethodID hasArrayMeth;
jmethodID arrayMeth;
jmethodID setPositionMeth;

jint position;
jint limit;
jboolean hasArray;
jbyteArray bufArr;

(void)jcl;

if (jenv == NULL || ssl == NULL || buf == NULL) {
return BAD_FUNC_ARG;
}

if (length > 0) {
/* Get WolfSSLException class */
excClass = (*jenv)->FindClass(jenv, "com/wolfssl/WolfSSLException");
if ((*jenv)->ExceptionOccurred(jenv)) {
(*jenv)->ExceptionDescribe(jenv);
(*jenv)->ExceptionClear(jenv);
return -1;
}

do {
/* lock mutex around session I/O before read attempt */
if (wc_LockMutex(jniSessLock) != 0) {
size = WOLFSSL_FAILURE;
break;
}
/* Get ByteBuffer class */
buffClass = (*jenv)->GetObjectClass(jenv, buf);
if (buffClass == NULL) {
(*jenv)->ThrowNew(jenv, excClass,
"Failed to find ByteBuffer class in native read()");
return -1;
}

size = wolfSSL_read(ssl, data + offset, length);
err = wolfSSL_get_error(ssl, size);
/* Get ByteBuffer position */
positionMeth = (*jenv)->GetMethodID(jenv, buffClass, "position", "()I");
if (positionMeth == NULL) {
if ((*jenv)->ExceptionOccurred(jenv)) {
(*jenv)->ExceptionDescribe(jenv);
(*jenv)->ExceptionClear(jenv);
}
(*jenv)->ThrowNew(jenv, excClass,
"Failed to find ByteBuffer position() method in native read()");
return -1;
}
position = (*jenv)->CallIntMethod(jenv, buf, positionMeth);

/* unlock mutex around session I/O after read attempt */
if (wc_UnLockMutex(jniSessLock) != 0) {
size = WOLFSSL_FAILURE;
break;
/* Get ByteBuffer limit */
limitMeth = (*jenv)->GetMethodID(jenv, buffClass, "limit", "()I");
if (limitMeth == NULL) {
if ((*jenv)->ExceptionOccurred(jenv)) {
(*jenv)->ExceptionDescribe(jenv);
(*jenv)->ExceptionClear(jenv);
}
(*jenv)->ThrowNew(jenv, excClass,
"Failed to find ByteBuffer limit() method in native read()");
return -1;
}
limit = (*jenv)->CallIntMethod(jenv, buf, limitMeth);

if (size < 0 && ((err == SSL_ERROR_WANT_READ) || \
(err == SSL_ERROR_WANT_WRITE))) {
/* Get and call ByteBuffer.hasArray() before calling array() */
hasArrayMeth = (*jenv)->GetMethodID(jenv, buffClass, "hasArray", "()Z");
if (hasArrayMeth == NULL) {
if ((*jenv)->ExceptionOccurred(jenv)) {
(*jenv)->ExceptionDescribe(jenv);
(*jenv)->ExceptionClear(jenv);
}
(*jenv)->ThrowNew(jenv, excClass,
"Failed to find ByteBuffer hasArray() method in native read()");
return -1;
}

sockfd = wolfSSL_get_fd(ssl);
if (sockfd == -1) {
/* For I/O that does not use sockets, sockfd may be -1,
* skip try to call select() */
break;
}
/* ByteBuffer.hasArray() does not throw any exceptions */
hasArray = (*jenv)->CallBooleanMethod(jenv, buf, hasArrayMeth);
if (!hasArray) {
(*jenv)->ThrowNew(jenv, excClass,
"ByteBuffer.hasArray() is false in native read()");
return BAD_FUNC_ARG;
}

if (err == SSL_ERROR_WANT_READ) {
pollRx = 1;
}
else if (err == SSL_ERROR_WANT_WRITE) {
pollTx = 1;
}
/* Only read up to maximum space we have in this ByteBuffer */
maxOutputSz = (limit - position);
if (outSz > maxOutputSz) {
outSz = maxOutputSz;
}

#if defined(WOLFJNI_USE_IO_SELECT) || defined(USE_WINDOWS_API)
ret = socketSelect(sockfd, (int)timeout, pollRx);
#else
ret = socketPoll(sockfd, (int)timeout, pollRx, pollTx);
#endif
if ((ret == WOLFJNI_IO_EVENT_RECV_READY) ||
(ret == WOLFJNI_IO_EVENT_SEND_READY)) {
/* loop around and try wolfSSL_read() again */
continue;
} else {
/* Java will throw SocketTimeoutException or
* SocketException if ret equals
* WOLFJNI_IO_EVENT_TIMEOUT, WOLFJNI_IO_EVENT_FD_CLOSED
* WOLFJNI_IO_EVENT_ERROR, WOLFJNI_IO_EVENT_POLLHUP or
* WOLFJNI_IO_EVENT_FAIL */
size = ret;
break;
}
/* Get reference to underlying byte[] from ByteBuffer */
arrayMeth = (*jenv)->GetMethodID(jenv, buffClass, "array", "()[B");
if (arrayMeth == NULL) {
if ((*jenv)->ExceptionOccurred(jenv)) {
(*jenv)->ExceptionDescribe(jenv);
(*jenv)->ExceptionClear(jenv);
}
(*jenv)->ThrowNew(jenv, excClass,
"Failed to find ByteBuffer array() method in native read()");
return -1;
}
bufArr = (jbyteArray)(*jenv)->CallObjectMethod(jenv, buf, arrayMeth);

} while (err == SSL_ERROR_WANT_WRITE || err == SSL_ERROR_WANT_READ);
/* Get array elements */
data = (byte*)(*jenv)->GetByteArrayElements(jenv, bufArr, NULL);
if ((*jenv)->ExceptionOccurred(jenv)) {
(*jenv)->ExceptionDescribe(jenv);
(*jenv)->ExceptionClear(jenv);
(*jenv)->ThrowNew(jenv, excClass,
"Exception when calling ByteBuffer.array() in native read()");
return -1;
}


if (data != NULL) {
size = SSLReadNonblockingWithSelectPoll(ssl, data + position,
maxOutputSz, (int)timeout);

/* JNI_COMMIT commits the data but does not free the local array
* 0 is used here to both commit and free */
(*jenv)->ReleaseByteArrayElements(jenv, raw, (jbyte*)data, 0);
/* Relase array elements */
if (size < 0) {
(*jenv)->ReleaseByteArrayElements(jenv, bufArr, (jbyte*)data,
JNI_ABORT);
}
else {
/* JNI_COMMIT commits the data but does not free the local array
* 0 is used here to both commit and free */
(*jenv)->ReleaseByteArrayElements(jenv, bufArr,
(jbyte*)data, 0);

/* Update ByteBuffer position() based on bytes written */
setPositionMeth = (*jenv)->GetMethodID(jenv, buffClass,
"position", "(I)Ljava/nio/Buffer;");
//"position", "(I)V");
if (setPositionMeth == NULL) {
if ((*jenv)->ExceptionOccurred(jenv)) {
(*jenv)->ExceptionDescribe(jenv);
(*jenv)->ExceptionClear(jenv);
}
(*jenv)->ThrowNew(jenv, excClass,
"Failed to set ByteBuffer position() from "
"native read()");
size = -1;
}
else {
(*jenv)->CallVoidMethod(jenv, buf, setPositionMeth,
position + size);
}
}
}
}

return size;
Expand Down
10 changes: 9 additions & 1 deletion native/com_wolfssl_WolfSSLSession.h

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading

0 comments on commit 2c6868c

Please sign in to comment.