Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Validate DTLS message before decryption #59

Open
wants to merge 14 commits into
base: master
Choose a base branch
from
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ class DtlsChannelHandler @JvmOverloads constructor(

private fun loadSession(result: DtlsServer.ReceiveResult.CidSessionMissing, msg: DatagramPacket, ctx: ChannelHandlerContext) {
sessionStore.read(result.cid)
.thenApplyAsync({ sessBuf -> dtlsServer.loadSession(sessBuf, msg.sender(), result.cid) }, ctx.executor())
.thenApplyAsync({ sessBuf -> dtlsServer.loadSession(sessBuf, msg.sender(), result.cid, msg.content().nioBuffer()) }, ctx.executor())
.whenComplete { isLoaded: Boolean?, _ ->
if (isLoaded == true) {
channelRead(ctx, msg)
Expand Down
2 changes: 2 additions & 0 deletions kotlin-mbedtls/src/main/kotlin/org/opencoap/ssl/MbedtlsApi.kt
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ internal object MbedtlsApi {
external fun mbedtls_ssl_get_peer_cid(sslContext: Pointer, enabled: Pointer, peerCid: Pointer, peerCidLen: Pointer): Int
external fun mbedtls_ssl_context_save(sslContext: Pointer, buf: ByteArray, bufLen: Int, outputLen: ByteArray): Int
external fun mbedtls_ssl_context_load(sslContext: Pointer, buf: ByteArray, len: Int): Int
external fun mbedtls_ssl_check_record(sslContext: Pointer, buf: Memory, bufLen: Int): Int
external fun mbedtls_ssl_conf_ca_chain(sslConfig: Pointer, caChain: Pointer, caCrl: Pointer?)
external fun mbedtls_ssl_conf_own_cert(sslConfig: Pointer, ownCert: Memory, pkKey: Pointer): Int
external fun mbedtls_ssl_set_mtu(sslContext: Pointer, mtu: Int)
Expand All @@ -96,6 +97,7 @@ internal object MbedtlsApi {
const val MBEDTLS_SSL_TRANSPORT_DATAGRAM = 1
const val MBEDTLS_SSL_VERIFY_NONE = 0
const val MBEDTLS_SSL_VERIFY_REQUIRED = 2
const val MBEDTLS_ERR_SSL_UNEXPECTED_RECORD = -0x6700

// ----- net_sockets.h -----
val MBEDTLS_ERR_NET_RECV_FAILED = -0x004C
Expand Down
21 changes: 21 additions & 0 deletions kotlin-mbedtls/src/main/kotlin/org/opencoap/ssl/SslContext.kt
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
package org.opencoap.ssl

import com.sun.jna.Memory
import org.opencoap.ssl.MbedtlsApi.MBEDTLS_ERR_SSL_UNEXPECTED_RECORD
import org.opencoap.ssl.MbedtlsApi.mbedtls_ssl_close_notify
import org.opencoap.ssl.MbedtlsApi.mbedtls_ssl_context_save
import org.opencoap.ssl.MbedtlsApi.mbedtls_ssl_free
Expand All @@ -27,6 +28,7 @@ import org.opencoap.ssl.MbedtlsApi.mbedtls_ssl_handshake
import org.opencoap.ssl.MbedtlsApi.mbedtls_ssl_read
import org.opencoap.ssl.MbedtlsApi.mbedtls_ssl_write
import org.opencoap.ssl.MbedtlsApi.verify
import org.opencoap.ssl.transport.cloneToMemory
import org.opencoap.ssl.transport.toHex
import org.slf4j.LoggerFactory
import java.io.Closeable
Expand Down Expand Up @@ -164,6 +166,20 @@ class SslSession internal constructor(
plainBuffer.limit(size + plainBuffer.position())
}

fun checkRecord(encBuffer: ByteBuffer): VerificationResult {
val memory = encBuffer.cloneToMemory()
szysas marked this conversation as resolved.
Show resolved Hide resolved
try {
val result = MbedtlsApi.mbedtls_ssl_check_record(sslContext, memory, memory.size().toInt())
return if (result == 0 || result != MBEDTLS_ERR_SSL_UNEXPECTED_RECORD) {
VerificationResult.Valid("Success")
} else {
VerificationResult.Invalid(SslException.from(result).localizedMessage)
}
} finally {
memory.close()
}
}

fun decrypt(encBuffer: ByteBuffer, send: (ByteBuffer) -> Unit): ByteBuffer {
val buf = ByteBuffer.allocate(encBuffer.remaining())
decrypt(encBuffer, buf, send)
Expand Down Expand Up @@ -215,4 +231,9 @@ class SslSession internal constructor(
override fun close() {
mbedtls_ssl_free(sslContext)
}

sealed interface VerificationResult {
data class Valid(val message: String) : VerificationResult
data class Invalid(val message: String) : VerificationResult
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

package org.opencoap.ssl.transport

import com.sun.jna.Memory
import java.nio.ByteBuffer

internal fun ByteArray.toHex(): String {
Expand All @@ -29,6 +30,16 @@ fun ByteBuffer.copy(): ByteBuffer {
return bb
}

fun ByteBuffer.cloneToMemory(): Memory {
JuhaPekkaa marked this conversation as resolved.
Show resolved Hide resolved
this.mark() // saves the original position
val remaining = this.remaining()
val memory = Memory(remaining.toLong())
val intermediateBuffer: ByteBuffer = memory.getByteBuffer(0, remaining.toLong())
intermediateBuffer.put(this)
this.reset()
return memory
}

fun ByteBuffer.isNotEmpty(): Boolean = this.hasRemaining()
fun ByteBuffer.isEmpty(): Boolean = !this.hasRemaining()

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,12 @@ class DtlsServer(
private fun closeSession(addr: InetSocketAddress) {
sessions.remove(addr)?.apply {
storeAndClose()
logger.info("[{}] [CID:{}] DTLS session was stored", peerAddress, (this as? DtlsSession)?.sessionContext?.cid?.toHex() ?: "na")
logger.info(
"[{}] [CID:{}] DTLS session was stored",
peerAddress,
(this as? DtlsSession)?.sessionContext?.cid?.toHex()
?: "na"
)
}
}

Expand All @@ -138,18 +143,25 @@ class DtlsServer(
updateSessionAuthenticationContext(adr, ctx.authenticationContext)
}

fun loadSession(sessBuf: SessionWithContext?, adr: InetSocketAddress, cid: ByteArray): Boolean {
fun loadSession(sessBuf: SessionWithContext?, adr: InetSocketAddress, cid: ByteArray, dtlsPacket: ByteBuffer): Boolean {
return try {
if (sessBuf == null) {
logger.warn("[{}] [CID:{}] DTLS session not found", adr, cid.toHex())
reportMessageDrop(adr)
false
} else {
sessions[adr] = DtlsSession(sslConfig.loadSession(cid, sessBuf.sessionBlob, adr), adr, sessBuf.authenticationContext, sessBuf.sessionStartTimestamp)
true
return false
}

val sslSession = sslConfig.loadSession(cid, sessBuf.sessionBlob, adr)
val verificationResult = sslSession.checkRecord(dtlsPacket)
if (verificationResult is SslSession.VerificationResult.Invalid) {
logger.warn("[{}] [CID:{}] Record verification failed: {}", adr, cid.toHex(), verificationResult.message)
reportMessageDrop(adr)
return false
}
} catch (ex: SslException) {
logger.warn("[{}] [CID:{}] Failed to load session: {}", adr, cid.toHex(), ex.message)
sessions[adr] = DtlsSession(sslSession, adr, sessBuf.authenticationContext, sessBuf.sessionStartTimestamp)
true
} catch (ex: Exception) {
logger.error("[{}] [CID:{}] DTLS failed to load session: {}", adr, cid.toHex(), ex.message)
reportMessageDrop(adr)
false
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ class DtlsServerTransport private constructor(
val copyBuf = buf.copy()

sessionStore.read(result.cid).thenApplyAsync(
{ sessBuf -> dtlsServer.loadSession(sessBuf, adr, result.cid) },
{ sessBuf -> dtlsServer.loadSession(sessBuf, adr, result.cid, copyBuf) },
executor
).thenCompose { isLoaded ->
if (isLoaded) {
Expand Down
23 changes: 23 additions & 0 deletions kotlin-mbedtls/src/test/kotlin/org/opencoap/ssl/SslContextTest.kt
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,29 @@ class SslContextTest {
assertEquals("perse", serverSession.decrypt(encryptedDtls2, noSend).decodeToString())
}

@Test
fun `should check record is valid authentic and decrypt`() {
val clientSession = clientConf.loadSession(byteArrayOf(), StoredSessionPair.cliSession, localAddress(2_5684))
val serverSession = serverConf.loadSession(byteArrayOf(), StoredSessionPair.srvSession, localAddress(1_5684))

val encryptedDtls = clientSession.encrypt("auto".toByteBuffer())

val verificationResult = serverSession.checkRecord(encryptedDtls)
assertTrue(verificationResult is SslSession.VerificationResult.Valid)
assertEquals("auto", serverSession.decrypt(encryptedDtls, noSend).decodeToString())
}

@Test
fun `should check record is invalid when record is unexpected and replayed`() {
val clientSession = clientConf.loadSession(byteArrayOf(), StoredSessionPair.cliSession, localAddress(2_5684))
val serverSession = serverConf.loadSession(byteArrayOf(), StoredSessionPair.srvSession, localAddress(1_5684))
val encryptedDtls = clientSession.encrypt("auto".toByteBuffer())

serverSession.decrypt(encryptedDtls, noSend)
val result = serverSession.checkRecord(encryptedDtls.rewind() as ByteBuffer)
assertTrue(result is SslSession.VerificationResult.Invalid)
}

@Test
fun `should exchange data with direct byte buffer`() {
val clientSession = clientConf.loadSession(byteArrayOf(), StoredSessionPair.cliSession, localAddress(2_5684))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@

package org.opencoap.ssl.transport

import com.sun.jna.Memory
import org.junit.jupiter.api.Assertions.assertArrayEquals
import org.junit.jupiter.api.Assertions.assertEquals
import org.junit.jupiter.api.Assertions.assertTrue
import org.junit.jupiter.api.Test
Expand Down Expand Up @@ -70,4 +72,24 @@ class BytesExtensionsTest {
buf.position(2)
assertEquals("dupa", buf.decodeToString())
}

@Test
fun `should clone buffer to memory`() {
val originalData = byteArrayOf(1, 2, 3, 4, 5)
val byteBuffer = ByteBuffer.wrap(originalData)

val originalPosition = byteBuffer.position()
val originalLimit = byteBuffer.limit()
val originalCapacity = byteBuffer.capacity()

val memory: Memory = byteBuffer.cloneToMemory()

val clonedData = ByteArray(originalData.size)
memory.read(0, clonedData, 0, originalData.size)
assertArrayEquals(originalData, clonedData)

assertEquals(originalPosition, byteBuffer.position(), "Buffer position should not change")
assertEquals(originalLimit, byteBuffer.limit(), "Buffer limit should not change")
assertEquals(originalCapacity, byteBuffer.capacity(), "Buffer capacity should not change")
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ class DtlsServerTest {
assertTrue(dtlsServer.handleReceived(localAddress(2_5684), dtlsPacket) is ReceiveResult.CidSessionMissing)

// when
dtlsServer.loadSession(SessionWithContext(StoredSessionPair.srvSession, mapOf(), Instant.ofEpochSecond(123456789)), localAddress(2_5684), "f935adc57425e1b214f8640d56e0c733".decodeHex())
dtlsServer.loadSession(SessionWithContext(StoredSessionPair.srvSession, mapOf(), Instant.ofEpochSecond(123456789)), localAddress(2_5684), "f935adc57425e1b214f8640d56e0c733".decodeHex(), dtlsPacket)

// then
val dtlsPacketIn = (dtlsServer.handleReceived(localAddress(2_5684), dtlsPacket) as ReceiveResult.Decrypted).packet
Expand Down
Loading