diff --git a/kotlin-mbedtls-netty/src/main/kotlin/org/opencoap/ssl/netty/DtlsChannelHandler.kt b/kotlin-mbedtls-netty/src/main/kotlin/org/opencoap/ssl/netty/DtlsChannelHandler.kt index 8afb963..343a0ea 100644 --- a/kotlin-mbedtls-netty/src/main/kotlin/org/opencoap/ssl/netty/DtlsChannelHandler.kt +++ b/kotlin-mbedtls-netty/src/main/kotlin/org/opencoap/ssl/netty/DtlsChannelHandler.kt @@ -77,7 +77,7 @@ class DtlsChannelHandler @JvmOverloads constructor( private fun loadSession(result: DtlsServer.ReceiveResult.CidSessionMissing, msg: DatagramPacket, ctx: ChannelHandlerContext) { sessionStore.read(result.cid) - .thenApplyAsync({ sessBuf -> dtlsServer.loadSession(sessBuf, msg.sender(), result.cid) }, ctx.executor()) + .thenApplyAsync({ sessBuf -> dtlsServer.loadSession(sessBuf, msg.sender(), result.cid, msg.content().nioBuffer()) }, ctx.executor()) .whenComplete { isLoaded: Boolean?, _ -> if (isLoaded == true) { channelRead(ctx, msg) diff --git a/kotlin-mbedtls/src/main/kotlin/org/opencoap/ssl/MbedtlsApi.kt b/kotlin-mbedtls/src/main/kotlin/org/opencoap/ssl/MbedtlsApi.kt index 63ecdf4..8825fc6 100644 --- a/kotlin-mbedtls/src/main/kotlin/org/opencoap/ssl/MbedtlsApi.kt +++ b/kotlin-mbedtls/src/main/kotlin/org/opencoap/ssl/MbedtlsApi.kt @@ -78,6 +78,7 @@ internal object MbedtlsApi { external fun mbedtls_ssl_get_peer_cid(sslContext: Pointer, enabled: Pointer, peerCid: Pointer, peerCidLen: Pointer): Int external fun mbedtls_ssl_context_save(sslContext: Pointer, buf: ByteArray, bufLen: Int, outputLen: ByteArray): Int external fun mbedtls_ssl_context_load(sslContext: Pointer, buf: ByteArray, len: Int): Int + external fun mbedtls_ssl_check_record(sslContext: Pointer, buf: Memory, bufLen: Int): Int external fun mbedtls_ssl_conf_ca_chain(sslConfig: Pointer, caChain: Pointer, caCrl: Pointer?) external fun mbedtls_ssl_conf_own_cert(sslConfig: Pointer, ownCert: Memory, pkKey: Pointer): Int external fun mbedtls_ssl_set_mtu(sslContext: Pointer, mtu: Int) @@ -96,6 +97,7 @@ internal object MbedtlsApi { const val MBEDTLS_SSL_TRANSPORT_DATAGRAM = 1 const val MBEDTLS_SSL_VERIFY_NONE = 0 const val MBEDTLS_SSL_VERIFY_REQUIRED = 2 + const val MBEDTLS_ERR_SSL_UNEXPECTED_RECORD = -0x6700 // ----- net_sockets.h ----- val MBEDTLS_ERR_NET_RECV_FAILED = -0x004C diff --git a/kotlin-mbedtls/src/main/kotlin/org/opencoap/ssl/SslContext.kt b/kotlin-mbedtls/src/main/kotlin/org/opencoap/ssl/SslContext.kt index 1ffa731..8349d15 100644 --- a/kotlin-mbedtls/src/main/kotlin/org/opencoap/ssl/SslContext.kt +++ b/kotlin-mbedtls/src/main/kotlin/org/opencoap/ssl/SslContext.kt @@ -17,6 +17,7 @@ package org.opencoap.ssl import com.sun.jna.Memory +import org.opencoap.ssl.MbedtlsApi.MBEDTLS_ERR_SSL_UNEXPECTED_RECORD import org.opencoap.ssl.MbedtlsApi.mbedtls_ssl_close_notify import org.opencoap.ssl.MbedtlsApi.mbedtls_ssl_context_save import org.opencoap.ssl.MbedtlsApi.mbedtls_ssl_free @@ -27,6 +28,7 @@ import org.opencoap.ssl.MbedtlsApi.mbedtls_ssl_handshake import org.opencoap.ssl.MbedtlsApi.mbedtls_ssl_read import org.opencoap.ssl.MbedtlsApi.mbedtls_ssl_write import org.opencoap.ssl.MbedtlsApi.verify +import org.opencoap.ssl.transport.cloneToMemory import org.opencoap.ssl.transport.toHex import org.slf4j.LoggerFactory import java.io.Closeable @@ -164,6 +166,20 @@ class SslSession internal constructor( plainBuffer.limit(size + plainBuffer.position()) } + fun checkRecord(encBuffer: ByteBuffer): VerificationResult { + val memory = encBuffer.cloneToMemory() + try { + val result = MbedtlsApi.mbedtls_ssl_check_record(sslContext, memory, memory.size().toInt()) + return if (result == 0 || result != MBEDTLS_ERR_SSL_UNEXPECTED_RECORD) { + VerificationResult.Valid("Success") + } else { + VerificationResult.Invalid(SslException.from(result).localizedMessage) + } + } finally { + memory.close() + } + } + fun decrypt(encBuffer: ByteBuffer, send: (ByteBuffer) -> Unit): ByteBuffer { val buf = ByteBuffer.allocate(encBuffer.remaining()) decrypt(encBuffer, buf, send) @@ -215,4 +231,9 @@ class SslSession internal constructor( override fun close() { mbedtls_ssl_free(sslContext) } + + sealed interface VerificationResult { + data class Valid(val message: String) : VerificationResult + data class Invalid(val message: String) : VerificationResult + } } diff --git a/kotlin-mbedtls/src/main/kotlin/org/opencoap/ssl/transport/BytesExtensions.kt b/kotlin-mbedtls/src/main/kotlin/org/opencoap/ssl/transport/BytesExtensions.kt index dace1ef..cfff70c 100644 --- a/kotlin-mbedtls/src/main/kotlin/org/opencoap/ssl/transport/BytesExtensions.kt +++ b/kotlin-mbedtls/src/main/kotlin/org/opencoap/ssl/transport/BytesExtensions.kt @@ -16,6 +16,7 @@ package org.opencoap.ssl.transport +import com.sun.jna.Memory import java.nio.ByteBuffer internal fun ByteArray.toHex(): String { @@ -29,6 +30,16 @@ fun ByteBuffer.copy(): ByteBuffer { return bb } +fun ByteBuffer.cloneToMemory(): Memory { + this.mark() // saves the original position + val remaining = this.remaining() + val memory = Memory(remaining.toLong()) + val intermediateBuffer: ByteBuffer = memory.getByteBuffer(0, remaining.toLong()) + intermediateBuffer.put(this) + this.reset() + return memory +} + fun ByteBuffer.isNotEmpty(): Boolean = this.hasRemaining() fun ByteBuffer.isEmpty(): Boolean = !this.hasRemaining() diff --git a/kotlin-mbedtls/src/main/kotlin/org/opencoap/ssl/transport/DtlsServer.kt b/kotlin-mbedtls/src/main/kotlin/org/opencoap/ssl/transport/DtlsServer.kt index d3c9e8b..26fc5ee 100644 --- a/kotlin-mbedtls/src/main/kotlin/org/opencoap/ssl/transport/DtlsServer.kt +++ b/kotlin-mbedtls/src/main/kotlin/org/opencoap/ssl/transport/DtlsServer.kt @@ -125,7 +125,12 @@ class DtlsServer( private fun closeSession(addr: InetSocketAddress) { sessions.remove(addr)?.apply { storeAndClose() - logger.info("[{}] [CID:{}] DTLS session was stored", peerAddress, (this as? DtlsSession)?.sessionContext?.cid?.toHex() ?: "na") + logger.info( + "[{}] [CID:{}] DTLS session was stored", + peerAddress, + (this as? DtlsSession)?.sessionContext?.cid?.toHex() + ?: "na" + ) } } @@ -138,18 +143,25 @@ class DtlsServer( updateSessionAuthenticationContext(adr, ctx.authenticationContext) } - fun loadSession(sessBuf: SessionWithContext?, adr: InetSocketAddress, cid: ByteArray): Boolean { + fun loadSession(sessBuf: SessionWithContext?, adr: InetSocketAddress, cid: ByteArray, dtlsPacket: ByteBuffer): Boolean { return try { if (sessBuf == null) { logger.warn("[{}] [CID:{}] DTLS session not found", adr, cid.toHex()) reportMessageDrop(adr) - false - } else { - sessions[adr] = DtlsSession(sslConfig.loadSession(cid, sessBuf.sessionBlob, adr), adr, sessBuf.authenticationContext, sessBuf.sessionStartTimestamp) - true + return false + } + + val sslSession = sslConfig.loadSession(cid, sessBuf.sessionBlob, adr) + val verificationResult = sslSession.checkRecord(dtlsPacket) + if (verificationResult is SslSession.VerificationResult.Invalid) { + logger.warn("[{}] [CID:{}] Record verification failed: {}", adr, cid.toHex(), verificationResult.message) + reportMessageDrop(adr) + return false } - } catch (ex: SslException) { - logger.warn("[{}] [CID:{}] Failed to load session: {}", adr, cid.toHex(), ex.message) + sessions[adr] = DtlsSession(sslSession, adr, sessBuf.authenticationContext, sessBuf.sessionStartTimestamp) + true + } catch (ex: Exception) { + logger.error("[{}] [CID:{}] DTLS failed to load session: {}", adr, cid.toHex(), ex.message) reportMessageDrop(adr) false } diff --git a/kotlin-mbedtls/src/main/kotlin/org/opencoap/ssl/transport/DtlsServerTransport.kt b/kotlin-mbedtls/src/main/kotlin/org/opencoap/ssl/transport/DtlsServerTransport.kt index 78eb06d..04445a0 100644 --- a/kotlin-mbedtls/src/main/kotlin/org/opencoap/ssl/transport/DtlsServerTransport.kt +++ b/kotlin-mbedtls/src/main/kotlin/org/opencoap/ssl/transport/DtlsServerTransport.kt @@ -79,7 +79,7 @@ class DtlsServerTransport private constructor( val copyBuf = buf.copy() sessionStore.read(result.cid).thenApplyAsync( - { sessBuf -> dtlsServer.loadSession(sessBuf, adr, result.cid) }, + { sessBuf -> dtlsServer.loadSession(sessBuf, adr, result.cid, copyBuf) }, executor ).thenCompose { isLoaded -> if (isLoaded) { diff --git a/kotlin-mbedtls/src/test/kotlin/org/opencoap/ssl/SslContextTest.kt b/kotlin-mbedtls/src/test/kotlin/org/opencoap/ssl/SslContextTest.kt index 612b001..68847e4 100644 --- a/kotlin-mbedtls/src/test/kotlin/org/opencoap/ssl/SslContextTest.kt +++ b/kotlin-mbedtls/src/test/kotlin/org/opencoap/ssl/SslContextTest.kt @@ -123,6 +123,29 @@ class SslContextTest { assertEquals("perse", serverSession.decrypt(encryptedDtls2, noSend).decodeToString()) } + @Test + fun `should check record is valid authentic and decrypt`() { + val clientSession = clientConf.loadSession(byteArrayOf(), StoredSessionPair.cliSession, localAddress(2_5684)) + val serverSession = serverConf.loadSession(byteArrayOf(), StoredSessionPair.srvSession, localAddress(1_5684)) + + val encryptedDtls = clientSession.encrypt("auto".toByteBuffer()) + + val verificationResult = serverSession.checkRecord(encryptedDtls) + assertTrue(verificationResult is SslSession.VerificationResult.Valid) + assertEquals("auto", serverSession.decrypt(encryptedDtls, noSend).decodeToString()) + } + + @Test + fun `should check record is invalid when record is unexpected and replayed`() { + val clientSession = clientConf.loadSession(byteArrayOf(), StoredSessionPair.cliSession, localAddress(2_5684)) + val serverSession = serverConf.loadSession(byteArrayOf(), StoredSessionPair.srvSession, localAddress(1_5684)) + val encryptedDtls = clientSession.encrypt("auto".toByteBuffer()) + + serverSession.decrypt(encryptedDtls, noSend) + val result = serverSession.checkRecord(encryptedDtls.rewind() as ByteBuffer) + assertTrue(result is SslSession.VerificationResult.Invalid) + } + @Test fun `should exchange data with direct byte buffer`() { val clientSession = clientConf.loadSession(byteArrayOf(), StoredSessionPair.cliSession, localAddress(2_5684)) diff --git a/kotlin-mbedtls/src/test/kotlin/org/opencoap/ssl/transport/BytesExtensionsTest.kt b/kotlin-mbedtls/src/test/kotlin/org/opencoap/ssl/transport/BytesExtensionsTest.kt index 4788007..cfc9a70 100644 --- a/kotlin-mbedtls/src/test/kotlin/org/opencoap/ssl/transport/BytesExtensionsTest.kt +++ b/kotlin-mbedtls/src/test/kotlin/org/opencoap/ssl/transport/BytesExtensionsTest.kt @@ -16,6 +16,8 @@ package org.opencoap.ssl.transport +import com.sun.jna.Memory +import org.junit.jupiter.api.Assertions.assertArrayEquals import org.junit.jupiter.api.Assertions.assertEquals import org.junit.jupiter.api.Assertions.assertTrue import org.junit.jupiter.api.Test @@ -70,4 +72,24 @@ class BytesExtensionsTest { buf.position(2) assertEquals("dupa", buf.decodeToString()) } + + @Test + fun `should clone buffer to memory`() { + val originalData = byteArrayOf(1, 2, 3, 4, 5) + val byteBuffer = ByteBuffer.wrap(originalData) + + val originalPosition = byteBuffer.position() + val originalLimit = byteBuffer.limit() + val originalCapacity = byteBuffer.capacity() + + val memory: Memory = byteBuffer.cloneToMemory() + + val clonedData = ByteArray(originalData.size) + memory.read(0, clonedData, 0, originalData.size) + assertArrayEquals(originalData, clonedData) + + assertEquals(originalPosition, byteBuffer.position(), "Buffer position should not change") + assertEquals(originalLimit, byteBuffer.limit(), "Buffer limit should not change") + assertEquals(originalCapacity, byteBuffer.capacity(), "Buffer capacity should not change") + } } diff --git a/kotlin-mbedtls/src/test/kotlin/org/opencoap/ssl/transport/DtlsServerTest.kt b/kotlin-mbedtls/src/test/kotlin/org/opencoap/ssl/transport/DtlsServerTest.kt index 3d27623..9fa0caa 100644 --- a/kotlin-mbedtls/src/test/kotlin/org/opencoap/ssl/transport/DtlsServerTest.kt +++ b/kotlin-mbedtls/src/test/kotlin/org/opencoap/ssl/transport/DtlsServerTest.kt @@ -85,7 +85,7 @@ class DtlsServerTest { assertTrue(dtlsServer.handleReceived(localAddress(2_5684), dtlsPacket) is ReceiveResult.CidSessionMissing) // when - dtlsServer.loadSession(SessionWithContext(StoredSessionPair.srvSession, mapOf(), Instant.ofEpochSecond(123456789)), localAddress(2_5684), "f935adc57425e1b214f8640d56e0c733".decodeHex()) + dtlsServer.loadSession(SessionWithContext(StoredSessionPair.srvSession, mapOf(), Instant.ofEpochSecond(123456789)), localAddress(2_5684), "f935adc57425e1b214f8640d56e0c733".decodeHex(), dtlsPacket) // then val dtlsPacketIn = (dtlsServer.handleReceived(localAddress(2_5684), dtlsPacket) as ReceiveResult.Decrypted).packet