Skip to content

Commit

Permalink
Make cid accessible from DtslServer (#26)
Browse files Browse the repository at this point in the history
* make cid accessible from DtslServer

* cleaning. Added get cid to DtslServerTransport as well

* executor for get cid
  • Loading branch information
topiasjokiniemi-nordic authored Jul 19, 2023
1 parent bc7b435 commit 3f0c028
Show file tree
Hide file tree
Showing 5 changed files with 63 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,10 @@ class DtlsServer(
private val sessions = mutableMapOf<InetSocketAddress, DtlsState>()
private val cidSize = sslConfig.cidSupplier.next().size
val numberOfSessions get() = sessions.size
fun getSessionCid(inet: InetSocketAddress): ByteArray? {
val dtlsState = sessions[inet] as? DtlsSession
return dtlsState?.sessionContext?.cid
}

fun handleReceived(adr: InetSocketAddress, buf: ByteBuffer): ReceiveResult {
val cid by lazy { SslContext.peekCID(cidSize, buf) }
Expand Down Expand Up @@ -215,7 +219,8 @@ class DtlsServer(
val sessionContext: DtlsSessionContext
get() = DtlsSessionContext(
peerCertificateSubject = ctx.peerCertificateSubject,
authenticationContext = authenticationContext
authenticationContext = authenticationContext,
cid = if (ctx.ownCid?.isEmpty() != true) ctx.ownCid else ctx.peerCid
)

init {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -114,4 +114,6 @@ class DtlsServerTransport private constructor(
executor.supply {
dtlsServer.putSessionAuthenticationContext(adr, key, value)
}

fun getSessionCid(adr: InetSocketAddress) = executor.supply { dtlsServer.getSessionCid(adr) }
}
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,34 @@ typealias AuthenticationContext = Map<String, String>

data class DtlsSessionContext @JvmOverloads constructor(
val authenticationContext: AuthenticationContext = emptyMap(),
val peerCertificateSubject: String? = null
val peerCertificateSubject: String? = null,
val cid: ByteArray? = null
) {
companion object {
@JvmField
val EMPTY = DtlsSessionContext()
}

override fun equals(other: Any?): Boolean {
if (this === other) return true
if (javaClass != other?.javaClass) return false

other as DtlsSessionContext

if (authenticationContext != other.authenticationContext) return false
if (peerCertificateSubject != other.peerCertificateSubject) return false
if (cid != null) {
if (other.cid == null) return false
if (!cid.contentEquals(other.cid)) return false
} else if (other.cid != null) return false

return true
}

override fun hashCode(): Int {
var result = authenticationContext.hashCode()
result = 31 * result + (peerCertificateSubject?.hashCode() ?: 0)
result = 31 * result + (cid?.contentHashCode() ?: 0)
return result
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -216,7 +216,6 @@ class DtlsServerTest {
// given
val clientSession = clientHandshake()
assertEquals(1, dtlsServer.numberOfSessions)

// and nothing is sent to server

// when, inactivity
Expand All @@ -227,6 +226,24 @@ class DtlsServerTest {
clientSession.close()
}

@Test
fun `should find session cid`() {
// given
val clientSession = clientHandshake()
val cid = dtlsServer.getSessionCid(localAddress(2_5684))
assert(cid!!.isNotEmpty())
clientSession.close()
}

@Test
fun `shouldn't find session cid`() {
// given
val clientSession = clientHandshake()
val cid = dtlsServer.getSessionCid(localAddress(1234))
assertEquals(null, cid)
clientSession.close()
}

private fun clientHandshake(): SslSession {
val send: (ByteBuffer) -> Unit = { dtlsServer.handleReceived(localAddress(2_5684), it) }
val cliHandshake = clientConf.newContext(localAddress(5684))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ import org.awaitility.kotlin.await
import org.junit.jupiter.api.AfterEach
import org.junit.jupiter.api.Assertions.assertEquals
import org.junit.jupiter.api.Assertions.assertFalse
import org.junit.jupiter.api.Assertions.assertNull
import org.junit.jupiter.api.Assertions.assertTrue
import org.junit.jupiter.api.Test
import org.opencoap.ssl.CertificateAuth
Expand Down Expand Up @@ -438,6 +439,15 @@ class DtlsServerTransportTest {
assertTrue(server.executor() is ScheduledThreadPoolExecutor)
}

@Test
fun `should not return cid`() {
server = DtlsServerTransport.create(conf, lifecycleCallbacks = sslLifecycleCallbacks).listen(echoHandler)
val client = DtlsTransmitter.connect(server, clientConfig).await()
client.send("hello!")
val cid = server.getSessionCid(localAddress(1234)).await()
assertNull(cid)
}

@Test
fun `should set and use session context`() {
// given
Expand All @@ -456,8 +466,8 @@ class DtlsServerTransportTest {
client.send("msg2")

// then
assertEquals(DtlsSessionContext(mapOf("auth" to "id:dev-007")), server.receive(1.seconds).await().sessionContext)
assertEquals(DtlsSessionContext(mapOf("auth" to "id:dev-007")), server.receive(1.seconds).await().sessionContext)
assertEquals(mapOf("auth" to "id:dev-007"), server.receive(1.seconds).await().sessionContext.authenticationContext)
assertEquals(mapOf("auth" to "id:dev-007"), server.receive(1.seconds).await().sessionContext.authenticationContext)

client.close()
}
Expand Down

0 comments on commit 3f0c028

Please sign in to comment.