diff --git a/native/com_wolfssl_WolfSSLSession.c b/native/com_wolfssl_WolfSSLSession.c index ed9d909e..5b34262a 100644 --- a/native/com_wolfssl_WolfSSLSession.c +++ b/native/com_wolfssl_WolfSSLSession.c @@ -1002,16 +1002,104 @@ JNIEXPORT jint JNICALL Java_com_wolfssl_WolfSSLSession_write } } -JNIEXPORT jint JNICALL Java_com_wolfssl_WolfSSLSession_read - (JNIEnv* jenv, jobject jcl, jlong sslPtr, jbyteArray raw, jint offset, - jint length, jint timeout) +/** + * Read len bytes from wolfSSL_read() back into provided output buffer. + * + * Internal function called by WolfSSLSession.read() calls. + * + * If wolfSSL_get_fd(ssl) returns a socket descriptor, try to wait for + * data with select()/poll() up to provided timeout. + * + * Returns number of bytes read on success, or negative on error. + */ +static int SSLReadNonblockingWithSelectPoll(WOLFSSL* ssl, byte* out, + int length, int timeout) { - byte* data = NULL; - int size = 0, ret, err, sockfd; + int size, ret, err, sockfd; int pollRx = 0; int pollTx = 0; wolfSSL_Mutex* jniSessLock = NULL; SSLAppData* appData = NULL; + + if (ssl == NULL || out == NULL) { + return BAD_FUNC_ARG; + } + + /* get session mutex from SSL app data */ + appData = (SSLAppData*)wolfSSL_get_app_data(ssl); + if (appData == NULL) { + return WOLFSSL_FAILURE; + } + + jniSessLock = appData->jniSessLock; + if (jniSessLock == NULL) { + return WOLFSSL_FAILURE; + } + + do { + /* lock mutex around session I/O before read attempt */ + if (wc_LockMutex(jniSessLock) != 0) { + size = WOLFSSL_FAILURE; + break; + } + + size = wolfSSL_read(ssl, out, length); + err = wolfSSL_get_error(ssl, size); + + /* unlock mutex around session I/O after read attempt */ + if (wc_UnLockMutex(jniSessLock) != 0) { + size = WOLFSSL_FAILURE; + break; + } + + if (size < 0 && + ((err == SSL_ERROR_WANT_READ) || (err == SSL_ERROR_WANT_WRITE))) { + + sockfd = wolfSSL_get_fd(ssl); + if (sockfd == -1) { + /* For I/O that does not use sockets, sockfd may be -1, + * skip try to call select() */ + break; + } + + if (err == SSL_ERROR_WANT_READ) { + pollRx = 1; + } + else if (err == SSL_ERROR_WANT_WRITE) { + pollTx = 1; + } + + #if defined(WOLFJNI_USE_IO_SELECT) || defined(USE_WINDOWS_API) + ret = socketSelect(sockfd, timeout, pollRx); + #else + ret = socketPoll(sockfd, timeout, pollRx, pollTx); + #endif + if ((ret == WOLFJNI_IO_EVENT_RECV_READY) || + (ret == WOLFJNI_IO_EVENT_SEND_READY)) { + /* loop around and try wolfSSL_read() again */ + continue; + } else { + /* Java will throw SocketTimeoutException or + * SocketException if ret equals + * WOLFJNI_IO_EVENT_TIMEOUT, WOLFJNI_IO_EVENT_FD_CLOSED + * WOLFJNI_IO_EVENT_ERROR, WOLFJNI_IO_EVENT_POLLHUP or + * WOLFJNI_IO_EVENT_FAIL */ + size = ret; + break; + } + } + + } while (err == SSL_ERROR_WANT_WRITE || err == SSL_ERROR_WANT_READ); + + return size; +} + +JNIEXPORT jint JNICALL Java_com_wolfssl_WolfSSLSession_read__J_3BIII + (JNIEnv* jenv, jobject jcl, jlong sslPtr, jbyteArray raw, jint offset, + jint length, jint timeout) +{ + int size = 0; + byte* data = NULL; WOLFSSL* ssl = (WOLFSSL*)(uintptr_t)sslPtr; (void)jcl; @@ -1027,79 +1115,178 @@ JNIEXPORT jint JNICALL Java_com_wolfssl_WolfSSLSession_read return SSL_FAILURE; } - /* get session mutex from SSL app data */ - appData = (SSLAppData*)wolfSSL_get_app_data(ssl); - if (appData == NULL) { + size = SSLReadNonblockingWithSelectPoll(ssl, data + offset, + (int)length, (int)timeout); + + if (size < 0) { (*jenv)->ReleaseByteArrayElements(jenv, raw, (jbyte*)data, - JNI_ABORT); - return WOLFSSL_FAILURE; + JNI_ABORT); } + else { + /* JNI_COMMIT commits the data but does not free the local array + * 0 is used here to both commit and free */ + (*jenv)->ReleaseByteArrayElements(jenv, raw, (jbyte*)data, 0); + } + } - jniSessLock = appData->jniSessLock; - if (jniSessLock == NULL) { - (*jenv)->ReleaseByteArrayElements(jenv, raw, (jbyte*)data, - JNI_ABORT); - return WOLFSSL_FAILURE; + return size; +} + +JNIEXPORT jint JNICALL Java_com_wolfssl_WolfSSLSession_read__JLjava_nio_ByteBuffer_2II + (JNIEnv* jenv, jobject jcl, jlong sslPtr, jobject buf, jint length, jint timeout) +{ + int size = 0; + int maxOutputSz; + int outSz = length; + byte* data = NULL; + WOLFSSL* ssl = (WOLFSSL*)(uintptr_t)sslPtr; + + jclass excClass; + jclass buffClass; + jmethodID positionMeth; + jmethodID limitMeth; + jmethodID hasArrayMeth; + jmethodID arrayMeth; + jmethodID setPositionMeth; + + jint position; + jint limit; + jboolean hasArray; + jbyteArray bufArr; + + (void)jcl; + + if (jenv == NULL || ssl == NULL || buf == NULL) { + return BAD_FUNC_ARG; + } + + if (length > 0) { + /* Get WolfSSLException class */ + excClass = (*jenv)->FindClass(jenv, "com/wolfssl/WolfSSLException"); + if ((*jenv)->ExceptionOccurred(jenv)) { + (*jenv)->ExceptionDescribe(jenv); + (*jenv)->ExceptionClear(jenv); + return -1; } - do { - /* lock mutex around session I/O before read attempt */ - if (wc_LockMutex(jniSessLock) != 0) { - size = WOLFSSL_FAILURE; - break; - } + /* Get ByteBuffer class */ + buffClass = (*jenv)->GetObjectClass(jenv, buf); + if (buffClass == NULL) { + (*jenv)->ThrowNew(jenv, excClass, + "Failed to find ByteBuffer class in native read()"); + return -1; + } - size = wolfSSL_read(ssl, data + offset, length); - err = wolfSSL_get_error(ssl, size); + /* Get ByteBuffer position */ + positionMeth = (*jenv)->GetMethodID(jenv, buffClass, "position", "()I"); + if (positionMeth == NULL) { + if ((*jenv)->ExceptionOccurred(jenv)) { + (*jenv)->ExceptionDescribe(jenv); + (*jenv)->ExceptionClear(jenv); + } + (*jenv)->ThrowNew(jenv, excClass, + "Failed to find ByteBuffer position() method in native read()"); + return -1; + } + position = (*jenv)->CallIntMethod(jenv, buf, positionMeth); - /* unlock mutex around session I/O after read attempt */ - if (wc_UnLockMutex(jniSessLock) != 0) { - size = WOLFSSL_FAILURE; - break; + /* Get ByteBuffer limit */ + limitMeth = (*jenv)->GetMethodID(jenv, buffClass, "limit", "()I"); + if (limitMeth == NULL) { + if ((*jenv)->ExceptionOccurred(jenv)) { + (*jenv)->ExceptionDescribe(jenv); + (*jenv)->ExceptionClear(jenv); } + (*jenv)->ThrowNew(jenv, excClass, + "Failed to find ByteBuffer limit() method in native read()"); + return -1; + } + limit = (*jenv)->CallIntMethod(jenv, buf, limitMeth); - if (size < 0 && ((err == SSL_ERROR_WANT_READ) || \ - (err == SSL_ERROR_WANT_WRITE))) { + /* Get and call ByteBuffer.hasArray() before calling array() */ + hasArrayMeth = (*jenv)->GetMethodID(jenv, buffClass, "hasArray", "()Z"); + if (hasArrayMeth == NULL) { + if ((*jenv)->ExceptionOccurred(jenv)) { + (*jenv)->ExceptionDescribe(jenv); + (*jenv)->ExceptionClear(jenv); + } + (*jenv)->ThrowNew(jenv, excClass, + "Failed to find ByteBuffer hasArray() method in native read()"); + return -1; + } - sockfd = wolfSSL_get_fd(ssl); - if (sockfd == -1) { - /* For I/O that does not use sockets, sockfd may be -1, - * skip try to call select() */ - break; - } + /* ByteBuffer.hasArray() does not throw any exceptions */ + hasArray = (*jenv)->CallBooleanMethod(jenv, buf, hasArrayMeth); + if (!hasArray) { + (*jenv)->ThrowNew(jenv, excClass, + "ByteBuffer.hasArray() is false in native read()"); + return BAD_FUNC_ARG; + } - if (err == SSL_ERROR_WANT_READ) { - pollRx = 1; - } - else if (err == SSL_ERROR_WANT_WRITE) { - pollTx = 1; - } + /* Only read up to maximum space we have in this ByteBuffer */ + maxOutputSz = (limit - position); + if (outSz > maxOutputSz) { + outSz = maxOutputSz; + } - #if defined(WOLFJNI_USE_IO_SELECT) || defined(USE_WINDOWS_API) - ret = socketSelect(sockfd, (int)timeout, pollRx); - #else - ret = socketPoll(sockfd, (int)timeout, pollRx, pollTx); - #endif - if ((ret == WOLFJNI_IO_EVENT_RECV_READY) || - (ret == WOLFJNI_IO_EVENT_SEND_READY)) { - /* loop around and try wolfSSL_read() again */ - continue; - } else { - /* Java will throw SocketTimeoutException or - * SocketException if ret equals - * WOLFJNI_IO_EVENT_TIMEOUT, WOLFJNI_IO_EVENT_FD_CLOSED - * WOLFJNI_IO_EVENT_ERROR, WOLFJNI_IO_EVENT_POLLHUP or - * WOLFJNI_IO_EVENT_FAIL */ - size = ret; - break; - } + /* Get reference to underlying byte[] from ByteBuffer */ + arrayMeth = (*jenv)->GetMethodID(jenv, buffClass, "array", "()[B"); + if (arrayMeth == NULL) { + if ((*jenv)->ExceptionOccurred(jenv)) { + (*jenv)->ExceptionDescribe(jenv); + (*jenv)->ExceptionClear(jenv); } + (*jenv)->ThrowNew(jenv, excClass, + "Failed to find ByteBuffer array() method in native read()"); + return -1; + } + bufArr = (jbyteArray)(*jenv)->CallObjectMethod(jenv, buf, arrayMeth); - } while (err == SSL_ERROR_WANT_WRITE || err == SSL_ERROR_WANT_READ); + /* Get array elements */ + data = (byte*)(*jenv)->GetByteArrayElements(jenv, bufArr, NULL); + if ((*jenv)->ExceptionOccurred(jenv)) { + (*jenv)->ExceptionDescribe(jenv); + (*jenv)->ExceptionClear(jenv); + (*jenv)->ThrowNew(jenv, excClass, + "Exception when calling ByteBuffer.array() in native read()"); + return -1; + } + + + if (data != NULL) { + size = SSLReadNonblockingWithSelectPoll(ssl, data + position, + maxOutputSz, (int)timeout); - /* JNI_COMMIT commits the data but does not free the local array - * 0 is used here to both commit and free */ - (*jenv)->ReleaseByteArrayElements(jenv, raw, (jbyte*)data, 0); + /* Relase array elements */ + if (size < 0) { + (*jenv)->ReleaseByteArrayElements(jenv, bufArr, (jbyte*)data, + JNI_ABORT); + } + else { + /* JNI_COMMIT commits the data but does not free the local array + * 0 is used here to both commit and free */ + (*jenv)->ReleaseByteArrayElements(jenv, bufArr, + (jbyte*)data, 0); + + /* Update ByteBuffer position() based on bytes written */ + setPositionMeth = (*jenv)->GetMethodID(jenv, buffClass, + "position", "(I)Ljava/nio/Buffer;"); + if (setPositionMeth == NULL) { + if ((*jenv)->ExceptionOccurred(jenv)) { + (*jenv)->ExceptionDescribe(jenv); + (*jenv)->ExceptionClear(jenv); + } + (*jenv)->ThrowNew(jenv, excClass, + "Failed to set ByteBuffer position() from " + "native read()"); + size = -1; + } + else { + (*jenv)->CallVoidMethod(jenv, buf, setPositionMeth, + position + size); + } + } + } } return size; diff --git a/native/com_wolfssl_WolfSSLSession.h b/native/com_wolfssl_WolfSSLSession.h index e7941971..639db768 100644 --- a/native/com_wolfssl_WolfSSLSession.h +++ b/native/com_wolfssl_WolfSSLSession.h @@ -100,9 +100,17 @@ JNIEXPORT jint JNICALL Java_com_wolfssl_WolfSSLSession_write * Method: read * Signature: (J[BIII)I */ -JNIEXPORT jint JNICALL Java_com_wolfssl_WolfSSLSession_read +JNIEXPORT jint JNICALL Java_com_wolfssl_WolfSSLSession_read__J_3BIII (JNIEnv *, jobject, jlong, jbyteArray, jint, jint, jint); +/* + * Class: com_wolfssl_WolfSSLSession + * Method: read + * Signature: (JLjava/nio/ByteBuffer;II)I + */ +JNIEXPORT jint JNICALL Java_com_wolfssl_WolfSSLSession_read__JLjava_nio_ByteBuffer_2II + (JNIEnv *, jobject, jlong, jobject, jint, jint); + /* * Class: com_wolfssl_WolfSSLSession * Method: accept diff --git a/src/java/com/wolfssl/WolfSSLSession.java b/src/java/com/wolfssl/WolfSSLSession.java index cb5d2a05..659f1ce1 100644 --- a/src/java/com/wolfssl/WolfSSLSession.java +++ b/src/java/com/wolfssl/WolfSSLSession.java @@ -28,6 +28,7 @@ import java.net.SocketException; import java.net.SocketTimeoutException; import java.lang.StringBuilder; +import java.nio.ByteBuffer; import java.nio.charset.StandardCharsets; import com.wolfssl.WolfSSLException; @@ -253,6 +254,8 @@ private native int write(long ssl, byte[] data, int offset, int length, int timeout); private native int read(long ssl, byte[] data, int offset, int sz, int timeout); + private native int read(long ssl, ByteBuffer data, int sz, int timeout) + throws WolfSSLException; private native int accept(long ssl, int timeout); private native void freeSSL(long ssl); private native int shutdownSSL(long ssl, int timeout); @@ -1112,6 +1115,86 @@ public int read(byte[] data, int offset, int sz, int timeout) return ret; } + /** + * Reads bytes from the SSL session and returns the read bytes into + * the provided ByteBuffer, using socket timeout value in milliseconds. + * + * The bytes read are removed from the internal receive buffer. + *

+ * If necessary, read() will negotiate an SSL/TLS session + * if the handshake has not already been performed yet by connect() + * or accept(). + *

+ * The SSL/TLS protocol uses SSL records which have a maximum size of + * 16kB. As such, wolfSSL needs to read an entire SSL record internally + * before it is able to process and decrypt the record. Because of this, + * a call to read() will only be able to return the + * maximum buffer size which has been decrypted at the time of calling. + * There may be additional not-yet-decrypted data waiting in the internal + * wolfSSL receive buffer which will be retrieved and decrypted with the + * next call to read(). + * + * @param data ByteBuffer where the data read from the SSL connection + * will be placed. position() will be updated after this + * method writes data to the ByteBuffer. + * @param sz number of bytes to read into data, + * may be adjusted to the maximum space in data if that is + * smaller than this size. + * @param timeout read timeout, milliseconds. + * @return the number of bytes read upon success. SSL_FAILURE + * will be returned upon failure which may be caused + * by either a clean (close notify alert) shutdown or just + * that the peer closed the connection. + * SSL_FATAL_ERROR upon failure when either an error + * occurred or, when using non-blocking sockets, the + * SSL_ERROR_WANT_READ or SSL_ERROR_WANT_WRITE + * error was received and the application needs to call + * read() again. Use getError to + * get a specific error code. + * BAD_FUNC_ARC when bad arguments are used. + * @throws IllegalStateException WolfSSLContext has been freed + * @throws SocketTimeoutException if socket timeout occurs + * @throws SocketException Native socket select/poll() failed + */ + public int read(ByteBuffer data, int sz, int timeout) + throws IllegalStateException, SocketTimeoutException, SocketException { + + int ret; + long localPtr; + + confirmObjectIsActive(); + + /* Fix for Infer scan, since not synchronizing on sslLock for + * access to this.sslPtr, see note below */ + synchronized (sslLock) { + localPtr = this.sslPtr; + } + + WolfSSLDebug.log(getClass(), WolfSSLDebug.Component.JNI, + WolfSSLDebug.INFO, localPtr, "entered read(ByteBuffer, " + + "sz: " + sz + ", timeout: " + timeout + ")"); + + /* not synchronizing on sslLock here since JNI read() locks + * session mutex around native wolfSSL_read() call. If sslLock + * is locked here, since we call select() inside native JNI we + * could timeout waiting for corresponding write() operation to + * occur if needed */ + try { + ret = read(localPtr, data, sz, timeout); + } catch (WolfSSLException e) { + /* JNI code may throw WolfSSLException on JNI specific errors */ + throw new SocketException(e.getMessage()); + } + + WolfSSLDebug.log(getClass(), WolfSSLDebug.Component.JNI, + WolfSSLDebug.INFO, localPtr, "read() ret: " + ret + + ", err: " + getError(ret)); + + throwExceptionFromIOReturnValue(ret, "wolfSSL_read()"); + + return ret; + } + /** * Waits for an SSL client to initiate the SSL/TLS handshake. * This method is called on the server side. When it is called, the diff --git a/src/java/com/wolfssl/provider/jsse/WolfSSLEngine.java b/src/java/com/wolfssl/provider/jsse/WolfSSLEngine.java index 368a3dd5..12d61fca 100644 --- a/src/java/com/wolfssl/provider/jsse/WolfSSLEngine.java +++ b/src/java/com/wolfssl/provider/jsse/WolfSSLEngine.java @@ -818,15 +818,23 @@ private synchronized int RecvAppData(ByteBuffer[] out, int ofst, int length) int ret = 0; int idx = 0; /* index into out[] array */ int err = 0; - byte[] tmp; + byte[] tmp = null; - /* create read buffer of max output size */ + /* Calculate maximum output size across ByteBuffer arrays */ maxOutSz = getTotalOutputSize(out, ofst, length); - tmp = new byte[maxOutSz]; synchronized (ioLock) { try { - ret = this.ssl.read(tmp, maxOutSz); + /* If we only have one ByteBuffer, skip allocating + * separate intermediate byte[] and write directly to underlying + * ByteBuffer array */ + if (out.length == 1) { + ret = this.ssl.read(out[0], maxOutSz, 0); + } + else { + tmp = new byte[maxOutSz]; + ret = this.ssl.read(tmp, maxOutSz); + } } catch (SocketTimeoutException | SocketException e) { throw new SSLException(e); } @@ -883,27 +891,32 @@ private synchronized int RecvAppData(ByteBuffer[] out, int ofst, int length) } } else { - /* write processed data into output buffers */ - for (i = 0; i < ret;) { - if (idx + ofst >= length) { - /* no more output buffers left */ - break; - } + if (out.length == 1) { + totalRead = ret; + } + else { + /* write processed data into output buffers */ + for (i = 0; i < ret;) { + if (idx + ofst >= length) { + /* no more output buffers left */ + break; + } - bufSpace = out[idx + ofst].remaining(); - if (bufSpace == 0) { - /* no more space in current out buffer, advance */ - idx++; - continue; - } + bufSpace = out[idx + ofst].remaining(); + if (bufSpace == 0) { + /* no more space in current out buffer, advance */ + idx++; + continue; + } - sz = (bufSpace >= (ret - i)) ? (ret - i) : bufSpace; - out[idx + ofst].put(tmp, i, sz); - i += sz; - totalRead += sz; + sz = (bufSpace >= (ret - i)) ? (ret - i) : bufSpace; + out[idx + ofst].put(tmp, i, sz); + i += sz; + totalRead += sz; - if ((ret - i) > 0) { - idx++; /* go to next output buffer */ + if ((ret - i) > 0) { + idx++; /* go to next output buffer */ + } } } }