Skip to content

Commit

Permalink
Run DTLS session lifecycle probes asynchronously (#25)
Browse files Browse the repository at this point in the history
* Call metrics callbacks asynchronously

* Fixing tests

* Set finish timestamp on handshake failure
  • Loading branch information
akolosov-n authored Aug 1, 2023
1 parent 3f0c028 commit 5fc13c8
Show file tree
Hide file tree
Showing 5 changed files with 34 additions and 32 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -54,10 +54,10 @@ class DtlsServerMetricsCallbacks(
handshakesInitiated.increment()
}

override fun handshakeFinished(adr: InetSocketAddress, hanshakeStartTimestamp: Long, reason: DtlsSessionLifecycleCallbacks.Reason, throwable: Throwable?) = when {
override fun handshakeFinished(adr: InetSocketAddress, hanshakeStartTimestamp: Long, hanshakeFinishTimestamp: Long, reason: DtlsSessionLifecycleCallbacks.Reason, throwable: Throwable?) = when {
throwable is HelloVerifyRequired -> {} // Skip HelloVerifyRequired handshake states
reason == DtlsSessionLifecycleCallbacks.Reason.SUCCEEDED ->
handshakesSucceeded.record(System.currentTimeMillis() - hanshakeStartTimestamp, TimeUnit.MILLISECONDS)
handshakesSucceeded.record(hanshakeFinishTimestamp - hanshakeStartTimestamp, TimeUnit.MILLISECONDS)
reason == DtlsSessionLifecycleCallbacks.Reason.FAILED ->
handshakesFailedBuilder.reasonTag(throwable).register(registry).increment()
reason == DtlsSessionLifecycleCallbacks.Reason.EXPIRED ->
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ class SslHandshakeContext internal constructor(
) : SslContext {
private val logger = LoggerFactory.getLogger(javaClass)
val startTimestamp: Long = System.currentTimeMillis()
var finishTimestamp: Long = 0
private var stepTimeout: Duration = Duration.ZERO

fun step(send: (ByteBuffer) -> Unit): SslContext = step0(null, send)
Expand All @@ -86,10 +87,12 @@ class SslHandshakeContext internal constructor(
MbedtlsApi.MBEDTLS_ERR_SSL_WANT_READ -> return this
MbedtlsApi.MBEDTLS_ERR_SSL_HELLO_VERIFY_REQUIRED -> throw HelloVerifyRequired
0 -> SslSession(conf, sslContext, cid).also {
logger.info("[{}] DTLS connected in {}ms {}", peerAdr, System.currentTimeMillis() - startTimestamp, it)
finishTimestamp = System.currentTimeMillis()
logger.info("[{}] DTLS connected in {}ms {}", peerAdr, finishTimestamp - startTimestamp, it)
}

else -> throw SslException.from(ret).also {
finishTimestamp = System.currentTimeMillis()
logger.debug("[{}] DTLS failed handshake: {}", peerAdr, it.message)
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -172,8 +172,8 @@ class DtlsServer(
}

is SslSession -> {
reportHandshakeFinished(DtlsSessionLifecycleCallbacks.Reason.SUCCEEDED)
sessions[peerAddress] = DtlsSession(newCtx, peerAddress)
reportHandshakeFinished(DtlsSessionLifecycleCallbacks.Reason.SUCCEEDED)
}
}
} catch (ex: Exception) {
Expand All @@ -185,28 +185,28 @@ class DtlsServer(
else ->
logger.error(ex.toString(), ex)
}
reportHandshakeFinished(DtlsSessionLifecycleCallbacks.Reason.FAILED, ex)
closeAndRemove()
reportHandshakeFinished(DtlsSessionLifecycleCallbacks.Reason.FAILED, ex)
}
return ReceiveResult.Handled
}

fun timeout() {
reportHandshakeFinished(DtlsSessionLifecycleCallbacks.Reason.EXPIRED)
closeAndRemove()
logger.warn("[{}] DTLS handshake expired", peerAddress)
reportHandshakeFinished(DtlsSessionLifecycleCallbacks.Reason.EXPIRED)
}

override fun storeAndClose0() = close()

override fun close() = ctx.close()

private fun reportHandshakeStarted() {
lifecycleCallbacks.handshakeStarted(peerAddress)
executor.supply { lifecycleCallbacks.handshakeStarted(peerAddress) }
}

private fun reportHandshakeFinished(reason: DtlsSessionLifecycleCallbacks.Reason, err: Throwable? = null) {
lifecycleCallbacks.handshakeFinished(peerAddress, ctx.startTimestamp, reason, err)
executor.supply { lifecycleCallbacks.handshakeFinished(peerAddress, ctx.startTimestamp, ctx.finishTimestamp, reason, err) }
}
}

Expand Down Expand Up @@ -272,26 +272,26 @@ class DtlsServer(
try {
return ctx.encrypt(plainPacket)
} catch (ex: SslException) {
closeAndRemove()
logger.warn("[{}] DTLS failed: {}", peerAddress, ex.message)
reportSessionFinished(DtlsSessionLifecycleCallbacks.Reason.FAILED, ex)
closeAndRemove()
throw ex
}
}

fun timeout() {
lifecycleCallbacks.sessionFinished(peerAddress, DtlsSessionLifecycleCallbacks.Reason.EXPIRED)
sessions.remove(peerAddress, this)
logger.info("[{}] DTLS connection expired", peerAddress)
storeAndClose()
logger.info("[{}] DTLS connection expired", peerAddress)
lifecycleCallbacks.sessionFinished(peerAddress, DtlsSessionLifecycleCallbacks.Reason.EXPIRED)
}

private fun reportSessionStarted() {
lifecycleCallbacks.sessionStarted(peerAddress, ctx.cipherSuite, ctx.reloaded)
executor.supply { lifecycleCallbacks.sessionStarted(peerAddress, ctx.cipherSuite, ctx.reloaded) }
}

private fun reportSessionFinished(reason: DtlsSessionLifecycleCallbacks.Reason, err: Throwable? = null) {
lifecycleCallbacks.sessionFinished(peerAddress, reason, err)
executor.supply { lifecycleCallbacks.sessionFinished(peerAddress, reason, err) }
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ interface DtlsSessionLifecycleCallbacks {
SUCCEEDED, FAILED, CLOSED, EXPIRED
}
fun handshakeStarted(adr: InetSocketAddress) = Unit
fun handshakeFinished(adr: InetSocketAddress, hanshakeStartTimestamp: Long, reason: Reason, throwable: Throwable? = null) = Unit
fun handshakeFinished(adr: InetSocketAddress, hanshakeStartTimestamp: Long, hanshakeFinishTimestamp: Long, reason: Reason, throwable: Throwable? = null) = Unit
fun sessionStarted(adr: InetSocketAddress, cipherSuite: String, reloaded: Boolean) = Unit
fun sessionFinished(adr: InetSocketAddress, reason: Reason, throwable: Throwable? = null) = Unit
}
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@ import io.mockk.clearMocks
import io.mockk.confirmVerified
import io.mockk.mockk
import io.mockk.verify
import io.mockk.verifyOrder
import org.awaitility.kotlin.await
import org.junit.jupiter.api.AfterEach
import org.junit.jupiter.api.Assertions.assertEquals
Expand Down Expand Up @@ -113,11 +112,11 @@ class DtlsServerTransportTest {
val clientAddress = client.localAddress()
client.close()

verifyOrder {
verify {
sslLifecycleCallbacks.handshakeStarted(clientAddress)
sslLifecycleCallbacks.handshakeFinished(clientAddress, any(), DtlsSessionLifecycleCallbacks.Reason.FAILED, ofType(HelloVerifyRequired::class))
sslLifecycleCallbacks.handshakeFinished(clientAddress, any(), any(), DtlsSessionLifecycleCallbacks.Reason.FAILED, ofType(HelloVerifyRequired::class))
sslLifecycleCallbacks.handshakeStarted(clientAddress)
sslLifecycleCallbacks.handshakeFinished(clientAddress, any(), DtlsSessionLifecycleCallbacks.Reason.SUCCEEDED)
sslLifecycleCallbacks.handshakeFinished(clientAddress, any(), any(), DtlsSessionLifecycleCallbacks.Reason.SUCCEEDED)
sslLifecycleCallbacks.sessionStarted(clientAddress, any(), false)
}

Expand Down Expand Up @@ -165,11 +164,11 @@ class DtlsServerTransportTest {
assertEquals(0, server.numberOfSessions())
}
assertFalse(srvReceive.isDone)
verifyOrder {
verify {
sslLifecycleCallbacks.handshakeStarted(any())
sslLifecycleCallbacks.handshakeFinished(any(), any(), DtlsSessionLifecycleCallbacks.Reason.FAILED, ofType(HelloVerifyRequired::class))
sslLifecycleCallbacks.handshakeFinished(any(), any(), any(), DtlsSessionLifecycleCallbacks.Reason.FAILED, ofType(HelloVerifyRequired::class))
sslLifecycleCallbacks.handshakeStarted(any())
sslLifecycleCallbacks.handshakeFinished(any(), any(), DtlsSessionLifecycleCallbacks.Reason.FAILED, ofType(SslException::class))
sslLifecycleCallbacks.handshakeFinished(any(), any(), any(), DtlsSessionLifecycleCallbacks.Reason.FAILED, ofType(SslException::class))
}

verify(exactly = 0) {
Expand All @@ -193,9 +192,9 @@ class DtlsServerTransportTest {

client.close()

verifyOrder {
verify {
sslLifecycleCallbacks.handshakeStarted(any())
sslLifecycleCallbacks.handshakeFinished(any(), any(), DtlsSessionLifecycleCallbacks.Reason.SUCCEEDED)
sslLifecycleCallbacks.handshakeFinished(any(), any(), any(), DtlsSessionLifecycleCallbacks.Reason.SUCCEEDED)
sslLifecycleCallbacks.sessionStarted(any(), any(), any())
}
}
Expand Down Expand Up @@ -237,11 +236,11 @@ class DtlsServerTransportTest {
cliChannel.close()

verify(atMost = 100) {
sslLifecycleCallbacks.handshakeFinished(any(), any(), DtlsSessionLifecycleCallbacks.Reason.FAILED, ofType(SslException::class))
sslLifecycleCallbacks.handshakeFinished(any(), any(), any(), DtlsSessionLifecycleCallbacks.Reason.FAILED, ofType(SslException::class))
}

verify(exactly = 0) {
sslLifecycleCallbacks.handshakeFinished(any(), any(), DtlsSessionLifecycleCallbacks.Reason.FAILED, ofType(HelloVerifyRequired::class))
sslLifecycleCallbacks.handshakeFinished(any(), any(), any(), DtlsSessionLifecycleCallbacks.Reason.FAILED, ofType(HelloVerifyRequired::class))
}
}

Expand Down Expand Up @@ -286,7 +285,7 @@ class DtlsServerTransportTest {
assertEquals(0, server.numberOfSessions())
}

verifyOrder {
verify {
sslLifecycleCallbacks.sessionStarted(any(), any(), any())
sslLifecycleCallbacks.sessionFinished(any(), DtlsSessionLifecycleCallbacks.Reason.CLOSED)
}
Expand All @@ -308,12 +307,12 @@ class DtlsServerTransportTest {

// No handshake failures other than HelloVerifyRequired
verify(exactly = 0) {
sslLifecycleCallbacks.handshakeFinished(any(), any(), DtlsSessionLifecycleCallbacks.Reason.FAILED, not(ofType(HelloVerifyRequired::class)))
sslLifecycleCallbacks.handshakeFinished(any(), any(), any(), DtlsSessionLifecycleCallbacks.Reason.FAILED, not(ofType(HelloVerifyRequired::class)))
}

// One successful handshake must happen
verify(exactly = 1) {
sslLifecycleCallbacks.handshakeFinished(any(), any(), DtlsSessionLifecycleCallbacks.Reason.SUCCEEDED)
sslLifecycleCallbacks.handshakeFinished(any(), any(), any(), DtlsSessionLifecycleCallbacks.Reason.SUCCEEDED)
}
}

Expand All @@ -335,7 +334,7 @@ class DtlsServerTransportTest {
cli.close()

verify(exactly = 1) {
sslLifecycleCallbacks.handshakeFinished(any(), any(), DtlsSessionLifecycleCallbacks.Reason.FAILED, and(ofType(SslException::class), not(ofType(HelloVerifyRequired::class))))
sslLifecycleCallbacks.handshakeFinished(any(), any(), any(), DtlsSessionLifecycleCallbacks.Reason.FAILED, and(ofType(SslException::class), not(ofType(HelloVerifyRequired::class))))
}
}

Expand Down Expand Up @@ -387,11 +386,11 @@ class DtlsServerTransportTest {
}
client.close()

verifyOrder() {
verify() {
sslLifecycleCallbacks.handshakeStarted(any())
sslLifecycleCallbacks.handshakeFinished(any(), any(), DtlsSessionLifecycleCallbacks.Reason.FAILED, ofType(HelloVerifyRequired::class))
sslLifecycleCallbacks.handshakeFinished(any(), any(), any(), DtlsSessionLifecycleCallbacks.Reason.FAILED, ofType(HelloVerifyRequired::class))
sslLifecycleCallbacks.handshakeStarted(any())
sslLifecycleCallbacks.handshakeFinished(any(), any(), DtlsSessionLifecycleCallbacks.Reason.SUCCEEDED)
sslLifecycleCallbacks.handshakeFinished(any(), any(), any(), DtlsSessionLifecycleCallbacks.Reason.SUCCEEDED)
sslLifecycleCallbacks.sessionStarted(any(), any(), false)
sslLifecycleCallbacks.sessionFinished(any(), DtlsSessionLifecycleCallbacks.Reason.EXPIRED)
sslLifecycleCallbacks.sessionStarted(any(), any(), true)
Expand Down

0 comments on commit 5fc13c8

Please sign in to comment.