Skip to content

Commit

Permalink
Support for DTLS session authentication context updates via outbound …
Browse files Browse the repository at this point in the history
…datagram context
  • Loading branch information
akolosov-n committed Aug 21, 2024
1 parent f8aa937 commit e773e04
Show file tree
Hide file tree
Showing 6 changed files with 81 additions and 17 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -92,11 +92,7 @@ class DtlsChannelHandler @JvmOverloads constructor(
when (msg) {
is DatagramPacketWithContext -> {
write(msg, promise, ctx)
if (msg.sessionContext.sessionExpirationHint) {
promise.toCompletableFuture().thenAccept {
dtlsServer.closeSession(msg.recipient())
}
}
handleDtlsContext(msg, promise)
}
is DatagramPacket -> write(msg, promise, ctx)
is SessionAuthenticationContext -> {
Expand All @@ -114,6 +110,22 @@ class DtlsChannelHandler @JvmOverloads constructor(
}
}

private fun handleDtlsContext(msg: DatagramPacketWithContext, promise: ChannelPromise) {
val sessCtx = msg.sessionContext
if (sessCtx.sessionSuspensionHint) {
promise.toCompletableFuture().thenAccept {
dtlsServer.closeSession(msg.recipient())
}
}
if (sessCtx.authenticationContext.isNotEmpty()) {
promise.toCompletableFuture().thenAccept {
sessCtx.authenticationContext.forEach { (key, value) ->
dtlsServer.putSessionAuthenticationContext(msg.recipient(), key, value)
}
}
}
}

private fun write(msg: DatagramPacket, promise: ChannelPromise, ctx: ChannelHandlerContext) {
msg.useAndRelease {
val plainContent = msg.content().nioBuffer()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,9 @@ class EchoHandler : ChannelInboundHandlerAdapter() {
val authContext = (sessionContext.authenticationContext["AUTH"] ?: "")
val dgramContent = dgram.content().toByteArray()
val goToSleep = dgramContent.toString(Charset.defaultCharset()).endsWith(":sleep")
val newAuthContext = dgramContent.toString(Charset.defaultCharset())
.takeIf { it.startsWith("auth:") }
?.substringAfter(":")

val reply = ctx.alloc().buffer(dgramContent.size + 20)
reply.writeBytes(echoPrefix)
Expand All @@ -43,7 +46,10 @@ class EchoHandler : ChannelInboundHandlerAdapter() {
reply,
dgram.sender(),
null,
sessionContext.copy(sessionExpirationHint = goToSleep)
sessionContext.copy(
authenticationContext = newAuthContext?.let { mapOf("AUTH" to it) } ?: emptyMap(),
sessionSuspensionHint = goToSleep
)
)
)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,25 @@ class NettyTest {
client.close()
}

@Test
fun `should forward authentication context passed inside outbound datagram`() {
// connect and handshake
val client = NettyTransportAdapter.connect(clientConf, srvAddress).mapToString()

assertTrue(client.send("hi").await())
assertEquals("ECHO:hi", client.receive(5.seconds).await())

// when
assertTrue(client.send("auth:007:").await())
assertEquals("ECHO:auth:007:", client.receive(5.seconds).await())

// then
assertTrue(client.send("hi").await())
assertEquals("ECHO:007:hi", client.receive(5.seconds).await())

client.close()
}

@Test
fun `should fail to forward authentication context for non existing client`() {
assertThatThrownBy {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -97,15 +97,21 @@ class DtlsServerTransport private constructor(

when {
encPacket == null -> completedFuture(false)
packet.sessionContext.sessionExpirationHint -> {
transport.send(encPacket).thenApply { isSuccess ->
if (isSuccess) {
dtlsServer.closeSession(packet.peerAddress)
}
isSuccess
else -> transport.send(encPacket)
}.thenApply { isSuccess ->
if (!isSuccess) return@thenApply false

val sessCtx = packet.sessionContext
if (sessCtx.sessionSuspensionHint) {
dtlsServer.closeSession(packet.peerAddress)
}
if (sessCtx.authenticationContext.isNotEmpty()) {
sessCtx.authenticationContext.forEach { (key, value) ->
dtlsServer.putSessionAuthenticationContext(packet.peerAddress, key, value)
}
}
else -> transport.send(encPacket)

true
}
}.thenCompose(Function.identity())

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ data class DtlsSessionContext @JvmOverloads constructor(
val peerCertificateSubject: String? = null,
val cid: ByteArray? = null,
val sessionStartTimestamp: Instant? = null,
val sessionExpirationHint: Boolean = false
val sessionSuspensionHint: Boolean = false
) {
companion object {
@JvmField
Expand All @@ -45,7 +45,7 @@ data class DtlsSessionContext @JvmOverloads constructor(
if (!cid.contentEquals(other.cid)) return false
} else if (other.cid != null) return false
if (sessionStartTimestamp != other.sessionStartTimestamp) return false
if (sessionExpirationHint != other.sessionExpirationHint) return false
if (sessionSuspensionHint != other.sessionSuspensionHint) return false

return true
}
Expand All @@ -55,7 +55,7 @@ data class DtlsSessionContext @JvmOverloads constructor(
result = 31 * result + (peerCertificateSubject?.hashCode() ?: 0)
result = 31 * result + (cid?.contentHashCode() ?: 0)
result = 31 * result + (sessionStartTimestamp?.hashCode() ?: 0)
result = 31 * result + (sessionExpirationHint.hashCode())
result = 31 * result + (sessionSuspensionHint.hashCode())
return result
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,14 @@ class DtlsServerTransportTest {
} else if (msg.startsWith("Authenticate:")) {
server.putSessionAuthenticationContext(packet.peerAddress, "auth", msg.substring(12))
server.send(Packet("OK".toByteBuffer(), packet.peerAddress))
} else if (msg.startsWith("AuthenticateWithContext:")) {
server.send(
Packet(
"OK".toByteBuffer(),
packet.peerAddress,
DtlsSessionContext(authenticationContext = mapOf("auth" to msg.substring(23)))
)
)
} else {
val ctx = (packet.sessionContext.authenticationContext["auth"] ?: "")
server.send(packet.map { "$msg:resp$ctx".toByteBuffer() })
Expand Down Expand Up @@ -479,6 +487,19 @@ class DtlsServerTransportTest {
client.close()
}

@Test
fun `should set and use session context passed inside outbound datagram`() {
server = DtlsServerTransport.create(conf, expireAfter = 100.millis, sessionStore = sessionStore, lifecycleCallbacks = sslLifecycleCallbacks).listen(echoHandler)
// client connected
val client = DtlsTransmitter.connect(server, clientConfig).await()
client.send("AuthenticateWithContext:dev-007")
assertEquals("OK", client.receiveString())
client.send("hi")
assertEquals("hi:resp:dev-007", client.receiveString())

client.close()
}

@Test
fun `server should store session if hinted to do so`() {
// given
Expand All @@ -491,7 +512,7 @@ class DtlsServerTransportTest {
assertEquals("dupa", client.receive(1.seconds).await())

client.send("sleep")
server.send(Packet("sleep".toByteBuffer(), serverReceived.await().peerAddress, sessionContext = DtlsSessionContext(sessionExpirationHint = true)))
server.send(Packet("sleep".toByteBuffer(), serverReceived.await().peerAddress, sessionContext = DtlsSessionContext(sessionSuspensionHint = true)))
assertEquals("sleep", client.receive(1.seconds).await())

await.atMost(5.seconds).untilAsserted {
Expand Down

0 comments on commit e773e04

Please sign in to comment.