diff --git a/native/com_wolfssl_WolfSSLSession.c b/native/com_wolfssl_WolfSSLSession.c index ed9d909e..5b34262a 100644 --- a/native/com_wolfssl_WolfSSLSession.c +++ b/native/com_wolfssl_WolfSSLSession.c @@ -1002,16 +1002,104 @@ JNIEXPORT jint JNICALL Java_com_wolfssl_WolfSSLSession_write } } -JNIEXPORT jint JNICALL Java_com_wolfssl_WolfSSLSession_read - (JNIEnv* jenv, jobject jcl, jlong sslPtr, jbyteArray raw, jint offset, - jint length, jint timeout) +/** + * Read len bytes from wolfSSL_read() back into provided output buffer. + * + * Internal function called by WolfSSLSession.read() calls. + * + * If wolfSSL_get_fd(ssl) returns a socket descriptor, try to wait for + * data with select()/poll() up to provided timeout. + * + * Returns number of bytes read on success, or negative on error. + */ +static int SSLReadNonblockingWithSelectPoll(WOLFSSL* ssl, byte* out, + int length, int timeout) { - byte* data = NULL; - int size = 0, ret, err, sockfd; + int size, ret, err, sockfd; int pollRx = 0; int pollTx = 0; wolfSSL_Mutex* jniSessLock = NULL; SSLAppData* appData = NULL; + + if (ssl == NULL || out == NULL) { + return BAD_FUNC_ARG; + } + + /* get session mutex from SSL app data */ + appData = (SSLAppData*)wolfSSL_get_app_data(ssl); + if (appData == NULL) { + return WOLFSSL_FAILURE; + } + + jniSessLock = appData->jniSessLock; + if (jniSessLock == NULL) { + return WOLFSSL_FAILURE; + } + + do { + /* lock mutex around session I/O before read attempt */ + if (wc_LockMutex(jniSessLock) != 0) { + size = WOLFSSL_FAILURE; + break; + } + + size = wolfSSL_read(ssl, out, length); + err = wolfSSL_get_error(ssl, size); + + /* unlock mutex around session I/O after read attempt */ + if (wc_UnLockMutex(jniSessLock) != 0) { + size = WOLFSSL_FAILURE; + break; + } + + if (size < 0 && + ((err == SSL_ERROR_WANT_READ) || (err == SSL_ERROR_WANT_WRITE))) { + + sockfd = wolfSSL_get_fd(ssl); + if (sockfd == -1) { + /* For I/O that does not use sockets, sockfd may be -1, + * skip try to call select() */ + break; + } + + if (err == SSL_ERROR_WANT_READ) { + pollRx = 1; + } + else if (err == SSL_ERROR_WANT_WRITE) { + pollTx = 1; + } + + #if defined(WOLFJNI_USE_IO_SELECT) || defined(USE_WINDOWS_API) + ret = socketSelect(sockfd, timeout, pollRx); + #else + ret = socketPoll(sockfd, timeout, pollRx, pollTx); + #endif + if ((ret == WOLFJNI_IO_EVENT_RECV_READY) || + (ret == WOLFJNI_IO_EVENT_SEND_READY)) { + /* loop around and try wolfSSL_read() again */ + continue; + } else { + /* Java will throw SocketTimeoutException or + * SocketException if ret equals + * WOLFJNI_IO_EVENT_TIMEOUT, WOLFJNI_IO_EVENT_FD_CLOSED + * WOLFJNI_IO_EVENT_ERROR, WOLFJNI_IO_EVENT_POLLHUP or + * WOLFJNI_IO_EVENT_FAIL */ + size = ret; + break; + } + } + + } while (err == SSL_ERROR_WANT_WRITE || err == SSL_ERROR_WANT_READ); + + return size; +} + +JNIEXPORT jint JNICALL Java_com_wolfssl_WolfSSLSession_read__J_3BIII + (JNIEnv* jenv, jobject jcl, jlong sslPtr, jbyteArray raw, jint offset, + jint length, jint timeout) +{ + int size = 0; + byte* data = NULL; WOLFSSL* ssl = (WOLFSSL*)(uintptr_t)sslPtr; (void)jcl; @@ -1027,79 +1115,178 @@ JNIEXPORT jint JNICALL Java_com_wolfssl_WolfSSLSession_read return SSL_FAILURE; } - /* get session mutex from SSL app data */ - appData = (SSLAppData*)wolfSSL_get_app_data(ssl); - if (appData == NULL) { + size = SSLReadNonblockingWithSelectPoll(ssl, data + offset, + (int)length, (int)timeout); + + if (size < 0) { (*jenv)->ReleaseByteArrayElements(jenv, raw, (jbyte*)data, - JNI_ABORT); - return WOLFSSL_FAILURE; + JNI_ABORT); } + else { + /* JNI_COMMIT commits the data but does not free the local array + * 0 is used here to both commit and free */ + (*jenv)->ReleaseByteArrayElements(jenv, raw, (jbyte*)data, 0); + } + } - jniSessLock = appData->jniSessLock; - if (jniSessLock == NULL) { - (*jenv)->ReleaseByteArrayElements(jenv, raw, (jbyte*)data, - JNI_ABORT); - return WOLFSSL_FAILURE; + return size; +} + +JNIEXPORT jint JNICALL Java_com_wolfssl_WolfSSLSession_read__JLjava_nio_ByteBuffer_2II + (JNIEnv* jenv, jobject jcl, jlong sslPtr, jobject buf, jint length, jint timeout) +{ + int size = 0; + int maxOutputSz; + int outSz = length; + byte* data = NULL; + WOLFSSL* ssl = (WOLFSSL*)(uintptr_t)sslPtr; + + jclass excClass; + jclass buffClass; + jmethodID positionMeth; + jmethodID limitMeth; + jmethodID hasArrayMeth; + jmethodID arrayMeth; + jmethodID setPositionMeth; + + jint position; + jint limit; + jboolean hasArray; + jbyteArray bufArr; + + (void)jcl; + + if (jenv == NULL || ssl == NULL || buf == NULL) { + return BAD_FUNC_ARG; + } + + if (length > 0) { + /* Get WolfSSLException class */ + excClass = (*jenv)->FindClass(jenv, "com/wolfssl/WolfSSLException"); + if ((*jenv)->ExceptionOccurred(jenv)) { + (*jenv)->ExceptionDescribe(jenv); + (*jenv)->ExceptionClear(jenv); + return -1; } - do { - /* lock mutex around session I/O before read attempt */ - if (wc_LockMutex(jniSessLock) != 0) { - size = WOLFSSL_FAILURE; - break; - } + /* Get ByteBuffer class */ + buffClass = (*jenv)->GetObjectClass(jenv, buf); + if (buffClass == NULL) { + (*jenv)->ThrowNew(jenv, excClass, + "Failed to find ByteBuffer class in native read()"); + return -1; + } - size = wolfSSL_read(ssl, data + offset, length); - err = wolfSSL_get_error(ssl, size); + /* Get ByteBuffer position */ + positionMeth = (*jenv)->GetMethodID(jenv, buffClass, "position", "()I"); + if (positionMeth == NULL) { + if ((*jenv)->ExceptionOccurred(jenv)) { + (*jenv)->ExceptionDescribe(jenv); + (*jenv)->ExceptionClear(jenv); + } + (*jenv)->ThrowNew(jenv, excClass, + "Failed to find ByteBuffer position() method in native read()"); + return -1; + } + position = (*jenv)->CallIntMethod(jenv, buf, positionMeth); - /* unlock mutex around session I/O after read attempt */ - if (wc_UnLockMutex(jniSessLock) != 0) { - size = WOLFSSL_FAILURE; - break; + /* Get ByteBuffer limit */ + limitMeth = (*jenv)->GetMethodID(jenv, buffClass, "limit", "()I"); + if (limitMeth == NULL) { + if ((*jenv)->ExceptionOccurred(jenv)) { + (*jenv)->ExceptionDescribe(jenv); + (*jenv)->ExceptionClear(jenv); } + (*jenv)->ThrowNew(jenv, excClass, + "Failed to find ByteBuffer limit() method in native read()"); + return -1; + } + limit = (*jenv)->CallIntMethod(jenv, buf, limitMeth); - if (size < 0 && ((err == SSL_ERROR_WANT_READ) || \ - (err == SSL_ERROR_WANT_WRITE))) { + /* Get and call ByteBuffer.hasArray() before calling array() */ + hasArrayMeth = (*jenv)->GetMethodID(jenv, buffClass, "hasArray", "()Z"); + if (hasArrayMeth == NULL) { + if ((*jenv)->ExceptionOccurred(jenv)) { + (*jenv)->ExceptionDescribe(jenv); + (*jenv)->ExceptionClear(jenv); + } + (*jenv)->ThrowNew(jenv, excClass, + "Failed to find ByteBuffer hasArray() method in native read()"); + return -1; + } - sockfd = wolfSSL_get_fd(ssl); - if (sockfd == -1) { - /* For I/O that does not use sockets, sockfd may be -1, - * skip try to call select() */ - break; - } + /* ByteBuffer.hasArray() does not throw any exceptions */ + hasArray = (*jenv)->CallBooleanMethod(jenv, buf, hasArrayMeth); + if (!hasArray) { + (*jenv)->ThrowNew(jenv, excClass, + "ByteBuffer.hasArray() is false in native read()"); + return BAD_FUNC_ARG; + } - if (err == SSL_ERROR_WANT_READ) { - pollRx = 1; - } - else if (err == SSL_ERROR_WANT_WRITE) { - pollTx = 1; - } + /* Only read up to maximum space we have in this ByteBuffer */ + maxOutputSz = (limit - position); + if (outSz > maxOutputSz) { + outSz = maxOutputSz; + } - #if defined(WOLFJNI_USE_IO_SELECT) || defined(USE_WINDOWS_API) - ret = socketSelect(sockfd, (int)timeout, pollRx); - #else - ret = socketPoll(sockfd, (int)timeout, pollRx, pollTx); - #endif - if ((ret == WOLFJNI_IO_EVENT_RECV_READY) || - (ret == WOLFJNI_IO_EVENT_SEND_READY)) { - /* loop around and try wolfSSL_read() again */ - continue; - } else { - /* Java will throw SocketTimeoutException or - * SocketException if ret equals - * WOLFJNI_IO_EVENT_TIMEOUT, WOLFJNI_IO_EVENT_FD_CLOSED - * WOLFJNI_IO_EVENT_ERROR, WOLFJNI_IO_EVENT_POLLHUP or - * WOLFJNI_IO_EVENT_FAIL */ - size = ret; - break; - } + /* Get reference to underlying byte[] from ByteBuffer */ + arrayMeth = (*jenv)->GetMethodID(jenv, buffClass, "array", "()[B"); + if (arrayMeth == NULL) { + if ((*jenv)->ExceptionOccurred(jenv)) { + (*jenv)->ExceptionDescribe(jenv); + (*jenv)->ExceptionClear(jenv); } + (*jenv)->ThrowNew(jenv, excClass, + "Failed to find ByteBuffer array() method in native read()"); + return -1; + } + bufArr = (jbyteArray)(*jenv)->CallObjectMethod(jenv, buf, arrayMeth); - } while (err == SSL_ERROR_WANT_WRITE || err == SSL_ERROR_WANT_READ); + /* Get array elements */ + data = (byte*)(*jenv)->GetByteArrayElements(jenv, bufArr, NULL); + if ((*jenv)->ExceptionOccurred(jenv)) { + (*jenv)->ExceptionDescribe(jenv); + (*jenv)->ExceptionClear(jenv); + (*jenv)->ThrowNew(jenv, excClass, + "Exception when calling ByteBuffer.array() in native read()"); + return -1; + } + + + if (data != NULL) { + size = SSLReadNonblockingWithSelectPoll(ssl, data + position, + maxOutputSz, (int)timeout); - /* JNI_COMMIT commits the data but does not free the local array - * 0 is used here to both commit and free */ - (*jenv)->ReleaseByteArrayElements(jenv, raw, (jbyte*)data, 0); + /* Relase array elements */ + if (size < 0) { + (*jenv)->ReleaseByteArrayElements(jenv, bufArr, (jbyte*)data, + JNI_ABORT); + } + else { + /* JNI_COMMIT commits the data but does not free the local array + * 0 is used here to both commit and free */ + (*jenv)->ReleaseByteArrayElements(jenv, bufArr, + (jbyte*)data, 0); + + /* Update ByteBuffer position() based on bytes written */ + setPositionMeth = (*jenv)->GetMethodID(jenv, buffClass, + "position", "(I)Ljava/nio/Buffer;"); + if (setPositionMeth == NULL) { + if ((*jenv)->ExceptionOccurred(jenv)) { + (*jenv)->ExceptionDescribe(jenv); + (*jenv)->ExceptionClear(jenv); + } + (*jenv)->ThrowNew(jenv, excClass, + "Failed to set ByteBuffer position() from " + "native read()"); + size = -1; + } + else { + (*jenv)->CallVoidMethod(jenv, buf, setPositionMeth, + position + size); + } + } + } } return size; diff --git a/native/com_wolfssl_WolfSSLSession.h b/native/com_wolfssl_WolfSSLSession.h index e7941971..639db768 100644 --- a/native/com_wolfssl_WolfSSLSession.h +++ b/native/com_wolfssl_WolfSSLSession.h @@ -100,9 +100,17 @@ JNIEXPORT jint JNICALL Java_com_wolfssl_WolfSSLSession_write * Method: read * Signature: (J[BIII)I */ -JNIEXPORT jint JNICALL Java_com_wolfssl_WolfSSLSession_read +JNIEXPORT jint JNICALL Java_com_wolfssl_WolfSSLSession_read__J_3BIII (JNIEnv *, jobject, jlong, jbyteArray, jint, jint, jint); +/* + * Class: com_wolfssl_WolfSSLSession + * Method: read + * Signature: (JLjava/nio/ByteBuffer;II)I + */ +JNIEXPORT jint JNICALL Java_com_wolfssl_WolfSSLSession_read__JLjava_nio_ByteBuffer_2II + (JNIEnv *, jobject, jlong, jobject, jint, jint); + /* * Class: com_wolfssl_WolfSSLSession * Method: accept diff --git a/src/java/com/wolfssl/WolfSSLSession.java b/src/java/com/wolfssl/WolfSSLSession.java index cb5d2a05..659f1ce1 100644 --- a/src/java/com/wolfssl/WolfSSLSession.java +++ b/src/java/com/wolfssl/WolfSSLSession.java @@ -28,6 +28,7 @@ import java.net.SocketException; import java.net.SocketTimeoutException; import java.lang.StringBuilder; +import java.nio.ByteBuffer; import java.nio.charset.StandardCharsets; import com.wolfssl.WolfSSLException; @@ -253,6 +254,8 @@ private native int write(long ssl, byte[] data, int offset, int length, int timeout); private native int read(long ssl, byte[] data, int offset, int sz, int timeout); + private native int read(long ssl, ByteBuffer data, int sz, int timeout) + throws WolfSSLException; private native int accept(long ssl, int timeout); private native void freeSSL(long ssl); private native int shutdownSSL(long ssl, int timeout); @@ -1112,6 +1115,86 @@ public int read(byte[] data, int offset, int sz, int timeout) return ret; } + /** + * Reads bytes from the SSL session and returns the read bytes into + * the provided ByteBuffer, using socket timeout value in milliseconds. + * + * The bytes read are removed from the internal receive buffer. + *
+ * If necessary, read()
will negotiate an SSL/TLS session
+ * if the handshake has not already been performed yet by connect()
+ *
or accept()
.
+ *
+ * The SSL/TLS protocol uses SSL records which have a maximum size of
+ * 16kB. As such, wolfSSL needs to read an entire SSL record internally
+ * before it is able to process and decrypt the record. Because of this,
+ * a call to read()
will only be able to return the
+ * maximum buffer size which has been decrypted at the time of calling.
+ * There may be additional not-yet-decrypted data waiting in the internal
+ * wolfSSL receive buffer which will be retrieved and decrypted with the
+ * next call to read()
.
+ *
+ * @param data ByteBuffer where the data read from the SSL connection
+ * will be placed. position() will be updated after this
+ * method writes data to the ByteBuffer.
+ * @param sz number of bytes to read into data
,
+ * may be adjusted to the maximum space in data if that is
+ * smaller than this size.
+ * @param timeout read timeout, milliseconds.
+ * @return the number of bytes read upon success. SSL_FAILURE
+ *
will be returned upon failure which may be caused
+ * by either a clean (close notify alert) shutdown or just
+ * that the peer closed the connection.
+ * SSL_FATAL_ERROR
upon failure when either an error
+ * occurred or, when using non-blocking sockets, the
+ * SSL_ERROR_WANT_READ or SSL_ERROR_WANT_WRITE
+ * error was received and the application needs to call
+ * read()
again. Use getError
to
+ * get a specific error code.
+ * BAD_FUNC_ARC
when bad arguments are used.
+ * @throws IllegalStateException WolfSSLContext has been freed
+ * @throws SocketTimeoutException if socket timeout occurs
+ * @throws SocketException Native socket select/poll() failed
+ */
+ public int read(ByteBuffer data, int sz, int timeout)
+ throws IllegalStateException, SocketTimeoutException, SocketException {
+
+ int ret;
+ long localPtr;
+
+ confirmObjectIsActive();
+
+ /* Fix for Infer scan, since not synchronizing on sslLock for
+ * access to this.sslPtr, see note below */
+ synchronized (sslLock) {
+ localPtr = this.sslPtr;
+ }
+
+ WolfSSLDebug.log(getClass(), WolfSSLDebug.Component.JNI,
+ WolfSSLDebug.INFO, localPtr, "entered read(ByteBuffer, " +
+ "sz: " + sz + ", timeout: " + timeout + ")");
+
+ /* not synchronizing on sslLock here since JNI read() locks
+ * session mutex around native wolfSSL_read() call. If sslLock
+ * is locked here, since we call select() inside native JNI we
+ * could timeout waiting for corresponding write() operation to
+ * occur if needed */
+ try {
+ ret = read(localPtr, data, sz, timeout);
+ } catch (WolfSSLException e) {
+ /* JNI code may throw WolfSSLException on JNI specific errors */
+ throw new SocketException(e.getMessage());
+ }
+
+ WolfSSLDebug.log(getClass(), WolfSSLDebug.Component.JNI,
+ WolfSSLDebug.INFO, localPtr, "read() ret: " + ret +
+ ", err: " + getError(ret));
+
+ throwExceptionFromIOReturnValue(ret, "wolfSSL_read()");
+
+ return ret;
+ }
+
/**
* Waits for an SSL client to initiate the SSL/TLS handshake.
* This method is called on the server side. When it is called, the
diff --git a/src/java/com/wolfssl/provider/jsse/WolfSSLEngine.java b/src/java/com/wolfssl/provider/jsse/WolfSSLEngine.java
index 368a3dd5..12d61fca 100644
--- a/src/java/com/wolfssl/provider/jsse/WolfSSLEngine.java
+++ b/src/java/com/wolfssl/provider/jsse/WolfSSLEngine.java
@@ -818,15 +818,23 @@ private synchronized int RecvAppData(ByteBuffer[] out, int ofst, int length)
int ret = 0;
int idx = 0; /* index into out[] array */
int err = 0;
- byte[] tmp;
+ byte[] tmp = null;
- /* create read buffer of max output size */
+ /* Calculate maximum output size across ByteBuffer arrays */
maxOutSz = getTotalOutputSize(out, ofst, length);
- tmp = new byte[maxOutSz];
synchronized (ioLock) {
try {
- ret = this.ssl.read(tmp, maxOutSz);
+ /* If we only have one ByteBuffer, skip allocating
+ * separate intermediate byte[] and write directly to underlying
+ * ByteBuffer array */
+ if (out.length == 1) {
+ ret = this.ssl.read(out[0], maxOutSz, 0);
+ }
+ else {
+ tmp = new byte[maxOutSz];
+ ret = this.ssl.read(tmp, maxOutSz);
+ }
} catch (SocketTimeoutException | SocketException e) {
throw new SSLException(e);
}
@@ -883,27 +891,32 @@ private synchronized int RecvAppData(ByteBuffer[] out, int ofst, int length)
}
}
else {
- /* write processed data into output buffers */
- for (i = 0; i < ret;) {
- if (idx + ofst >= length) {
- /* no more output buffers left */
- break;
- }
+ if (out.length == 1) {
+ totalRead = ret;
+ }
+ else {
+ /* write processed data into output buffers */
+ for (i = 0; i < ret;) {
+ if (idx + ofst >= length) {
+ /* no more output buffers left */
+ break;
+ }
- bufSpace = out[idx + ofst].remaining();
- if (bufSpace == 0) {
- /* no more space in current out buffer, advance */
- idx++;
- continue;
- }
+ bufSpace = out[idx + ofst].remaining();
+ if (bufSpace == 0) {
+ /* no more space in current out buffer, advance */
+ idx++;
+ continue;
+ }
- sz = (bufSpace >= (ret - i)) ? (ret - i) : bufSpace;
- out[idx + ofst].put(tmp, i, sz);
- i += sz;
- totalRead += sz;
+ sz = (bufSpace >= (ret - i)) ? (ret - i) : bufSpace;
+ out[idx + ofst].put(tmp, i, sz);
+ i += sz;
+ totalRead += sz;
- if ((ret - i) > 0) {
- idx++; /* go to next output buffer */
+ if ((ret - i) > 0) {
+ idx++; /* go to next output buffer */
+ }
}
}
}